diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py b/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py index f042762e4..5ec917719 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py @@ -1,16 +1,17 @@ import logging -from typing import Any, Dict, Optional - +from typing import Any, Dict # Optional +from enum import Enum import numpy as np import torch as th from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.monitor import Monitor # from stable_baselines3.common.vec_env import SubprocVecEnv -from freqtrade.freqai.RL.BaseRLEnv import BaseRLEnv, Actions, Positions +from freqtrade.freqai.RL.BaseRLEnv import BaseRLEnv from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel from freqtrade.freqai.RL.TDQNagent import TDQN from stable_baselines3.common.buffers import ReplayBuffer - +from gym import spaces +from gym.utils import seeding logger = logging.getLogger(__name__) @@ -57,7 +58,7 @@ class ReinforcementLearningTDQN(BaseReinforcementLearningModel): learning_rate=0.00025, gamma=0.9, target_update_interval=5000, buffer_size=50000, exploration_initial_eps=1, exploration_final_eps=0.1, - replay_buffer_class=Optional(ReplayBuffer) + replay_buffer_class=ReplayBuffer ) model.learn( @@ -70,11 +71,102 @@ class ReinforcementLearningTDQN(BaseReinforcementLearningModel): return model +class Actions(Enum): + Neutral = 0 + Long_buy = 1 + Long_sell = 2 + Short_buy = 3 + Short_sell = 4 + + +class Positions(Enum): + Short = 0 + Long = 1 + Neutral = 0.5 + + def opposite(self): + return Positions.Short if self == Positions.Long else Positions.Long + + class MyRLEnv(BaseRLEnv): """ - User can override any function in BaseRLEnv and gym.Env + User can override any function in BaseRLEnv and gym.Env. Here the user + Adds 5 actions. """ + metadata = {'render.modes': ['human']} + + def __init__(self, df, prices, reward_kwargs, window_size=10, starting_point=True, ): + assert df.ndim == 2 + + self.seed() + self.df = df + self.signal_features = self.df + self.prices = prices + self.window_size = window_size + self.starting_point = starting_point + self.rr = reward_kwargs["rr"] + self.profit_aim = reward_kwargs["profit_aim"] + + self.fee = 0.0015 + + # # spaces + self.shape = (window_size, self.signal_features.shape[1]) + self.action_space = spaces.Discrete(len(Actions)) + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32) + + # episode + self._start_tick = self.window_size + self._end_tick = len(self.prices) - 1 + self._done = None + self._current_tick = None + self._last_trade_tick = None + self._position = Positions.Neutral + self._position_history = None + self.total_reward = None + self._total_profit = None + self._first_rendering = None + self.history = None + self.trade_history = [] + + # self.A_t, self.B_t = 0.000639, 0.00001954 + self.r_t_change = 0. + + self.returns_report = [] + + def seed(self, seed=None): + self.np_random, seed = seeding.np_random(seed) + return [seed] + + def reset(self): + + self._done = False + + if self.starting_point is True: + self._position_history = (self._start_tick * [None]) + [self._position] + else: + self._position_history = (self.window_size * [None]) + [self._position] + + self._current_tick = self._start_tick + self._last_trade_tick = None + self._position = Positions.Neutral + + self.total_reward = 0. + self._total_profit = 1. # unit + self._first_rendering = True + self.history = {} + self.trade_history = [] + self.portfolio_log_returns = np.zeros(len(self.prices)) + + self._profits = [(self._start_tick, 1)] + self.close_trade_profit = [] + self.r_t_change = 0. + + self.returns_report = [] + + return self._get_observation() + def step(self, action): self._done = False self._current_tick += 1 @@ -85,11 +177,12 @@ class MyRLEnv(BaseRLEnv): self.update_portfolio_log_returns(action) self._update_profit(action) - step_reward = self._calculate_reward(action) + step_reward = self.calculate_reward(action) self.total_reward += step_reward trade_type = None - if self.is_tradesignal(action): + if self.is_tradesignal(action): # exclude 3 case not trade + # Update position """ Action: Neutral, position: Long -> Close Long Action: Neutral, position: Short -> Close Short @@ -104,12 +197,18 @@ class MyRLEnv(BaseRLEnv): if action == Actions.Neutral.value: self._position = Positions.Neutral trade_type = "neutral" - elif action == Actions.Long.value: + elif action == Actions.Long_buy.value: self._position = Positions.Long trade_type = "long" - elif action == Actions.Short.value: + elif action == Actions.Short_buy.value: self._position = Positions.Short trade_type = "short" + elif action == Actions.Long_sell.value: + self._position = Positions.Neutral + trade_type = "neutral" + elif action == Actions.Short_sell.value: + self._position = Positions.Neutral + trade_type = "neutral" else: print("case not defined") @@ -136,33 +235,69 @@ class MyRLEnv(BaseRLEnv): return observation, step_reward, self._done, info - def calculate_reward(self, action): + def _get_observation(self): + return self.signal_features[(self._current_tick - self.window_size):self._current_tick] + + def get_unrealized_profit(self): if self._last_trade_tick is None: return 0. - # close long - if action == Actions.Long_sell.value and self._position == Positions.Long: - last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) - current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) - return float(np.log(current_price) - np.log(last_trade_price)) - - if action == Actions.Long_sell.value and self._position == Positions.Long: - if self.close_trade_profit[-1] > self.profit_aim * self.rr: - last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) - current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) - return float((np.log(current_price) - np.log(last_trade_price)) * 2) - - # close short - if action == Actions.Short_buy.value and self._position == Positions.Short: - last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) + if self._position == Positions.Neutral: + return 0. + elif self._position == Positions.Short: current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) - return float(np.log(last_trade_price) - np.log(current_price)) + last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) + return (last_trade_price - current_price) / last_trade_price + elif self._position == Positions.Long: + current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) + last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) + return (current_price - last_trade_price) / last_trade_price + else: + return 0. - if action == Actions.Short_buy.value and self._position == Positions.Short: - if self.close_trade_profit[-1] > self.profit_aim * self.rr: - last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) - current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) - return float((np.log(last_trade_price) - np.log(current_price)) * 2) + def is_tradesignal(self, action): + # trade signal + """ + not trade signal is : + Action: Neutral, position: Neutral -> Nothing + Action: Long, position: Long -> Hold Long + Action: Short, position: Short -> Hold Short + """ + return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or + (action == Actions.Short_buy.value and self._position == Positions.Short) or + (action == Actions.Short_sell.value and self._position == Positions.Short) or + (action == Actions.Short_buy.value and self._position == Positions.Long) or + (action == Actions.Short_sell.value and self._position == Positions.Long) or - return 0. + (action == Actions.Long_buy.value and self._position == Positions.Long) or + (action == Actions.Long_sell.value and self._position == Positions.Long) or + (action == Actions.Long_buy.value and self._position == Positions.Short) or + (action == Actions.Long_sell.value and self._position == Positions.Short)) + + def _is_trade(self, action): + return ((action == Actions.Long_buy.value and self._position == Positions.Short) or + (action == Actions.Short_buy.value and self._position == Positions.Long) or + (action == Actions.Neutral.value and self._position == Positions.Long) or + (action == Actions.Neutral.value and self._position == Positions.Short) or + + (action == Actions.Neutral.Short_sell and self._position == Positions.Long) or + (action == Actions.Neutral.Long_sell and self._position == Positions.Short) + ) + + def is_hold(self, action): + return ((action == Actions.Short.value and self._position == Positions.Short) + or (action == Actions.Long.value and self._position == Positions.Long)) + + def add_buy_fee(self, price): + return price * (1 + self.fee) + + def add_sell_fee(self, price): + return price / (1 + self.fee) + + def _update_history(self, info): + if not self.history: + self.history = {key: [] for key in info.keys()} + + for key, value in info.items(): + self.history[key].append(value)