improve default reward, fix bugs in environment
This commit is contained in:
		| @@ -140,30 +140,32 @@ class Base5ActionRLEnv(gym.Env): | ||||
|             if action == Actions.Neutral.value: | ||||
|                 self._position = Positions.Neutral | ||||
|                 trade_type = "neutral" | ||||
|                 self._last_trade_tick = None | ||||
|             elif action == Actions.Long_enter.value: | ||||
|                 self._position = Positions.Long | ||||
|                 trade_type = "long" | ||||
|                 self._last_trade_tick = self._current_tick | ||||
|             elif action == Actions.Short_enter.value: | ||||
|                 self._position = Positions.Short | ||||
|                 trade_type = "short" | ||||
|                 self._last_trade_tick = self._current_tick | ||||
|             elif action == Actions.Long_exit.value: | ||||
|                 self._position = Positions.Neutral | ||||
|                 trade_type = "neutral" | ||||
|                 self._last_trade_tick = None | ||||
|             elif action == Actions.Short_exit.value: | ||||
|                 self._position = Positions.Neutral | ||||
|                 trade_type = "neutral" | ||||
|                 self._last_trade_tick = None | ||||
|             else: | ||||
|                 print("case not defined") | ||||
|  | ||||
|             # Update last trade tick | ||||
|             self._last_trade_tick = self._current_tick | ||||
|  | ||||
|             if trade_type is not None: | ||||
|                 self.trade_history.append( | ||||
|                     {'price': self.current_price(), 'index': self._current_tick, | ||||
|                      'type': trade_type}) | ||||
|  | ||||
|         if self._total_profit < 0.2: | ||||
|         if self._total_profit < 0.5: | ||||
|             self._done = True | ||||
|  | ||||
|         self._position_history.append(self._position) | ||||
| @@ -221,8 +223,7 @@ class Base5ActionRLEnv(gym.Env): | ||||
|     def is_tradesignal(self, action: int): | ||||
|         # trade signal | ||||
|         """ | ||||
|         not trade signal is : | ||||
|         Determine if the signal is non sensical | ||||
|         Determine if the signal is a trade signal | ||||
|         e.g.: agent wants a Actions.Long_exit while it is in a Positions.short | ||||
|         """ | ||||
|         return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or | ||||
| @@ -237,6 +238,24 @@ class Base5ActionRLEnv(gym.Env): | ||||
|                     (action == Actions.Long_exit.value and self._position == Positions.Short) or | ||||
|                     (action == Actions.Long_exit.value and self._position == Positions.Neutral)) | ||||
|  | ||||
|     def _is_valid(self, action: int): | ||||
|         # trade signal | ||||
|         """ | ||||
|         Determine if the signal is valid. | ||||
|         e.g.: agent wants a Actions.Long_exit while it is in a Positions.short | ||||
|         """ | ||||
|         # Agent should only try to exit if it is in position | ||||
|         if action in (Actions.Short_exit.value, Actions.Long_exit.value): | ||||
|             if self._position not in (Positions.Short, Positions.Long): | ||||
|                 return False | ||||
|  | ||||
|         # Agent should only try to enter if it is not in position | ||||
|         if action in (Actions.Short_enter.value, Actions.Long_enter.value): | ||||
|             if self._position != Positions.Neutral: | ||||
|                 return False | ||||
|  | ||||
|         return True | ||||
|  | ||||
|     def _is_trade(self, action: Actions): | ||||
|         return ((action == Actions.Long_enter.value and self._position == Positions.Neutral) or | ||||
|                 (action == Actions.Short_enter.value and self._position == Positions.Neutral)) | ||||
| @@ -278,13 +297,8 @@ class Base5ActionRLEnv(gym.Env): | ||||
|         if self._is_trade(action) or self._done: | ||||
|             pnl = self.get_unrealized_profit() | ||||
|  | ||||
|             if self._position == Positions.Long: | ||||
|                 self._total_profit = self._total_profit + self._total_profit * pnl | ||||
|                 self._profits.append((self._current_tick, self._total_profit)) | ||||
|                 self.close_trade_profit.append(pnl) | ||||
|  | ||||
|             if self._position == Positions.Short: | ||||
|                 self._total_profit = self._total_profit + self._total_profit * pnl | ||||
|             if self._position in (Positions.Long, Positions.Short): | ||||
|                 self._total_profit *= (1 + pnl) | ||||
|                 self._profits.append((self._current_tick, self._total_profit)) | ||||
|                 self.close_trade_profit.append(pnl) | ||||
|  | ||||
|   | ||||
| @@ -19,7 +19,6 @@ from typing import Callable | ||||
| from datetime import datetime, timezone | ||||
| from stable_baselines3.common.utils import set_random_seed | ||||
| import gym | ||||
| from pathlib import Path | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| torch.multiprocessing.set_sharing_strategy('file_system') | ||||
| @@ -112,27 +111,14 @@ class BaseReinforcementLearningModel(IFreqaiModel): | ||||
|         test_df = data_dictionary["test_features"] | ||||
|         eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df) | ||||
|  | ||||
|         # environments | ||||
|         if not self.train_env: | ||||
|             self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, | ||||
|                                      reward_kwargs=self.reward_params, config=self.config) | ||||
|             self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, | ||||
|                                     window_size=self.CONV_WIDTH, | ||||
|                                     reward_kwargs=self.reward_params, config=self.config)) | ||||
|             self.eval_callback = EvalCallback(self.eval_env, deterministic=True, | ||||
|                                               render=False, eval_freq=eval_freq, | ||||
|                                               best_model_save_path=str(dk.data_path)) | ||||
|         else: | ||||
|             self.train_env.reset() | ||||
|             self.eval_env.reset() | ||||
|             self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params) | ||||
|             self.eval_env.reset_env(test_df, prices_test, self.CONV_WIDTH, self.reward_params) | ||||
|             # self.eval_callback.eval_env = self.eval_env | ||||
|             # self.eval_callback.best_model_save_path = str(dk.data_path) | ||||
|             # self.eval_callback._init_callback() | ||||
|             self.eval_callback.__init__(self.eval_env, deterministic=True, | ||||
|                                         render=False, eval_freq=eval_freq, | ||||
|                                         best_model_save_path=str(dk.data_path)) | ||||
|         self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, | ||||
|                                  reward_kwargs=self.reward_params, config=self.config) | ||||
|         self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, | ||||
|                                 window_size=self.CONV_WIDTH, | ||||
|                                 reward_kwargs=self.reward_params, config=self.config)) | ||||
|         self.eval_callback = EvalCallback(self.eval_env, deterministic=True, | ||||
|                                           render=False, eval_freq=eval_freq, | ||||
|                                           best_model_save_path=str(dk.data_path)) | ||||
|  | ||||
|     @abstractmethod | ||||
|     def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen): | ||||
| @@ -284,30 +270,43 @@ class MyRLEnv(Base5ActionRLEnv): | ||||
|  | ||||
|     def calculate_reward(self, action): | ||||
|  | ||||
|         if self._last_trade_tick is None: | ||||
|             return 0. | ||||
|         # first, penalize if the action is not valid | ||||
|         if not self._is_valid(action): | ||||
|             return -15 | ||||
|  | ||||
|         pnl = self.get_unrealized_profit() | ||||
|         max_trade_duration = self.rl_config.get('max_trade_duration_candles', 100) | ||||
|         rew = np.sign(pnl) * (pnl + 1) | ||||
|         factor = 100 | ||||
|  | ||||
|         # reward agent for entering trades | ||||
|         if action in (Actions.Long_enter.value, Actions.Short_enter.value): | ||||
|             return 25 | ||||
|         # discourage agent from not entering trades | ||||
|         if action == Actions.Neutral.value and self._position == Positions.Neutral: | ||||
|             return -15 | ||||
|  | ||||
|         max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) | ||||
|         trade_duration = self._current_tick - self._last_trade_tick | ||||
|  | ||||
|         factor = 1 | ||||
|         if trade_duration <= max_trade_duration: | ||||
|             factor *= 1.5 | ||||
|         elif trade_duration > max_trade_duration: | ||||
|             factor *= 0.5 | ||||
|  | ||||
|         # discourage sitting in position | ||||
|         if self._position in (Positions.Short, Positions.Long): | ||||
|             return -50 * trade_duration / max_trade_duration | ||||
|  | ||||
|         # close long | ||||
|         if action == Actions.Long_exit.value and self._position == Positions.Long: | ||||
|             if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|             if pnl > self.profit_aim * self.rr: | ||||
|                 factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) | ||||
|             return float(pnl * factor) | ||||
|             return float(rew * factor) | ||||
|  | ||||
|         # close short | ||||
|         if action == Actions.Short_exit.value and self._position == Positions.Short: | ||||
|             factor = 1 | ||||
|             if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|             if pnl > self.profit_aim * self.rr: | ||||
|                 factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) | ||||
|             return float(pnl * factor) | ||||
|             return float(rew * factor) | ||||
|  | ||||
|         return 0. | ||||
|   | ||||
| @@ -6,6 +6,10 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions | ||||
| from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel | ||||
| from pathlib import Path | ||||
| from pandas import DataFrame | ||||
| from stable_baselines3.common.callbacks import EvalCallback | ||||
| from stable_baselines3.common.monitor import Monitor | ||||
| import numpy as np | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -49,6 +53,25 @@ class ReinforcementLearner(BaseReinforcementLearningModel): | ||||
|  | ||||
|         return model | ||||
|  | ||||
|     def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame], | ||||
|                                         prices_train: DataFrame, prices_test: DataFrame, | ||||
|                                         dk: FreqaiDataKitchen): | ||||
|         """ | ||||
|         User can override this if they are using a custom MyRLEnv | ||||
|         """ | ||||
|         train_df = data_dictionary["train_features"] | ||||
|         test_df = data_dictionary["test_features"] | ||||
|         eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df) | ||||
|  | ||||
|         self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, | ||||
|                                  reward_kwargs=self.reward_params, config=self.config) | ||||
|         self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, | ||||
|                                 window_size=self.CONV_WIDTH, | ||||
|                                 reward_kwargs=self.reward_params, config=self.config)) | ||||
|         self.eval_callback = EvalCallback(self.eval_env, deterministic=True, | ||||
|                                           render=False, eval_freq=eval_freq, | ||||
|                                           best_model_save_path=str(dk.data_path)) | ||||
|  | ||||
|  | ||||
| class MyRLEnv(Base5ActionRLEnv): | ||||
|     """ | ||||
| @@ -58,30 +81,43 @@ class MyRLEnv(Base5ActionRLEnv): | ||||
|  | ||||
|     def calculate_reward(self, action): | ||||
|  | ||||
|         if self._last_trade_tick is None: | ||||
|             return 0. | ||||
|         # first, penalize if the action is not valid | ||||
|         if not self._is_valid(action): | ||||
|             return -15 | ||||
|  | ||||
|         pnl = self.get_unrealized_profit() | ||||
|         max_trade_duration = self.rl_config.get('max_trade_duration_candles', 100) | ||||
|         rew = np.sign(pnl) * (pnl + 1) | ||||
|         factor = 100 | ||||
|  | ||||
|         # reward agent for entering trades | ||||
|         if action in (Actions.Long_enter.value, Actions.Short_enter.value): | ||||
|             return 25 | ||||
|         # discourage agent from not entering trades | ||||
|         if action == Actions.Neutral.value and self._position == Positions.Neutral: | ||||
|             return -15 | ||||
|  | ||||
|         max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) | ||||
|         trade_duration = self._current_tick - self._last_trade_tick | ||||
|  | ||||
|         factor = 1 | ||||
|         if trade_duration <= max_trade_duration: | ||||
|             factor *= 1.5 | ||||
|         elif trade_duration > max_trade_duration: | ||||
|             factor *= 0.5 | ||||
|  | ||||
|         # discourage sitting in position | ||||
|         if self._position in (Positions.Short, Positions.Long): | ||||
|             return -50 * trade_duration / max_trade_duration | ||||
|  | ||||
|         # close long | ||||
|         if action == Actions.Long_exit.value and self._position == Positions.Long: | ||||
|             if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|             if pnl > self.profit_aim * self.rr: | ||||
|                 factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) | ||||
|             return float(pnl * factor) | ||||
|             return float(rew * factor) | ||||
|  | ||||
|         # close short | ||||
|         if action == Actions.Short_exit.value and self._position == Positions.Short: | ||||
|             factor = 1 | ||||
|             if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: | ||||
|             if pnl > self.profit_aim * self.rr: | ||||
|                 factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) | ||||
|             return float(pnl * factor) | ||||
|             return float(rew * factor) | ||||
|  | ||||
|         return 0. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user