From a8c9aa01fb3c11330618f26efa822bfe9394124e Mon Sep 17 00:00:00 2001 From: Emre Date: Fri, 16 Dec 2022 22:31:44 +0300 Subject: [PATCH] Add 3ac test --- tests/freqai/test_freqai_interface.py | 5 +- .../ReinforcementLearner_test_3ac.py | 65 +++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 tests/freqai/test_models/ReinforcementLearner_test_3ac.py diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index f19acb018..2c58d4c0a 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -34,6 +34,7 @@ def is_mac() -> bool: ('CatboostRegressor', False, False, False), ('ReinforcementLearner', False, True, False), ('ReinforcementLearner_multiproc', False, False, False), + ('ReinforcementLearner_test_3ac', False, False, False), ('ReinforcementLearner_test_4ac', False, False, False) ]) def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32): @@ -58,7 +59,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True}) freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True}) - if 'test_4ac' in model: + if 'test_3ac' in model or 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") if 'ReinforcementLearner' in model: @@ -68,7 +69,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True}) freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True}) - if 'test_4ac' in model: + if 'test_3ac' in model or 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") strategy = get_patched_freqai_strategy(mocker, freqai_conf) diff --git a/tests/freqai/test_models/ReinforcementLearner_test_3ac.py b/tests/freqai/test_models/ReinforcementLearner_test_3ac.py new file mode 100644 index 000000000..c267c76a8 --- /dev/null +++ b/tests/freqai/test_models/ReinforcementLearner_test_3ac.py @@ -0,0 +1,65 @@ +import logging + +import numpy as np + +from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner +from freqtrade.freqai.RL.Base3ActionRLEnv import Actions, Base3ActionRLEnv, Positions + + +logger = logging.getLogger(__name__) + + +class ReinforcementLearner_test_3ac(ReinforcementLearner): + """ + User created Reinforcement Learning Model prediction model. + """ + + class MyRLEnv(Base3ActionRLEnv): + """ + User can override any function in BaseRLEnv and gym.Env. Here the user + sets a custom reward based on profit and trade duration. + """ + + def calculate_reward(self, action: int) -> float: + + # first, penalize if the action is not valid + if not self._is_valid(action): + return -2 + + pnl = self.get_unrealized_profit() + rew = np.sign(pnl) * (pnl + 1) + factor = 100. + + # reward agent for entering trades + if (action in (Actions.Buy.value, Actions.Sell.value) + and self._position == Positions.Neutral): + return 25 + # discourage agent from not entering trades + if action == Actions.Neutral.value and self._position == Positions.Neutral: + return -1 + + max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) + trade_duration = self._current_tick - self._last_trade_tick # type: ignore + + if trade_duration <= max_trade_duration: + factor *= 1.5 + elif trade_duration > max_trade_duration: + factor *= 0.5 + + # discourage sitting in position + if self._position in (Positions.Short, Positions.Long) and ( + action == Actions.Neutral.value + or (action == Actions.Sell.value and self._position == Positions.Short) + or (action == Actions.Buy.value and self._position == Positions.Long) + ): + return -1 * trade_duration / max_trade_duration + + # close position + if (action == Actions.Buy.value and self._position == Positions.Short) or ( + action == Actions.Sell.value and self._position == Positions.Long + ): + if pnl > self.profit_aim * self.rr: + factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2) + return float(rew * factor) + + return 0.