diff --git a/docs/freqai.md b/docs/freqai.md
index 938fb70f4..20562aadc 100644
--- a/docs/freqai.md
+++ b/docs/freqai.md
@@ -131,7 +131,7 @@ Mandatory parameters are marked as **Required**, which means that they are requi
| | *Reinforcement Learning Parameters**
| `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model.
**Datatype:** Dictionary.
| `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points.
**Datatype:** Integer.
-| `thread_count` | Number of threads to dedicate to the Reinforcement Learning training process.
**Datatype:** int.
+| `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process.
**Datatype:** int.
| `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()`
**Datatype:** int.
| `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website)
**Datatype:** string.
| `policy_type` | One of the available policy types from stable_baselines3
**Datatype:** string.
diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py
index 70b3e58ef..705c35297 100644
--- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py
+++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py
@@ -39,7 +39,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def __init__(self, **kwargs):
super().__init__(config=kwargs['config'])
- th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
+ self.max_threads = min(self.freqai_info['rl_config'].get(
+ 'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
+ th.set_num_threads(self.max_threads)
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
diff --git a/freqtrade/freqai/data_kitchen.py b/freqtrade/freqai/data_kitchen.py
index 005005368..73717abce 100644
--- a/freqtrade/freqai/data_kitchen.py
+++ b/freqtrade/freqai/data_kitchen.py
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple
import numpy as np
import numpy.typing as npt
import pandas as pd
+import psutil
from pandas import DataFrame
from scipy import stats
from sklearn import linear_model
@@ -95,7 +96,10 @@ class FreqaiDataKitchen:
)
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
- self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1)
+ if not self.freqai_config.get("data_kitchen_thread_count", 0):
+ self.thread_count = max(int(psutil.cpu_count() * 2 - 2), 1)
+ else:
+ self.thread_count = self.freqai_config["data_kitchen_thread_count"]
self.train_dates: DataFrame = pd.DataFrame()
self.unique_classes: Dict[str, list] = {}
self.unique_class_list: list = []
diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py
index 1a847a25e..44535f191 100644
--- a/freqtrade/freqai/freqai_interface.py
+++ b/freqtrade/freqai/freqai_interface.py
@@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
+import psutil
from numpy.typing import NDArray
from pandas import DataFrame
@@ -96,6 +97,7 @@ class IFreqaiModel(ABC):
self._threads: List[threading.Thread] = []
self._stop_event = threading.Event()
self.strategy: Optional[IStrategy] = None
+ self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
def __getstate__(self):
"""
@@ -158,6 +160,13 @@ class IFreqaiModel(ABC):
self.model = None
self.dk = None
+ def _on_stop(self):
+ """
+ Callback for Subclasses to override to include logic for shutting down resources
+ when SIGINT is sent.
+ """
+ return
+
def shutdown(self):
"""
Cleans up threads on Shutdown, set stop event. Join threads to wait
@@ -166,6 +175,8 @@ class IFreqaiModel(ABC):
logger.info("Stopping FreqAI")
self._stop_event.set()
+ self._on_stop()
+
logger.info("Waiting on Training iteration")
for _thread in self._threads:
_thread.join()
diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py
index 0e6449dcd..a644c0c04 100644
--- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py
+++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py
@@ -73,18 +73,28 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
test_df = data_dictionary["test_features"]
env_id = "train_env"
- num_cpu = int(self.freqai_info["rl_config"]["thread_count"])
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i
- in range(num_cpu)])
+ in range(self.max_threads)])
eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i
- in range(num_cpu)])
+ in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
+
+ def _on_stop(self):
+ """
+ Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
+ """
+
+ if self.train_env:
+ self.train_env.close()
+
+ if self.eval_env:
+ self.eval_env.close()