From d1bee29b1e5b01eb3465deea1b64968660e42b82 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Wed, 24 Aug 2022 18:32:40 +0200 Subject: [PATCH] improve default reward, fix bugs in environment --- freqtrade/freqai/RL/Base5ActionRLEnv.py | 40 ++++++++---- .../RL/BaseReinforcementLearningModel.py | 61 +++++++++---------- .../prediction_models/ReinforcementLearner.py | 54 +++++++++++++--- 3 files changed, 102 insertions(+), 53 deletions(-) diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index 64d7061fc..9f7c52c9c 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -140,30 +140,32 @@ class Base5ActionRLEnv(gym.Env): if action == Actions.Neutral.value: self._position = Positions.Neutral trade_type = "neutral" + self._last_trade_tick = None elif action == Actions.Long_enter.value: self._position = Positions.Long trade_type = "long" + self._last_trade_tick = self._current_tick elif action == Actions.Short_enter.value: self._position = Positions.Short trade_type = "short" + self._last_trade_tick = self._current_tick elif action == Actions.Long_exit.value: self._position = Positions.Neutral trade_type = "neutral" + self._last_trade_tick = None elif action == Actions.Short_exit.value: self._position = Positions.Neutral trade_type = "neutral" + self._last_trade_tick = None else: print("case not defined") - # Update last trade tick - self._last_trade_tick = self._current_tick - if trade_type is not None: self.trade_history.append( {'price': self.current_price(), 'index': self._current_tick, 'type': trade_type}) - if self._total_profit < 0.2: + if self._total_profit < 0.5: self._done = True self._position_history.append(self._position) @@ -221,8 +223,7 @@ class Base5ActionRLEnv(gym.Env): def is_tradesignal(self, action: int): # trade signal """ - not trade signal is : - Determine if the signal is non sensical + Determine if the signal is a trade signal e.g.: agent wants a Actions.Long_exit while it is in a Positions.short """ return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or @@ -237,6 +238,24 @@ class Base5ActionRLEnv(gym.Env): (action == Actions.Long_exit.value and self._position == Positions.Short) or (action == Actions.Long_exit.value and self._position == Positions.Neutral)) + def _is_valid(self, action: int): + # trade signal + """ + Determine if the signal is valid. + e.g.: agent wants a Actions.Long_exit while it is in a Positions.short + """ + # Agent should only try to exit if it is in position + if action in (Actions.Short_exit.value, Actions.Long_exit.value): + if self._position not in (Positions.Short, Positions.Long): + return False + + # Agent should only try to enter if it is not in position + if action in (Actions.Short_enter.value, Actions.Long_enter.value): + if self._position != Positions.Neutral: + return False + + return True + def _is_trade(self, action: Actions): return ((action == Actions.Long_enter.value and self._position == Positions.Neutral) or (action == Actions.Short_enter.value and self._position == Positions.Neutral)) @@ -278,13 +297,8 @@ class Base5ActionRLEnv(gym.Env): if self._is_trade(action) or self._done: pnl = self.get_unrealized_profit() - if self._position == Positions.Long: - self._total_profit = self._total_profit + self._total_profit * pnl - self._profits.append((self._current_tick, self._total_profit)) - self.close_trade_profit.append(pnl) - - if self._position == Positions.Short: - self._total_profit = self._total_profit + self._total_profit * pnl + if self._position in (Positions.Long, Positions.Short): + self._total_profit *= (1 + pnl) self._profits.append((self._current_tick, self._total_profit)) self.close_trade_profit.append(pnl) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 0f0120365..84d19f269 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -19,7 +19,6 @@ from typing import Callable from datetime import datetime, timezone from stable_baselines3.common.utils import set_random_seed import gym -from pathlib import Path logger = logging.getLogger(__name__) torch.multiprocessing.set_sharing_strategy('file_system') @@ -112,27 +111,14 @@ class BaseReinforcementLearningModel(IFreqaiModel): test_df = data_dictionary["test_features"] eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df) - # environments - if not self.train_env: - self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, - reward_kwargs=self.reward_params, config=self.config) - self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, - window_size=self.CONV_WIDTH, - reward_kwargs=self.reward_params, config=self.config)) - self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=eval_freq, - best_model_save_path=str(dk.data_path)) - else: - self.train_env.reset() - self.eval_env.reset() - self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params) - self.eval_env.reset_env(test_df, prices_test, self.CONV_WIDTH, self.reward_params) - # self.eval_callback.eval_env = self.eval_env - # self.eval_callback.best_model_save_path = str(dk.data_path) - # self.eval_callback._init_callback() - self.eval_callback.__init__(self.eval_env, deterministic=True, - render=False, eval_freq=eval_freq, - best_model_save_path=str(dk.data_path)) + self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, + reward_kwargs=self.reward_params, config=self.config) + self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, + window_size=self.CONV_WIDTH, + reward_kwargs=self.reward_params, config=self.config)) + self.eval_callback = EvalCallback(self.eval_env, deterministic=True, + render=False, eval_freq=eval_freq, + best_model_save_path=str(dk.data_path)) @abstractmethod def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen): @@ -284,30 +270,43 @@ class MyRLEnv(Base5ActionRLEnv): def calculate_reward(self, action): - if self._last_trade_tick is None: - return 0. + # first, penalize if the action is not valid + if not self._is_valid(action): + return -15 pnl = self.get_unrealized_profit() - max_trade_duration = self.rl_config.get('max_trade_duration_candles', 100) + rew = np.sign(pnl) * (pnl + 1) + factor = 100 + + # reward agent for entering trades + if action in (Actions.Long_enter.value, Actions.Short_enter.value): + return 25 + # discourage agent from not entering trades + if action == Actions.Neutral.value and self._position == Positions.Neutral: + return -15 + + max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) trade_duration = self._current_tick - self._last_trade_tick - factor = 1 if trade_duration <= max_trade_duration: factor *= 1.5 elif trade_duration > max_trade_duration: factor *= 0.5 + # discourage sitting in position + if self._position in (Positions.Short, Positions.Long): + return -50 * trade_duration / max_trade_duration + # close long if action == Actions.Long_exit.value and self._position == Positions.Long: - if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: + if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) - return float(pnl * factor) + return float(rew * factor) # close short if action == Actions.Short_exit.value and self._position == Positions.Short: - factor = 1 - if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: + if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) - return float(pnl * factor) + return float(rew * factor) return 0. diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index f7f016ab4..2d1cafab5 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -6,6 +6,10 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel from pathlib import Path +from pandas import DataFrame +from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.monitor import Monitor +import numpy as np logger = logging.getLogger(__name__) @@ -49,6 +53,25 @@ class ReinforcementLearner(BaseReinforcementLearningModel): return model + def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame], + prices_train: DataFrame, prices_test: DataFrame, + dk: FreqaiDataKitchen): + """ + User can override this if they are using a custom MyRLEnv + """ + train_df = data_dictionary["train_features"] + test_df = data_dictionary["test_features"] + eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df) + + self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, + reward_kwargs=self.reward_params, config=self.config) + self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, + window_size=self.CONV_WIDTH, + reward_kwargs=self.reward_params, config=self.config)) + self.eval_callback = EvalCallback(self.eval_env, deterministic=True, + render=False, eval_freq=eval_freq, + best_model_save_path=str(dk.data_path)) + class MyRLEnv(Base5ActionRLEnv): """ @@ -58,30 +81,43 @@ class MyRLEnv(Base5ActionRLEnv): def calculate_reward(self, action): - if self._last_trade_tick is None: - return 0. + # first, penalize if the action is not valid + if not self._is_valid(action): + return -15 pnl = self.get_unrealized_profit() - max_trade_duration = self.rl_config.get('max_trade_duration_candles', 100) + rew = np.sign(pnl) * (pnl + 1) + factor = 100 + + # reward agent for entering trades + if action in (Actions.Long_enter.value, Actions.Short_enter.value): + return 25 + # discourage agent from not entering trades + if action == Actions.Neutral.value and self._position == Positions.Neutral: + return -15 + + max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) trade_duration = self._current_tick - self._last_trade_tick - factor = 1 if trade_duration <= max_trade_duration: factor *= 1.5 elif trade_duration > max_trade_duration: factor *= 0.5 + # discourage sitting in position + if self._position in (Positions.Short, Positions.Long): + return -50 * trade_duration / max_trade_duration + # close long if action == Actions.Long_exit.value and self._position == Positions.Long: - if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: + if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) - return float(pnl * factor) + return float(rew * factor) # close short if action == Actions.Short_exit.value and self._position == Positions.Short: - factor = 1 - if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: + if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) - return float(pnl * factor) + return float(rew * factor) return 0.