2022-08-25 19:40:16 +00:00
|
|
|
import logging
|
|
|
|
from enum import Enum
|
|
|
|
|
|
|
|
from gym import spaces
|
2022-08-28 17:21:57 +00:00
|
|
|
|
|
|
|
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
|
|
|
|
|
|
|
|
2022-08-25 19:40:16 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class Actions(Enum):
|
|
|
|
Neutral = 0
|
|
|
|
Exit = 1
|
|
|
|
Long_enter = 2
|
|
|
|
Short_enter = 3
|
|
|
|
|
|
|
|
|
2022-08-28 17:21:57 +00:00
|
|
|
class Base4ActionRLEnv(BaseEnvironment):
|
2022-08-25 19:40:16 +00:00
|
|
|
"""
|
2022-08-28 17:21:57 +00:00
|
|
|
Base class for a 4 action environment
|
2022-08-25 19:40:16 +00:00
|
|
|
"""
|
|
|
|
|
2022-08-28 17:21:57 +00:00
|
|
|
def set_action_space(self):
|
2022-08-25 19:40:16 +00:00
|
|
|
self.action_space = spaces.Discrete(len(Actions))
|
|
|
|
|
|
|
|
def step(self, action: int):
|
|
|
|
self._done = False
|
|
|
|
self._current_tick += 1
|
|
|
|
|
|
|
|
if self._current_tick == self._end_tick:
|
|
|
|
self._done = True
|
|
|
|
|
|
|
|
self.update_portfolio_log_returns(action)
|
|
|
|
|
|
|
|
self._update_profit(action)
|
|
|
|
step_reward = self.calculate_reward(action)
|
|
|
|
self.total_reward += step_reward
|
|
|
|
|
|
|
|
trade_type = None
|
|
|
|
if self.is_tradesignal(action):
|
|
|
|
"""
|
|
|
|
Action: Neutral, position: Long -> Close Long
|
|
|
|
Action: Neutral, position: Short -> Close Short
|
|
|
|
|
|
|
|
Action: Long, position: Neutral -> Open Long
|
|
|
|
Action: Long, position: Short -> Close Short and Open Long
|
|
|
|
|
|
|
|
Action: Short, position: Neutral -> Open Short
|
|
|
|
Action: Short, position: Long -> Close Long and Open Short
|
|
|
|
"""
|
|
|
|
|
|
|
|
if action == Actions.Neutral.value:
|
|
|
|
self._position = Positions.Neutral
|
|
|
|
trade_type = "neutral"
|
|
|
|
self._last_trade_tick = None
|
|
|
|
elif action == Actions.Long_enter.value:
|
|
|
|
self._position = Positions.Long
|
|
|
|
trade_type = "long"
|
|
|
|
self._last_trade_tick = self._current_tick
|
|
|
|
elif action == Actions.Short_enter.value:
|
|
|
|
self._position = Positions.Short
|
|
|
|
trade_type = "short"
|
|
|
|
self._last_trade_tick = self._current_tick
|
|
|
|
elif action == Actions.Exit.value:
|
|
|
|
self._position = Positions.Neutral
|
|
|
|
trade_type = "neutral"
|
|
|
|
self._last_trade_tick = None
|
|
|
|
else:
|
|
|
|
print("case not defined")
|
|
|
|
|
|
|
|
if trade_type is not None:
|
|
|
|
self.trade_history.append(
|
|
|
|
{'price': self.current_price(), 'index': self._current_tick,
|
|
|
|
'type': trade_type})
|
|
|
|
|
|
|
|
if self._total_profit < 1 - self.rl_config.get('max_training_drawdown_pct', 0.8):
|
|
|
|
self._done = True
|
|
|
|
|
|
|
|
self._position_history.append(self._position)
|
|
|
|
|
|
|
|
info = dict(
|
|
|
|
tick=self._current_tick,
|
|
|
|
total_reward=self.total_reward,
|
|
|
|
total_profit=self._total_profit,
|
|
|
|
position=self._position.value
|
|
|
|
)
|
|
|
|
|
|
|
|
observation = self._get_observation()
|
|
|
|
|
|
|
|
self._update_history(info)
|
|
|
|
|
|
|
|
return observation, step_reward, self._done, info
|
|
|
|
|
|
|
|
def is_tradesignal(self, action: int):
|
|
|
|
# trade signal
|
|
|
|
"""
|
|
|
|
Determine if the signal is a trade signal
|
|
|
|
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
|
|
|
"""
|
|
|
|
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
|
|
|
(action == Actions.Neutral.value and self._position == Positions.Short) or
|
|
|
|
(action == Actions.Neutral.value and self._position == Positions.Long) or
|
|
|
|
(action == Actions.Short_enter.value and self._position == Positions.Short) or
|
2022-08-28 17:21:57 +00:00
|
|
|
(action == Actions.Short_enter.value and self._position == Positions.Long) or
|
2022-08-25 19:40:16 +00:00
|
|
|
(action == Actions.Exit.value and self._position == Positions.Neutral) or
|
|
|
|
(action == Actions.Long_enter.value and self._position == Positions.Long) or
|
|
|
|
(action == Actions.Long_enter.value and self._position == Positions.Short))
|
|
|
|
|
|
|
|
def _is_valid(self, action: int):
|
|
|
|
# trade signal
|
|
|
|
"""
|
|
|
|
Determine if the signal is valid.
|
|
|
|
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
|
|
|
"""
|
|
|
|
# Agent should only try to exit if it is in position
|
2022-08-28 17:21:57 +00:00
|
|
|
if action == Actions.Exit.value:
|
2022-08-25 19:40:16 +00:00
|
|
|
if self._position not in (Positions.Short, Positions.Long):
|
|
|
|
return False
|
|
|
|
|
|
|
|
# Agent should only try to enter if it is not in position
|
|
|
|
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
|
|
|
|
if self._position != Positions.Neutral:
|
|
|
|
return False
|
|
|
|
|
|
|
|
return True
|