clean RL tests to avoid dir pollution and increase speed
This commit is contained in:
parent
3e258e000e
commit
8d7adfabe9
@ -63,6 +63,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.MODELCLASS = getattr(mod, self.model_type)
|
self.MODELCLASS = getattr(mod, self.model_type)
|
||||||
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
self.policy_type = self.freqai_info['rl_config']['policy_type']
|
||||||
self.unset_outlier_removal()
|
self.unset_outlier_removal()
|
||||||
|
self.net_arch = self.rl_config.get('net_arch', [128, 128])
|
||||||
|
|
||||||
def unset_outlier_removal(self):
|
def unset_outlier_removal(self):
|
||||||
"""
|
"""
|
||||||
@ -287,6 +288,17 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _on_stop(self):
|
||||||
|
"""
|
||||||
|
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.train_env:
|
||||||
|
self.train_env.close()
|
||||||
|
|
||||||
|
if self.eval_env:
|
||||||
|
self.eval_env.close()
|
||||||
|
|
||||||
# Nested class which can be overridden by user to customize further
|
# Nested class which can be overridden by user to customize further
|
||||||
class MyRLEnv(Base5ActionRLEnv):
|
class MyRLEnv(Base5ActionRLEnv):
|
||||||
"""
|
"""
|
||||||
|
@ -31,7 +31,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||||
|
|
||||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||||
net_arch=[128, 128])
|
net_arch=self.net_arch)
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||||
|
@ -28,7 +28,7 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
|||||||
|
|
||||||
# model arch
|
# model arch
|
||||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||||
net_arch=[128, 128])
|
net_arch=self.net_arch)
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||||
@ -87,14 +87,3 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
|
|||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
|
|
||||||
def _on_stop(self):
|
|
||||||
"""
|
|
||||||
Hook called on bot shutdown. Close SubprocVecEnv subprocesses for clean shutdown.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.train_env:
|
|
||||||
self.train_env.close()
|
|
||||||
|
|
||||||
if self.eval_env:
|
|
||||||
self.eval_env.close()
|
|
||||||
|
@ -58,6 +58,30 @@ def freqai_conf(default_conf, tmpdir):
|
|||||||
return freqaiconf
|
return freqaiconf
|
||||||
|
|
||||||
|
|
||||||
|
def make_rl_config(conf):
|
||||||
|
conf.update({"strategy": "freqai_rl_test_strat"})
|
||||||
|
conf["freqai"].update({"model_training_parameters": {
|
||||||
|
"learning_rate": 0.00025,
|
||||||
|
"gamma": 0.9,
|
||||||
|
"verbose": 1
|
||||||
|
}})
|
||||||
|
conf["freqai"]["rl_config"] = {
|
||||||
|
"train_cycles": 1,
|
||||||
|
"thread_count": 2,
|
||||||
|
"max_trade_duration_candles": 300,
|
||||||
|
"model_type": "PPO",
|
||||||
|
"policy_type": "MlpPolicy",
|
||||||
|
"max_training_drawdown_pct": 0.5,
|
||||||
|
"net_arch": [32, 32],
|
||||||
|
"model_reward_parameters": {
|
||||||
|
"rr": 1,
|
||||||
|
"profit_aim": 0.02,
|
||||||
|
"win_reward_factor": 2
|
||||||
|
}}
|
||||||
|
|
||||||
|
return conf
|
||||||
|
|
||||||
|
|
||||||
def get_patched_data_kitchen(mocker, freqaiconf):
|
def get_patched_data_kitchen(mocker, freqaiconf):
|
||||||
dk = FreqaiDataKitchen(freqaiconf)
|
dk = FreqaiDataKitchen(freqaiconf)
|
||||||
return dk
|
return dk
|
||||||
|
@ -14,7 +14,7 @@ from freqtrade.optimize.backtesting import Backtesting
|
|||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
from freqtrade.plugins.pairlistmanager import PairListManager
|
from freqtrade.plugins.pairlistmanager import PairListManager
|
||||||
from tests.conftest import get_patched_exchange, log_has_re
|
from tests.conftest import get_patched_exchange, log_has_re
|
||||||
from tests.freqai.conftest import get_patched_freqai_strategy
|
from tests.freqai.conftest import get_patched_freqai_strategy, make_rl_config
|
||||||
|
|
||||||
|
|
||||||
def is_arm() -> bool:
|
def is_arm() -> bool:
|
||||||
@ -49,25 +49,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
|
|||||||
|
|
||||||
if 'ReinforcementLearner' in model:
|
if 'ReinforcementLearner' in model:
|
||||||
model_save_ext = 'zip'
|
model_save_ext = 'zip'
|
||||||
freqai_conf.update({"strategy": "freqai_rl_test_strat"})
|
freqai_conf = make_rl_config(freqai_conf)
|
||||||
freqai_conf["freqai"].update({"model_training_parameters": {
|
|
||||||
"learning_rate": 0.00025,
|
|
||||||
"gamma": 0.9,
|
|
||||||
"verbose": 1
|
|
||||||
}})
|
|
||||||
freqai_conf["freqai"].update({"model_save_type": 'stable_baselines'})
|
|
||||||
freqai_conf["freqai"]["rl_config"] = {
|
|
||||||
"train_cycles": 1,
|
|
||||||
"thread_count": 2,
|
|
||||||
"max_trade_duration_candles": 300,
|
|
||||||
"model_type": "PPO",
|
|
||||||
"policy_type": "MlpPolicy",
|
|
||||||
"max_training_drawdown_pct": 0.5,
|
|
||||||
"model_reward_parameters": {
|
|
||||||
"rr": 1,
|
|
||||||
"profit_aim": 0.02,
|
|
||||||
"win_reward_factor": 2
|
|
||||||
}}
|
|
||||||
|
|
||||||
if 'test_4ac' in model:
|
if 'test_4ac' in model:
|
||||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
||||||
@ -79,6 +61,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model):
|
|||||||
freqai = strategy.freqai
|
freqai = strategy.freqai
|
||||||
freqai.live = True
|
freqai.live = True
|
||||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
||||||
|
freqai.dk.set_paths('ADA/BTC', 10000)
|
||||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
||||||
|
|
||||||
@ -204,25 +187,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat):
|
|||||||
freqai_conf.update({"strategy": strat})
|
freqai_conf.update({"strategy": strat})
|
||||||
|
|
||||||
if 'ReinforcementLearner' in model:
|
if 'ReinforcementLearner' in model:
|
||||||
|
freqai_conf = make_rl_config(freqai_conf)
|
||||||
freqai_conf["freqai"].update({"model_training_parameters": {
|
|
||||||
"learning_rate": 0.00025,
|
|
||||||
"gamma": 0.9,
|
|
||||||
"verbose": 1
|
|
||||||
}})
|
|
||||||
freqai_conf["freqai"].update({"model_save_type": 'stable_baselines'})
|
|
||||||
freqai_conf["freqai"]["rl_config"] = {
|
|
||||||
"train_cycles": 1,
|
|
||||||
"thread_count": 2,
|
|
||||||
"max_trade_duration_candles": 300,
|
|
||||||
"model_type": "PPO",
|
|
||||||
"policy_type": "MlpPolicy",
|
|
||||||
"max_training_drawdown_pct": 0.5,
|
|
||||||
"model_reward_parameters": {
|
|
||||||
"rr": 1,
|
|
||||||
"profit_aim": 0.02,
|
|
||||||
"win_reward_factor": 2
|
|
||||||
}}
|
|
||||||
|
|
||||||
if 'test_4ac' in model:
|
if 'test_4ac' in model:
|
||||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
||||||
|
@ -24,7 +24,7 @@ class ReinforcementLearner_test_4ac(BaseReinforcementLearningModel):
|
|||||||
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
|
||||||
|
|
||||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||||
net_arch=[128, 128])
|
net_arch=[64, 64])
|
||||||
|
|
||||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||||
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs,
|
||||||
|
Loading…
Reference in New Issue
Block a user