add rl_config.device param to support use apple chip GPU

This commit is contained in:
zhangzhichao
2022-12-23 19:51:33 +08:00
parent 3012c55ec5
commit fed9220b55
3 changed files with 23 additions and 15 deletions

View File

@@ -44,6 +44,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.max_threads = min(self.freqai_info['rl_config'].get(
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads)
self.device = self.freqai_info['rl_config'].get('device', '')
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()