Merge pull request #7492 from wizrds/freqai-rl-dev

Shutdown Subproc Env on signal
This commit is contained in:
Robert Caulk 2022-09-30 00:19:44 +02:00 committed by GitHub
commit 09e834fa21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 33 additions and 6 deletions

View File

@ -131,7 +131,7 @@ Mandatory parameters are marked as **Required**, which means that they are requi
| | *Reinforcement Learning Parameters** | | *Reinforcement Learning Parameters**
| `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model. <br> **Datatype:** Dictionary. | `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model. <br> **Datatype:** Dictionary.
| `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points. <br> **Datatype:** Integer. | `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points. <br> **Datatype:** Integer.
| `thread_count` | Number of threads to dedicate to the Reinforcement Learning training process. <br> **Datatype:** int. | `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process. <br> **Datatype:** int.
| `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()` <br> **Datatype:** int. | `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()` <br> **Datatype:** int.
| `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website) <br> **Datatype:** string. | `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website) <br> **Datatype:** string.
| `policy_type` | One of the available policy types from stable_baselines3 <br> **Datatype:** string. | `policy_type` | One of the available policy types from stable_baselines3 <br> **Datatype:** string.

View File

@ -39,7 +39,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(config=kwargs['config']) super().__init__(config=kwargs['config'])
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4)) self.max_threads = min(self.freqai_info['rl_config'].get(
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads)
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, gym.Env] = None self.train_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_env: Union[SubprocVecEnv, gym.Env] = None self.eval_env: Union[SubprocVecEnv, gym.Env] = None

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import pandas as pd import pandas as pd
import psutil
from pandas import DataFrame from pandas import DataFrame
from scipy import stats from scipy import stats
from sklearn import linear_model from sklearn import linear_model
@ -95,7 +96,10 @@ class FreqaiDataKitchen:
) )
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {}) self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1) if not self.freqai_config.get("data_kitchen_thread_count", 0):
self.thread_count = max(int(psutil.cpu_count() * 2 - 2), 1)
else:
self.thread_count = self.freqai_config["data_kitchen_thread_count"]
self.train_dates: DataFrame = pd.DataFrame() self.train_dates: DataFrame = pd.DataFrame()
self.unique_classes: Dict[str, list] = {} self.unique_classes: Dict[str, list] = {}
self.unique_class_list: list = [] self.unique_class_list: list = []

View File

@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import psutil
from numpy.typing import NDArray from numpy.typing import NDArray
from pandas import DataFrame from pandas import DataFrame
@ -96,6 +97,7 @@ class IFreqaiModel(ABC):
self._threads: List[threading.Thread] = [] self._threads: List[threading.Thread] = []
self._stop_event = threading.Event() self._stop_event = threading.Event()
self.strategy: Optional[IStrategy] = None self.strategy: Optional[IStrategy] = None
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
def __getstate__(self): def __getstate__(self):
""" """
@ -158,6 +160,13 @@ class IFreqaiModel(ABC):
self.model = None self.model = None
self.dk = None self.dk = None
def _on_stop(self):
"""
Callback for Subclasses to override to include logic for shutting down resources
when SIGINT is sent.
"""
return
def shutdown(self): def shutdown(self):
""" """
Cleans up threads on Shutdown, set stop event. Join threads to wait Cleans up threads on Shutdown, set stop event. Join threads to wait
@ -166,6 +175,8 @@ class IFreqaiModel(ABC):
logger.info("Stopping FreqAI") logger.info("Stopping FreqAI")
self._stop_event.set() self._stop_event.set()
self._on_stop()
logger.info("Waiting on Training iteration") logger.info("Waiting on Training iteration")
for _thread in self._threads: for _thread in self._threads:
_thread.join() _thread.join()

View File

@ -73,18 +73,28 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
test_df = data_dictionary["test_features"] test_df = data_dictionary["test_features"]
env_id = "train_env" env_id = "train_env"
num_cpu = int(self.freqai_info["rl_config"]["thread_count"])
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
self.reward_params, self.CONV_WIDTH, monitor=True, self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i config=self.config) for i
in range(num_cpu)]) in range(self.max_threads)])
eval_env_id = 'eval_env' eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test, test_df, prices_test,
self.reward_params, self.CONV_WIDTH, monitor=True, self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i config=self.config) for i
in range(num_cpu)]) in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
def _on_stop(self):
"""
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
"""
if self.train_env:
self.train_env.close()
if self.eval_env:
self.eval_env.close()