add rl_config.device param to support use apple chip GPU
This commit is contained in:
@@ -44,6 +44,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
self.max_threads = min(self.freqai_info['rl_config'].get(
|
||||
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
||||
th.set_num_threads(self.max_threads)
|
||||
self.device = self.freqai_info['rl_config'].get('device', '')
|
||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
||||
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
||||
|
@@ -58,10 +58,16 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
net_arch=self.net_arch)
|
||||
|
||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||
kwargs = self.freqai_info.get('model_training_parameters', {})
|
||||
|
||||
# set device if device is not None
|
||||
if self.device != '':
|
||||
kwargs['device'] = self.device
|
||||
|
||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=Path(
|
||||
dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
|
||||
**self.freqai_info.get('model_training_parameters', {})
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
logger.info('Continual training activated - starting training from previously '
|
||||
|
Reference in New Issue
Block a user