clean RL tests to avoid dir pollution and increase speed

This commit is contained in:
robcaulk 2022-10-08 12:10:38 +02:00
parent 3e258e000e
commit 8d7adfabe9
6 changed files with 43 additions and 53 deletions

View File

@ -63,6 +63,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.MODELCLASS = getattr(mod, self.model_type)
self.policy_type = self.freqai_info['rl_config']['policy_type']
self.unset_outlier_removal()
self.net_arch = self.rl_config.get('net_arch', [128, 128])
def unset_outlier_removal(self):
"""
@ -287,6 +288,17 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return model
def _on_stop(self):
"""
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
"""
if self.train_env:
self.train_env.close()
if self.eval_env:
self.eval_env.close()
# Nested class which can be overridden by user to customize further
class MyRLEnv(Base5ActionRLEnv):
"""

View File

@ -31,7 +31,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[128, 128])
net_arch=self.net_arch)
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,

View File

@ -28,7 +28,7 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
# model arch
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[128, 128])
net_arch=self.net_arch)
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
@ -87,14 +87,3 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
def _on_stop(self):
"""
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
"""
if self.train_env:
self.train_env.close()
if self.eval_env:
self.eval_env.close()

View File

@ -58,6 +58,30 @@ def freqai_conf(default_conf, tmpdir):
return freqaiconf
def make_rl_config(conf):
conf.update({"strategy": "freqai_rl_test_strat"})
conf["freqai"].update({"model_training_parameters": {
"learning_rate": 0.00025,
"gamma": 0.9,
"verbose": 1
}})
conf["freqai"]["rl_config"] = {
"train_cycles": 1,
"thread_count": 2,
"max_trade_duration_candles": 300,
"model_type": "PPO",
"policy_type": "MlpPolicy",
"max_training_drawdown_pct": 0.5,
"net_arch": [32, 32],
"model_reward_parameters": {
"rr": 1,
"profit_aim": 0.02,
"win_reward_factor": 2
}}
return conf
def get_patched_data_kitchen(mocker, freqaiconf):
dk = FreqaiDataKitchen(freqaiconf)
return dk

View File

@ -14,7 +14,7 @@ from freqtrade.optimize.backtesting import Backtesting
from freqtrade.persistence import Trade
from freqtrade.plugins.pairlistmanager import PairListManager
from tests.conftest import get_patched_exchange, log_has_re
from tests.freqai.conftest import get_patched_freqai_strategy
from tests.freqai.conftest import get_patched_freqai_strategy, make_rl_config
def is_arm() -> bool:
@ -49,25 +49,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
if 'ReinforcementLearner' in model:
model_save_ext = 'zip'
freqai_conf.update({"strategy": "freqai_rl_test_strat"})
freqai_conf["freqai"].update({"model_training_parameters": {
"learning_rate": 0.00025,
"gamma": 0.9,
"verbose": 1
}})
freqai_conf["freqai"].update({"model_save_type": 'stable_baselines'})
freqai_conf["freqai"]["rl_config"] = {
"train_cycles": 1,
"thread_count": 2,
"max_trade_duration_candles": 300,
"model_type": "PPO",
"policy_type": "MlpPolicy",
"max_training_drawdown_pct": 0.5,
"model_reward_parameters": {
"rr": 1,
"profit_aim": 0.02,
"win_reward_factor": 2
}}
freqai_conf = make_rl_config(freqai_conf)
if 'test_4ac' in model:
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
@ -79,6 +61,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
freqai = strategy.freqai
freqai.live = True
freqai.dk = FreqaiDataKitchen(freqai_conf)
freqai.dk.set_paths('ADA/BTC', 10000)
timerange = TimeRange.parse_timerange("20180110-20180130")
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
@ -204,25 +187,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat):
freqai_conf.update({"strategy": strat})
if 'ReinforcementLearner' in model:
freqai_conf["freqai"].update({"model_training_parameters": {
"learning_rate": 0.00025,
"gamma": 0.9,
"verbose": 1
}})
freqai_conf["freqai"].update({"model_save_type": 'stable_baselines'})
freqai_conf["freqai"]["rl_config"] = {
"train_cycles": 1,
"thread_count": 2,
"max_trade_duration_candles": 300,
"model_type": "PPO",
"policy_type": "MlpPolicy",
"max_training_drawdown_pct": 0.5,
"model_reward_parameters": {
"rr": 1,
"profit_aim": 0.02,
"win_reward_factor": 2
}}
freqai_conf = make_rl_config(freqai_conf)
if 'test_4ac' in model:
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")

View File

@ -24,7 +24,7 @@ class ReinforcementLearner_test_4ac(BaseReinforcementLearningModel):
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[128, 128])
net_arch=[64, 64])
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,