add rl_config.device param to support use apple chip GPU

This commit is contained in:
zhangzhichao
2022-12-23 19:51:33 +08:00
parent 3012c55ec5
commit fed9220b55
3 changed files with 23 additions and 15 deletions

View File

@@ -58,10 +58,16 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
net_arch=self.net_arch)
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
kwargs = self.freqai_info.get('model_training_parameters', {})
# set device if device is not None
if self.device != '':
kwargs['device'] = self.device
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
tensorboard_log=Path(
dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
**self.freqai_info.get('model_training_parameters', {})
**kwargs
)
else:
logger.info('Continual training activated - starting training from previously '