diff --git a/docs/freqai.md b/docs/freqai.md index 938fb70f4..20562aadc 100644 --- a/docs/freqai.md +++ b/docs/freqai.md @@ -131,7 +131,7 @@ Mandatory parameters are marked as **Required**, which means that they are requi | | *Reinforcement Learning Parameters** | `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model.
**Datatype:** Dictionary. | `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points.
**Datatype:** Integer. -| `thread_count` | Number of threads to dedicate to the Reinforcement Learning training process.
**Datatype:** int. +| `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process.
**Datatype:** int. | `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()`
**Datatype:** int. | `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website)
**Datatype:** string. | `policy_type` | One of the available policy types from stable_baselines3
**Datatype:** string. diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 70b3e58ef..705c35297 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -39,7 +39,9 @@ class BaseReinforcementLearningModel(IFreqaiModel): def __init__(self, **kwargs): super().__init__(config=kwargs['config']) - th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4)) + self.max_threads = min(self.freqai_info['rl_config'].get( + 'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) + th.set_num_threads(self.max_threads) self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.train_env: Union[SubprocVecEnv, gym.Env] = None self.eval_env: Union[SubprocVecEnv, gym.Env] = None diff --git a/freqtrade/freqai/data_kitchen.py b/freqtrade/freqai/data_kitchen.py index 005005368..73717abce 100644 --- a/freqtrade/freqai/data_kitchen.py +++ b/freqtrade/freqai/data_kitchen.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple import numpy as np import numpy.typing as npt import pandas as pd +import psutil from pandas import DataFrame from scipy import stats from sklearn import linear_model @@ -95,7 +96,10 @@ class FreqaiDataKitchen: ) self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {}) - self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1) + if not self.freqai_config.get("data_kitchen_thread_count", 0): + self.thread_count = max(int(psutil.cpu_count() * 2 - 2), 1) + else: + self.thread_count = self.freqai_config["data_kitchen_thread_count"] self.train_dates: DataFrame = pd.DataFrame() self.unique_classes: Dict[str, list] = {} self.unique_class_list: list = [] diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 1a847a25e..44535f191 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd +import psutil from numpy.typing import NDArray from pandas import DataFrame @@ -96,6 +97,7 @@ class IFreqaiModel(ABC): self._threads: List[threading.Thread] = [] self._stop_event = threading.Event() self.strategy: Optional[IStrategy] = None + self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1) def __getstate__(self): """ @@ -158,6 +160,13 @@ class IFreqaiModel(ABC): self.model = None self.dk = None + def _on_stop(self): + """ + Callback for Subclasses to override to include logic for shutting down resources + when SIGINT is sent. + """ + return + def shutdown(self): """ Cleans up threads on Shutdown, set stop event. Join threads to wait @@ -166,6 +175,8 @@ class IFreqaiModel(ABC): logger.info("Stopping FreqAI") self._stop_event.set() + self._on_stop() + logger.info("Waiting on Training iteration") for _thread in self._threads: _thread.join() diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index 0e6449dcd..a644c0c04 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -73,18 +73,28 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): test_df = data_dictionary["test_features"] env_id = "train_env" - num_cpu = int(self.freqai_info["rl_config"]["thread_count"]) self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, self.reward_params, self.CONV_WIDTH, monitor=True, config=self.config) for i - in range(num_cpu)]) + in range(self.max_threads)]) eval_env_id = 'eval_env' self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, test_df, prices_test, self.reward_params, self.CONV_WIDTH, monitor=True, config=self.config) for i - in range(num_cpu)]) + in range(self.max_threads)]) self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), best_model_save_path=str(dk.data_path)) + + def _on_stop(self): + """ + Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown. + """ + + if self.train_env: + self.train_env.close() + + if self.eval_env: + self.eval_env.close()