Merge pull request #7492 from wizrds/freqai-rl-dev
Shutdown Subproc Env on signal
This commit is contained in:
commit
09e834fa21
@ -131,7 +131,7 @@ Mandatory parameters are marked as **Required**, which means that they are requi
|
|||||||
| | *Reinforcement Learning Parameters**
|
| | *Reinforcement Learning Parameters**
|
||||||
| `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model. <br> **Datatype:** Dictionary.
|
| `rl_config` | A dictionary containing the control parameters for a Reinforcement Learning model. <br> **Datatype:** Dictionary.
|
||||||
| `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points. <br> **Datatype:** Integer.
|
| `train_cycles` | Training time steps will be set based on the `train_cycles * number of training data points. <br> **Datatype:** Integer.
|
||||||
| `thread_count` | Number of threads to dedicate to the Reinforcement Learning training process. <br> **Datatype:** int.
|
| `cpu_count` | Number of processors to dedicate to the Reinforcement Learning training process. <br> **Datatype:** int.
|
||||||
| `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()` <br> **Datatype:** int.
|
| `max_trade_duration_candles`| Guides the agent training to keep trades below desired length. Example usage shown in `prediction_models/ReinforcementLearner.py` within the user customizable `calculate_reward()` <br> **Datatype:** int.
|
||||||
| `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website) <br> **Datatype:** string.
|
| `model_type` | Model string from stable_baselines3 or SBcontrib. Available strings include: `'TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO', 'PPO', 'A2C', 'DQN'`. User should ensure that `model_training_parameters` match those available to the corresponding stable_baselines3 model by visiting their documentaiton. [PPO doc](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (external website) <br> **Datatype:** string.
|
||||||
| `policy_type` | One of the available policy types from stable_baselines3 <br> **Datatype:** string.
|
| `policy_type` | One of the available policy types from stable_baselines3 <br> **Datatype:** string.
|
||||||
|
@ -39,7 +39,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(config=kwargs['config'])
|
super().__init__(config=kwargs['config'])
|
||||||
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
|
self.max_threads = min(self.freqai_info['rl_config'].get(
|
||||||
|
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
||||||
|
th.set_num_threads(self.max_threads)
|
||||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||||
self.train_env: Union[SubprocVecEnv, gym.Env] = None
|
self.train_env: Union[SubprocVecEnv, gym.Env] = None
|
||||||
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
|
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
|
||||||
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import psutil
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
from sklearn import linear_model
|
from sklearn import linear_model
|
||||||
@ -95,7 +96,10 @@ class FreqaiDataKitchen:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
|
self.data['extra_returns_per_train'] = self.freqai_config.get('extra_returns_per_train', {})
|
||||||
self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1)
|
if not self.freqai_config.get("data_kitchen_thread_count", 0):
|
||||||
|
self.thread_count = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||||
|
else:
|
||||||
|
self.thread_count = self.freqai_config["data_kitchen_thread_count"]
|
||||||
self.train_dates: DataFrame = pd.DataFrame()
|
self.train_dates: DataFrame = pd.DataFrame()
|
||||||
self.unique_classes: Dict[str, list] = {}
|
self.unique_classes: Dict[str, list] = {}
|
||||||
self.unique_class_list: list = []
|
self.unique_class_list: list = []
|
||||||
|
@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import psutil
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
@ -96,6 +97,7 @@ class IFreqaiModel(ABC):
|
|||||||
self._threads: List[threading.Thread] = []
|
self._threads: List[threading.Thread] = []
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self.strategy: Optional[IStrategy] = None
|
self.strategy: Optional[IStrategy] = None
|
||||||
|
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
"""
|
"""
|
||||||
@ -158,6 +160,13 @@ class IFreqaiModel(ABC):
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.dk = None
|
self.dk = None
|
||||||
|
|
||||||
|
def _on_stop(self):
|
||||||
|
"""
|
||||||
|
Callback for Subclasses to override to include logic for shutting down resources
|
||||||
|
when SIGINT is sent.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""
|
"""
|
||||||
Cleans up threads on Shutdown, set stop event. Join threads to wait
|
Cleans up threads on Shutdown, set stop event. Join threads to wait
|
||||||
@ -166,6 +175,8 @@ class IFreqaiModel(ABC):
|
|||||||
logger.info("Stopping FreqAI")
|
logger.info("Stopping FreqAI")
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
|
|
||||||
|
self._on_stop()
|
||||||
|
|
||||||
logger.info("Waiting on Training iteration")
|
logger.info("Waiting on Training iteration")
|
||||||
for _thread in self._threads:
|
for _thread in self._threads:
|
||||||
_thread.join()
|
_thread.join()
|
||||||
|
@ -73,18 +73,28 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
|||||||
test_df = data_dictionary["test_features"]
|
test_df = data_dictionary["test_features"]
|
||||||
|
|
||||||
env_id = "train_env"
|
env_id = "train_env"
|
||||||
num_cpu = int(self.freqai_info["rl_config"]["thread_count"])
|
|
||||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
||||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||||
config=self.config) for i
|
config=self.config) for i
|
||||||
in range(num_cpu)])
|
in range(self.max_threads)])
|
||||||
|
|
||||||
eval_env_id = 'eval_env'
|
eval_env_id = 'eval_env'
|
||||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||||
test_df, prices_test,
|
test_df, prices_test,
|
||||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||||
config=self.config) for i
|
config=self.config) for i
|
||||||
in range(num_cpu)])
|
in range(self.max_threads)])
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
|
|
||||||
|
def _on_stop(self):
|
||||||
|
"""
|
||||||
|
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.train_env:
|
||||||
|
self.train_env.close()
|
||||||
|
|
||||||
|
if self.eval_env:
|
||||||
|
self.eval_env.close()
|
||||||
|
Loading…
Reference in New Issue
Block a user