From 1c56fa034f908ae005e0167830e18ef54667f1a4 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Fri, 23 Sep 2022 09:19:16 +0200 Subject: [PATCH] add test_models folder --- .../ReinforcementLearner_test_4ac.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 tests/freqai/test_models/ReinforcementLearner_test_4ac.py diff --git a/tests/freqai/test_models/ReinforcementLearner_test_4ac.py b/tests/freqai/test_models/ReinforcementLearner_test_4ac.py new file mode 100644 index 000000000..9a8f800bd --- /dev/null +++ b/tests/freqai/test_models/ReinforcementLearner_test_4ac.py @@ -0,0 +1,104 @@ +import logging +from pathlib import Path +from typing import Any, Dict + +import numpy as np +import torch as th + +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.RL.Base4ActionRLEnv import Actions, Base4ActionRLEnv, Positions +from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel + + +logger = logging.getLogger(__name__) + + +class ReinforcementLearner_test_4ac(BaseReinforcementLearningModel): + """ + User created Reinforcement Learning Model prediction model. + """ + + def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): + + train_df = data_dictionary["train_features"] + total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df) + + policy_kwargs = dict(activation_fn=th.nn.ReLU, + net_arch=[128, 128]) + + if dk.pair not in self.dd.model_dictionary or not self.continual_learning: + model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, + tensorboard_log=Path( + dk.full_path / "tensorboard" / dk.pair.split('/')[0]), + **self.freqai_info['model_training_parameters'] + ) + else: + logger.info('Continual training activated - starting training from previously ' + 'trained agent.') + model = self.dd.model_dictionary[dk.pair] + model.set_env(self.train_env) + + model.learn( + total_timesteps=int(total_timesteps), + callback=self.eval_callback + ) + + if Path(dk.data_path / "best_model.zip").is_file(): + logger.info('Callback found a best model.') + best_model = self.MODELCLASS.load(dk.data_path / "best_model") + return best_model + + logger.info('Couldnt find best model, using final model instead.') + + return model + + class MyRLEnv(Base4ActionRLEnv): + """ + User can override any function in BaseRLEnv and gym.Env. Here the user + sets a custom reward based on profit and trade duration. + """ + + def calculate_reward(self, action): + + # first, penalize if the action is not valid + if not self._is_valid(action): + return -2 + + pnl = self.get_unrealized_profit() + rew = np.sign(pnl) * (pnl + 1) + factor = 100 + + # reward agent for entering trades + if (action in (Actions.Long_enter.value, Actions.Short_enter.value) + and self._position == Positions.Neutral): + return 25 + # discourage agent from not entering trades + if action == Actions.Neutral.value and self._position == Positions.Neutral: + return -1 + + max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) + trade_duration = self._current_tick - self._last_trade_tick + + if trade_duration <= max_trade_duration: + factor *= 1.5 + elif trade_duration > max_trade_duration: + factor *= 0.5 + + # discourage sitting in position + if (self._position in (Positions.Short, Positions.Long) and + action == Actions.Neutral.value): + return -1 * trade_duration / max_trade_duration + + # close long + if action == Actions.Exit.value and self._position == Positions.Long: + if pnl > self.profit_aim * self.rr: + factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) + return float(rew * factor) + + # close short + if action == Actions.Exit.value and self._position == Positions.Short: + if pnl > self.profit_aim * self.rr: + factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) + return float(rew * factor) + + return 0.