From 83343dc2f11988cc2ee384ebdcba2731d156e26d Mon Sep 17 00:00:00 2001 From: robcaulk Date: Thu, 29 Sep 2022 00:10:18 +0200 Subject: [PATCH] control number of threads, update doc --- docs/freqai.md | 2 +- freqtrade/freqai/RL/BaseReinforcementLearningModel.py | 4 +++- freqtrade/freqai/data_kitchen.py | 6 +++++- freqtrade/freqai/freqai_interface.py | 2 ++ .../prediction_models/ReinforcementLearner_multiproc.py | 5 ++--- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/docs/freqai.md b/docs/freqai.md index 938fb70f4..20562aadc 100644 --- a/docs/freqai.md +++ b/docs/freqai.md @@ -131,7 +131,7 @@ Mandatory parameters are marked as **Required**, which means that they are requi | | *Reinforcement Learning Parameters** | `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model.
**Datatype:** Dictionary. | `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points.
**Datatype:** Integer. -| `thread_count` | Number of threads to dedicate to the Reinforcement Learning training process.
**Datatype:** int. +| `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process.
**Datatype:** int. | `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()`
**Datatype:** int. | `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website)
**Datatype:** string. | `policy_type` | One of the available policy types from stable_baselines3
**Datatype:** string. diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 70b3e58ef..8785192f4 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -39,7 +39,9 @@ class BaseReinforcementLearningModel(IFreqaiModel): def __init__(self, **kwargs): super().__init__(config=kwargs['config']) - th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4)) + self.max_threads = max(self.freqai_info['rl_config'].get( + 'cpu_count', 0), int(self.max_system_threads / 2)) + th.set_num_threads(self.max_threads) self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.train_env: Union[SubprocVecEnv, gym.Env] = None self.eval_env: Union[SubprocVecEnv, gym.Env] = None diff --git a/freqtrade/freqai/data_kitchen.py b/freqtrade/freqai/data_kitchen.py index 005005368..9f84e63b7 100644 --- a/freqtrade/freqai/data_kitchen.py +++ b/freqtrade/freqai/data_kitchen.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple import numpy as np import numpy.typing as npt import pandas as pd +import psutil from pandas import DataFrame from scipy import stats from sklearn import linear_model @@ -95,7 +96,10 @@ class FreqaiDataKitchen: ) self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {}) - self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1) + if not self.freqai_config.get("data_kitchen_thread_count", 0): + self.thread_count = int(psutil.cpu_count() * 2 - 2) + else: + self.thread_count = self.freqai_config["data_kitchen_thread_count"] self.train_dates: DataFrame = pd.DataFrame() self.unique_classes: Dict[str, list] = {} self.unique_class_list: list = [] diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index f8ca34ddb..5fe3c318c 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd +import psutil from numpy.typing import NDArray from pandas import DataFrame @@ -96,6 +97,7 @@ class IFreqaiModel(ABC): self._threads: List[threading.Thread] = [] self._stop_event = threading.Event() self.strategy: Optional[IStrategy] = None + self.max_system_threads = int(psutil.cpu_count() * 2 - 2) def __getstate__(self): """ diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index d01c409c3..a644c0c04 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -73,18 +73,17 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): test_df = data_dictionary["test_features"] env_id = "train_env" - num_cpu = int(self.freqai_info["rl_config"].get("cpu_count", 2)) self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, self.reward_params, self.CONV_WIDTH, monitor=True, config=self.config) for i - in range(num_cpu)]) + in range(self.max_threads)]) eval_env_id = 'eval_env' self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, test_df, prices_test, self.reward_params, self.CONV_WIDTH, monitor=True, config=self.config) for i - in range(num_cpu)]) + in range(self.max_threads)]) self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), best_model_save_path=str(dk.data_path))