From 8d7adfabe97e7e7db23df2108e181452fd9f14ac Mon Sep 17 00:00:00 2001 From: robcaulk Date: Sat, 8 Oct 2022 12:10:38 +0200 Subject: [PATCH] clean RL tests to avoid dir pollution and increase speed --- .../RL/BaseReinforcementLearningModel.py | 12 ++++++ .../prediction_models/ReinforcementLearner.py | 2 +- .../ReinforcementLearner_multiproc.py | 13 +----- tests/freqai/conftest.py | 24 +++++++++++ tests/freqai/test_freqai_interface.py | 43 ++----------------- .../ReinforcementLearner_test_4ac.py | 2 +- 6 files changed, 43 insertions(+), 53 deletions(-) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index e89320668..64af31c45 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -63,6 +63,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.MODELCLASS = getattr(mod, self.model_type) self.policy_type = self.freqai_info['rl_config']['policy_type'] self.unset_outlier_removal() + self.net_arch = self.rl_config.get('net_arch', [128, 128]) def unset_outlier_removal(self): """ @@ -287,6 +288,17 @@ class BaseReinforcementLearningModel(IFreqaiModel): return model + def _on_stop(self): + """ + Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown. + """ + + if self.train_env: + self.train_env.close() + + if self.eval_env: + self.eval_env.close() + # Nested class which can be overridden by user to customize further class MyRLEnv(Base5ActionRLEnv): """ diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 48519c34c..4bf990172 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -31,7 +31,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df) policy_kwargs = dict(activation_fn=th.nn.ReLU, - net_arch=[128, 128]) + net_arch=self.net_arch) if dk.pair not in self.dd.model_dictionary or not self.continual_learning: model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index a644c0c04..41345b967 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -28,7 +28,7 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): # model arch policy_kwargs = dict(activation_fn=th.nn.ReLU, - net_arch=[128, 128]) + net_arch=self.net_arch) if dk.pair not in self.dd.model_dictionary or not self.continual_learning: model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, @@ -87,14 +87,3 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), best_model_save_path=str(dk.data_path)) - - def _on_stop(self): - """ - Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown. - """ - - if self.train_env: - self.train_env.close() - - if self.eval_env: - self.eval_env.close() diff --git a/tests/freqai/conftest.py b/tests/freqai/conftest.py index 026b45afc..7f4897439 100644 --- a/tests/freqai/conftest.py +++ b/tests/freqai/conftest.py @@ -58,6 +58,30 @@ def freqai_conf(default_conf, tmpdir): return freqaiconf +def make_rl_config(conf): + conf.update({"strategy": "freqai_rl_test_strat"}) + conf["freqai"].update({"model_training_parameters": { + "learning_rate": 0.00025, + "gamma": 0.9, + "verbose": 1 + }}) + conf["freqai"]["rl_config"] = { + "train_cycles": 1, + "thread_count": 2, + "max_trade_duration_candles": 300, + "model_type": "PPO", + "policy_type": "MlpPolicy", + "max_training_drawdown_pct": 0.5, + "net_arch": [32, 32], + "model_reward_parameters": { + "rr": 1, + "profit_aim": 0.02, + "win_reward_factor": 2 + }} + + return conf + + def get_patched_data_kitchen(mocker, freqaiconf): dk = FreqaiDataKitchen(freqaiconf) return dk diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index bd7c62c5f..40a573547 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -14,7 +14,7 @@ from freqtrade.optimize.backtesting import Backtesting from freqtrade.persistence import Trade from freqtrade.plugins.pairlistmanager import PairListManager from tests.conftest import get_patched_exchange, log_has_re -from tests.freqai.conftest import get_patched_freqai_strategy +from tests.freqai.conftest import get_patched_freqai_strategy, make_rl_config def is_arm() -> bool: @@ -49,25 +49,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model): if 'ReinforcementLearner' in model: model_save_ext = 'zip' - freqai_conf.update({"strategy": "freqai_rl_test_strat"}) - freqai_conf["freqai"].update({"model_training_parameters": { - "learning_rate": 0.00025, - "gamma": 0.9, - "verbose": 1 - }}) - freqai_conf["freqai"].update({"model_save_type": 'stable_baselines'}) - freqai_conf["freqai"]["rl_config"] = { - "train_cycles": 1, - "thread_count": 2, - "max_trade_duration_candles": 300, - "model_type": "PPO", - "policy_type": "MlpPolicy", - "max_training_drawdown_pct": 0.5, - "model_reward_parameters": { - "rr": 1, - "profit_aim": 0.02, - "win_reward_factor": 2 - }} + freqai_conf = make_rl_config(freqai_conf) if 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") @@ -79,6 +61,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model): freqai = strategy.freqai freqai.live = True freqai.dk = FreqaiDataKitchen(freqai_conf) + freqai.dk.set_paths('ADA/BTC', 10000) timerange = TimeRange.parse_timerange("20180110-20180130") freqai.dd.load_all_pair_histories(timerange, freqai.dk) @@ -204,25 +187,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat): freqai_conf.update({"strategy": strat}) if 'ReinforcementLearner' in model: - - freqai_conf["freqai"].update({"model_training_parameters": { - "learning_rate": 0.00025, - "gamma": 0.9, - "verbose": 1 - }}) - freqai_conf["freqai"].update({"model_save_type": 'stable_baselines'}) - freqai_conf["freqai"]["rl_config"] = { - "train_cycles": 1, - "thread_count": 2, - "max_trade_duration_candles": 300, - "model_type": "PPO", - "policy_type": "MlpPolicy", - "max_training_drawdown_pct": 0.5, - "model_reward_parameters": { - "rr": 1, - "profit_aim": 0.02, - "win_reward_factor": 2 - }} + freqai_conf = make_rl_config(freqai_conf) if 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") diff --git a/tests/freqai/test_models/ReinforcementLearner_test_4ac.py b/tests/freqai/test_models/ReinforcementLearner_test_4ac.py index 9a8f800bd..13e5af02f 100644 --- a/tests/freqai/test_models/ReinforcementLearner_test_4ac.py +++ b/tests/freqai/test_models/ReinforcementLearner_test_4ac.py @@ -24,7 +24,7 @@ class ReinforcementLearner_test_4ac(BaseReinforcementLearningModel): total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df) policy_kwargs = dict(activation_fn=th.nn.ReLU, - net_arch=[128, 128]) + net_arch=[64, 64]) if dk.pair not in self.dd.model_dictionary or not self.continual_learning: model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,