improve default reward, fix bugs in environment

This commit is contained in:
robcaulk 2022-08-24 18:32:40 +02:00
parent a61821e1c6
commit d1bee29b1e
3 changed files with 102 additions and 53 deletions

View File

@ -140,30 +140,32 @@ class Base5ActionRLEnv(gym.Env):
if action == Actions.Neutral.value: if action == Actions.Neutral.value:
self._position = Positions.Neutral self._position = Positions.Neutral
trade_type = "neutral" trade_type = "neutral"
self._last_trade_tick = None
elif action == Actions.Long_enter.value: elif action == Actions.Long_enter.value:
self._position = Positions.Long self._position = Positions.Long
trade_type = "long" trade_type = "long"
self._last_trade_tick = self._current_tick
elif action == Actions.Short_enter.value: elif action == Actions.Short_enter.value:
self._position = Positions.Short self._position = Positions.Short
trade_type = "short" trade_type = "short"
self._last_trade_tick = self._current_tick
elif action == Actions.Long_exit.value: elif action == Actions.Long_exit.value:
self._position = Positions.Neutral self._position = Positions.Neutral
trade_type = "neutral" trade_type = "neutral"
self._last_trade_tick = None
elif action == Actions.Short_exit.value: elif action == Actions.Short_exit.value:
self._position = Positions.Neutral self._position = Positions.Neutral
trade_type = "neutral" trade_type = "neutral"
self._last_trade_tick = None
else: else:
print("case not defined") print("case not defined")
# Update last trade tick
self._last_trade_tick = self._current_tick
if trade_type is not None: if trade_type is not None:
self.trade_history.append( self.trade_history.append(
{'price': self.current_price(), 'index': self._current_tick, {'price': self.current_price(), 'index': self._current_tick,
'type': trade_type}) 'type': trade_type})
if self._total_profit < 0.2: if self._total_profit < 0.5:
self._done = True self._done = True
self._position_history.append(self._position) self._position_history.append(self._position)
@ -221,8 +223,7 @@ class Base5ActionRLEnv(gym.Env):
def is_tradesignal(self, action: int): def is_tradesignal(self, action: int):
# trade signal # trade signal
""" """
not trade signal is : Determine if the signal is a trade signal
Determine if the signal is non sensical
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
""" """
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
@ -237,6 +238,24 @@ class Base5ActionRLEnv(gym.Env):
(action == Actions.Long_exit.value and self._position == Positions.Short) or (action == Actions.Long_exit.value and self._position == Positions.Short) or
(action == Actions.Long_exit.value and self._position == Positions.Neutral)) (action == Actions.Long_exit.value and self._position == Positions.Neutral))
def _is_valid(self, action: int):
# trade signal
"""
Determine if the signal is valid.
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
"""
# Agent should only try to exit if it is in position
if action in (Actions.Short_exit.value, Actions.Long_exit.value):
if self._position not in (Positions.Short, Positions.Long):
return False
# Agent should only try to enter if it is not in position
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
if self._position != Positions.Neutral:
return False
return True
def _is_trade(self, action: Actions): def _is_trade(self, action: Actions):
return ((action == Actions.Long_enter.value and self._position == Positions.Neutral) or return ((action == Actions.Long_enter.value and self._position == Positions.Neutral) or
(action == Actions.Short_enter.value and self._position == Positions.Neutral)) (action == Actions.Short_enter.value and self._position == Positions.Neutral))
@ -278,13 +297,8 @@ class Base5ActionRLEnv(gym.Env):
if self._is_trade(action) or self._done: if self._is_trade(action) or self._done:
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
if self._position == Positions.Long: if self._position in (Positions.Long, Positions.Short):
self._total_profit = self._total_profit + self._total_profit * pnl self._total_profit *= (1 + pnl)
self._profits.append((self._current_tick, self._total_profit))
self.close_trade_profit.append(pnl)
if self._position == Positions.Short:
self._total_profit = self._total_profit + self._total_profit * pnl
self._profits.append((self._current_tick, self._total_profit)) self._profits.append((self._current_tick, self._total_profit))
self.close_trade_profit.append(pnl) self.close_trade_profit.append(pnl)

View File

@ -19,7 +19,6 @@ from typing import Callable
from datetime import datetime, timezone from datetime import datetime, timezone
from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.utils import set_random_seed
import gym import gym
from pathlib import Path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy('file_system') torch.multiprocessing.set_sharing_strategy('file_system')
@ -112,27 +111,14 @@ class BaseReinforcementLearningModel(IFreqaiModel):
test_df = data_dictionary["test_features"] test_df = data_dictionary["test_features"]
eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df) eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df)
# environments self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
if not self.train_env: reward_kwargs=self.reward_params, config=self.config)
self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
reward_kwargs=self.reward_params, config=self.config) window_size=self.CONV_WIDTH,
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test, reward_kwargs=self.reward_params, config=self.config))
window_size=self.CONV_WIDTH, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
reward_kwargs=self.reward_params, config=self.config)) render=False, eval_freq=eval_freq,
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, best_model_save_path=str(dk.data_path))
render=False, eval_freq=eval_freq,
best_model_save_path=str(dk.data_path))
else:
self.train_env.reset()
self.eval_env.reset()
self.train_env.reset_env(train_df, prices_train, self.CONV_WIDTH, self.reward_params)
self.eval_env.reset_env(test_df, prices_test, self.CONV_WIDTH, self.reward_params)
# self.eval_callback.eval_env = self.eval_env
# self.eval_callback.best_model_save_path = str(dk.data_path)
# self.eval_callback._init_callback()
self.eval_callback.__init__(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=str(dk.data_path))
@abstractmethod @abstractmethod
def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen): def fit_rl(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen):
@ -284,30 +270,43 @@ class MyRLEnv(Base5ActionRLEnv):
def calculate_reward(self, action): def calculate_reward(self, action):
if self._last_trade_tick is None: # first, penalize if the action is not valid
return 0. if not self._is_valid(action):
return -15
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 100) rew = np.sign(pnl) * (pnl + 1)
factor = 100
# reward agent for entering trades
if action in (Actions.Long_enter.value, Actions.Short_enter.value):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -15
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick trade_duration = self._current_tick - self._last_trade_tick
factor = 1
if trade_duration <= max_trade_duration: if trade_duration <= max_trade_duration:
factor *= 1.5 factor *= 1.5
elif trade_duration > max_trade_duration: elif trade_duration > max_trade_duration:
factor *= 0.5 factor *= 0.5
# discourage sitting in position
if self._position in (Positions.Short, Positions.Long):
return -50 * trade_duration / max_trade_duration
# close long # close long
if action == Actions.Long_exit.value and self._position == Positions.Long: if action == Actions.Long_exit.value and self._position == Positions.Long:
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(pnl * factor) return float(rew * factor)
# close short # close short
if action == Actions.Short_exit.value and self._position == Positions.Short: if action == Actions.Short_exit.value and self._position == Positions.Short:
factor = 1 if pnl > self.profit_aim * self.rr:
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(pnl * factor) return float(rew * factor)
return 0. return 0.

View File

@ -6,6 +6,10 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
from pathlib import Path from pathlib import Path
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,6 +53,25 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
return model return model
def set_train_and_eval_environments(self, data_dictionary: Dict[str, DataFrame],
prices_train: DataFrame, prices_test: DataFrame,
dk: FreqaiDataKitchen):
"""
User can override this if they are using a custom MyRLEnv
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
eval_freq = self.freqai_info["rl_config"]["eval_cycles"] * len(test_df)
self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config)
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=str(dk.data_path))
class MyRLEnv(Base5ActionRLEnv): class MyRLEnv(Base5ActionRLEnv):
""" """
@ -58,30 +81,43 @@ class MyRLEnv(Base5ActionRLEnv):
def calculate_reward(self, action): def calculate_reward(self, action):
if self._last_trade_tick is None: # first, penalize if the action is not valid
return 0. if not self._is_valid(action):
return -15
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 100) rew = np.sign(pnl) * (pnl + 1)
factor = 100
# reward agent for entering trades
if action in (Actions.Long_enter.value, Actions.Short_enter.value):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -15
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick trade_duration = self._current_tick - self._last_trade_tick
factor = 1
if trade_duration <= max_trade_duration: if trade_duration <= max_trade_duration:
factor *= 1.5 factor *= 1.5
elif trade_duration > max_trade_duration: elif trade_duration > max_trade_duration:
factor *= 0.5 factor *= 0.5
# discourage sitting in position
if self._position in (Positions.Short, Positions.Long):
return -50 * trade_duration / max_trade_duration
# close long # close long
if action == Actions.Long_exit.value and self._position == Positions.Long: if action == Actions.Long_exit.value and self._position == Positions.Long:
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(pnl * factor) return float(rew * factor)
# close short # close short
if action == Actions.Short_exit.value and self._position == Positions.Short: if action == Actions.Short_exit.value and self._position == Positions.Short:
factor = 1 if pnl > self.profit_aim * self.rr:
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(pnl * factor) return float(rew * factor)
return 0. return 0.