add rl_config.device param to support use apple chip GPU

This commit is contained in:
zhangzhichao 2022-12-23 19:51:33 +08:00
parent 3012c55ec5
commit fed9220b55
3 changed files with 23 additions and 15 deletions

View File

@ -69,12 +69,13 @@ Mandatory parameters are marked as **Required** and have to be set in one of the
### Reinforcement Learning parameters ### Reinforcement Learning parameters
| Parameter | Description | | Parameter | Description |
|------------|-------------| |-------------------------------|-------------|
| | **Reinforcement Learning Parameters within the `freqai.rl_config` sub dictionary** | | **Reinforcement Learning Parameters within the `freqai.rl_config` sub dictionary**
| `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model. <br> **Datatype:** Dictionary. | `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model. <br> **Datatype:** Dictionary.
| `device` | Specify where to run. (cpu,mps,cuda) For example, you can specify 'mps' to use the GPU on an apple chip <br> **Datatype:** string.
| `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points. <br> **Datatype:** Integer. | `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points. <br> **Datatype:** Integer.
| `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process. <br> **Datatype:** int. | `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process. <br> **Datatype:** int.
| `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the customizable `calculate_reward()` function. <br> **Datatype:** int. | `max_trade_duration_candles` | Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the customizable `calculate_reward()` function. <br> **Datatype:** int.
| `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website) <br> **Datatype:** string. | `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website) <br> **Datatype:** string.
| `policy_type` | One of the available policy types from stable_baselines3 <br> **Datatype:** string. | `policy_type` | One of the available policy types from stable_baselines3 <br> **Datatype:** string.
| `max_training_drawdown_pct` | The maximum drawdown that the agent is allowed to experience during training. <br> **Datatype:** float. <br> Default: 0.8 | `max_training_drawdown_pct` | The maximum drawdown that the agent is allowed to experience during training. <br> **Datatype:** float. <br> Default: 0.8

View File

@ -44,6 +44,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.max_threads = min(self.freqai_info['rl_config'].get( self.max_threads = min(self.freqai_info['rl_config'].get(
'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) 'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads) th.set_num_threads(self.max_threads)
self.device = self.freqai_info['rl_config'].get('device', '')
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()

View File

@ -58,10 +58,16 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
net_arch=self.net_arch) net_arch=self.net_arch)
if dk.pair not in self.dd.model_dictionary or not self.continual_learning: if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
kwargs = self.freqai_info.get('model_training_parameters', {})
# set device if device is not None
if self.device != '':
kwargs['device'] = self.device
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
tensorboard_log=Path( tensorboard_log=Path(
dk.full_path / "tensorboard" / dk.pair.split('/')[0]), dk.full_path / "tensorboard" / dk.pair.split('/')[0]),
**self.freqai_info.get('model_training_parameters', {}) **kwargs
) )
else: else:
logger.info('Continual training activated - starting training from previously ' logger.info('Continual training activated - starting training from previously '