diff --git a/docs/freqai-parameter-table.md b/docs/freqai-parameter-table.md index d05ce80f3..89cad0d7e 100644 --- a/docs/freqai-parameter-table.md +++ b/docs/freqai-parameter-table.md @@ -68,20 +68,21 @@ Mandatory parameters are marked as **Required** and have to be set in one of the ### Reinforcement Learning parameters -| Parameter | Description | -|------------|-------------| -| | **Reinforcement Learning Parameters within the `freqai.rl_config` sub dictionary** -| `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model.
**Datatype:** Dictionary. -| `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points.
**Datatype:** Integer. -| `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process.
**Datatype:** int. -| `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the customizable `calculate_reward()` function.
**Datatype:** int. -| `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website)
**Datatype:** string. -| `policy_type` | One of the available policy types from stable_baselines3
**Datatype:** string. -| `max_training_drawdown_pct` | The maximum drawdown that the agent is allowed to experience during training.
**Datatype:** float.
Default: 0.8 -| `cpu_count` | Number of threads/cpus to dedicate to the Reinforcement Learning training process (depending on if `ReinforcementLearning_multiproc` is selected or not). Recommended to leave this untouched, by default, this value is set to the total number of physical cores minus 1.
**Datatype:** int. -| `model_reward_parameters` | Parameters used inside the customizable `calculate_reward()` function in `ReinforcementLearner.py`
**Datatype:** int. -| `add_state_info` | Tell FreqAI to include state information in the feature set for training and inferencing. The current state variables include trade duration, current profit, trade position. This is only available in dry/live runs, and is automatically switched to false for backtesting.
**Datatype:** bool.
Default: `False`. -| `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[, dict(vf=[], pi=[])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each. +| Parameter | Description | +|-------------------------------|-------------| +| | **Reinforcement Learning Parameters within the `freqai.rl_config` sub dictionary** +| `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model.
**Datatype:** Dictionary. +| `device` | Specify where to run. (cpu,mps,cuda) For example, you can specify 'mps' to use the GPU on an apple chip
**Datatype:** string. +| `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points.
**Datatype:** Integer. +| `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process.
**Datatype:** int. +| `max_trade_duration_candles` | Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the customizable `calculate_reward()` function.
**Datatype:** int. +| `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website)
**Datatype:** string. +| `policy_type` | One of the available policy types from stable_baselines3
**Datatype:** string. +| `max_training_drawdown_pct` | The maximum drawdown that the agent is allowed to experience during training.
**Datatype:** float.
Default: 0.8 +| `cpu_count` | Number of threads/cpus to dedicate to the Reinforcement Learning training process (depending on if `ReinforcementLearning_multiproc` is selected or not). Recommended to leave this untouched, by default, this value is set to the total number of physical cores minus 1.
**Datatype:** int. +| `model_reward_parameters` | Parameters used inside the customizable `calculate_reward()` function in `ReinforcementLearner.py`
**Datatype:** int. +| `add_state_info` | Tell FreqAI to include state information in the feature set for training and inferencing. The current state variables include trade duration, current profit, trade position. This is only available in dry/live runs, and is automatically switched to false for backtesting.
**Datatype:** bool.
Default: `False`. +| `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[, dict(vf=[], pi=[])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each. | `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting.
**Datatype:** bool.
Default: `False`. ### Additional parameters diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index af0726c0b..a9859da0e 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -44,6 +44,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.max_threads = min(self.freqai_info['rl_config'].get( 'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) th.set_num_threads(self.max_threads) + self.device = self.freqai_info['rl_config'].get('device', '') self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 2a87151f9..01de65205 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -58,10 +58,16 @@ class ReinforcementLearner(BaseReinforcementLearningModel): net_arch=self.net_arch) if dk.pair not in self.dd.model_dictionary or not self.continual_learning: + kwargs = self.freqai_info.get('model_training_parameters', {}) + + # set device if device is not None + if self.device != '': + kwargs['device'] = self.device + model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, tensorboard_log=Path( dk.full_path / "tensorboard" / dk.pair.split('/')[0]), - **self.freqai_info.get('model_training_parameters', {}) + **kwargs ) else: logger.info('Continual training activated - starting training from previously '