get TDQN working with 5 action environment

This commit is contained in:
robcaulk 2022-08-15 11:11:16 +02:00
parent d4db5c3281
commit 6048f60f13
1 changed files with 168 additions and 33 deletions

View File

@ -1,16 +1,17 @@
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict # Optional
from enum import Enum
import numpy as np import numpy as np
import torch as th import torch as th
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
# from stable_baselines3.common.vec_env import SubprocVecEnv # from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.freqai.RL.BaseRLEnv import BaseRLEnv, Actions, Positions from freqtrade.freqai.RL.BaseRLEnv import BaseRLEnv
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
from freqtrade.freqai.RL.TDQNagent import TDQN from freqtrade.freqai.RL.TDQNagent import TDQN
from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.buffers import ReplayBuffer
from gym import spaces
from gym.utils import seeding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +58,7 @@ class ReinforcementLearningTDQN(BaseReinforcementLearningModel):
learning_rate=0.00025, gamma=0.9, learning_rate=0.00025, gamma=0.9,
target_update_interval=5000, buffer_size=50000, target_update_interval=5000, buffer_size=50000,
exploration_initial_eps=1, exploration_final_eps=0.1, exploration_initial_eps=1, exploration_final_eps=0.1,
replay_buffer_class=Optional(ReplayBuffer) replay_buffer_class=ReplayBuffer
) )
model.learn( model.learn(
@ -70,11 +71,102 @@ class ReinforcementLearningTDQN(BaseReinforcementLearningModel):
return model return model
class Actions(Enum):
Neutral = 0
Long_buy = 1
Long_sell = 2
Short_buy = 3
Short_sell = 4
class Positions(Enum):
Short = 0
Long = 1
Neutral = 0.5
def opposite(self):
return Positions.Short if self == Positions.Long else Positions.Long
class MyRLEnv(BaseRLEnv): class MyRLEnv(BaseRLEnv):
""" """
User can override any function in BaseRLEnv and gym.Env User can override any function in BaseRLEnv and gym.Env. Here the user
Adds 5 actions.
""" """
metadata = {'render.modes': ['human']}
def __init__(self, df, prices, reward_kwargs, window_size=10, starting_point=True, ):
assert df.ndim == 2
self.seed()
self.df = df
self.signal_features = self.df
self.prices = prices
self.window_size = window_size
self.starting_point = starting_point
self.rr = reward_kwargs["rr"]
self.profit_aim = reward_kwargs["profit_aim"]
self.fee = 0.0015
# # spaces
self.shape = (window_size, self.signal_features.shape[1])
self.action_space = spaces.Discrete(len(Actions))
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)
# episode
self._start_tick = self.window_size
self._end_tick = len(self.prices) - 1
self._done = None
self._current_tick = None
self._last_trade_tick = None
self._position = Positions.Neutral
self._position_history = None
self.total_reward = None
self._total_profit = None
self._first_rendering = None
self.history = None
self.trade_history = []
# self.A_t, self.B_t = 0.000639, 0.00001954
self.r_t_change = 0.
self.returns_report = []
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
self._done = False
if self.starting_point is True:
self._position_history = (self._start_tick * [None]) + [self._position]
else:
self._position_history = (self.window_size * [None]) + [self._position]
self._current_tick = self._start_tick
self._last_trade_tick = None
self._position = Positions.Neutral
self.total_reward = 0.
self._total_profit = 1. # unit
self._first_rendering = True
self.history = {}
self.trade_history = []
self.portfolio_log_returns = np.zeros(len(self.prices))
self._profits = [(self._start_tick, 1)]
self.close_trade_profit = []
self.r_t_change = 0.
self.returns_report = []
return self._get_observation()
def step(self, action): def step(self, action):
self._done = False self._done = False
self._current_tick += 1 self._current_tick += 1
@ -85,11 +177,12 @@ class MyRLEnv(BaseRLEnv):
self.update_portfolio_log_returns(action) self.update_portfolio_log_returns(action)
self._update_profit(action) self._update_profit(action)
step_reward = self._calculate_reward(action) step_reward = self.calculate_reward(action)
self.total_reward += step_reward self.total_reward += step_reward
trade_type = None trade_type = None
if self.is_tradesignal(action): if self.is_tradesignal(action): # exclude 3 case not trade
# Update position
""" """
Action: Neutral, position: Long -> Close Long Action: Neutral, position: Long -> Close Long
Action: Neutral, position: Short -> Close Short Action: Neutral, position: Short -> Close Short
@ -104,12 +197,18 @@ class MyRLEnv(BaseRLEnv):
if action == Actions.Neutral.value: if action == Actions.Neutral.value:
self._position = Positions.Neutral self._position = Positions.Neutral
trade_type = "neutral" trade_type = "neutral"
elif action == Actions.Long.value: elif action == Actions.Long_buy.value:
self._position = Positions.Long self._position = Positions.Long
trade_type = "long" trade_type = "long"
elif action == Actions.Short.value: elif action == Actions.Short_buy.value:
self._position = Positions.Short self._position = Positions.Short
trade_type = "short" trade_type = "short"
elif action == Actions.Long_sell.value:
self._position = Positions.Neutral
trade_type = "neutral"
elif action == Actions.Short_sell.value:
self._position = Positions.Neutral
trade_type = "neutral"
else: else:
print("case not defined") print("case not defined")
@ -136,33 +235,69 @@ class MyRLEnv(BaseRLEnv):
return observation, step_reward, self._done, info return observation, step_reward, self._done, info
def calculate_reward(self, action): def _get_observation(self):
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
def get_unrealized_profit(self):
if self._last_trade_tick is None: if self._last_trade_tick is None:
return 0. return 0.
# close long if self._position == Positions.Neutral:
if action == Actions.Long_sell.value and self._position == Positions.Long: return 0.
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) elif self._position == Positions.Short:
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
return float(np.log(current_price) - np.log(last_trade_price))
if action == Actions.Long_sell.value and self._position == Positions.Long:
if self.close_trade_profit[-1] > self.profit_aim * self.rr:
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
return float((np.log(current_price) - np.log(last_trade_price)) * 2)
# close short
if action == Actions.Short_buy.value and self._position == Positions.Short:
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open)
return float(np.log(last_trade_price) - np.log(current_price)) last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open)
return (last_trade_price - current_price) / last_trade_price
elif self._position == Positions.Long:
current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open)
last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open)
return (current_price - last_trade_price) / last_trade_price
else:
return 0.
if action == Actions.Short_buy.value and self._position == Positions.Short: def is_tradesignal(self, action):
if self.close_trade_profit[-1] > self.profit_aim * self.rr: # trade signal
last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) """
current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) not trade signal is :
return float((np.log(last_trade_price) - np.log(current_price)) * 2) Action: Neutral, position: Neutral -> Nothing
Action: Long, position: Long -> Hold Long
Action: Short, position: Short -> Hold Short
"""
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
(action == Actions.Short_buy.value and self._position == Positions.Short) or
(action == Actions.Short_sell.value and self._position == Positions.Short) or
(action == Actions.Short_buy.value and self._position == Positions.Long) or
(action == Actions.Short_sell.value and self._position == Positions.Long) or
return 0. (action == Actions.Long_buy.value and self._position == Positions.Long) or
(action == Actions.Long_sell.value and self._position == Positions.Long) or
(action == Actions.Long_buy.value and self._position == Positions.Short) or
(action == Actions.Long_sell.value and self._position == Positions.Short))
def _is_trade(self, action):
return ((action == Actions.Long_buy.value and self._position == Positions.Short) or
(action == Actions.Short_buy.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Short) or
(action == Actions.Neutral.Short_sell and self._position == Positions.Long) or
(action == Actions.Neutral.Long_sell and self._position == Positions.Short)
)
def is_hold(self, action):
return ((action == Actions.Short.value and self._position == Positions.Short)
or (action == Actions.Long.value and self._position == Positions.Long))
def add_buy_fee(self, price):
return price * (1 + self.fee)
def add_sell_fee(self, price):
return price / (1 + self.fee)
def _update_history(self, info):
if not self.history:
self.history = {key: [] for key in info.keys()}
for key, value in info.items():
self.history[key].append(value)