import logging from enum import Enum from gym import spaces from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions logger = logging.getLogger(__name__) class Actions(Enum): Neutral = 0 Exit = 1 Long_enter = 2 Short_enter = 3 class Base4ActionRLEnv(BaseEnvironment): """ Base class for a 4 action environment """ def set_action_space(self): self.action_space = spaces.Discrete(len(Actions)) def step(self, action: int): self._done = False self._current_tick += 1 if self._current_tick == self._end_tick: self._done = True self.update_portfolio_log_returns(action) self._update_profit(action) step_reward = self.calculate_reward(action) self.total_reward += step_reward trade_type = None if self.is_tradesignal(action): """ Action: Neutral, position: Long -> Close Long Action: Neutral, position: Short -> Close Short Action: Long, position: Neutral -> Open Long Action: Long, position: Short -> Close Short and Open Long Action: Short, position: Neutral -> Open Short Action: Short, position: Long -> Close Long and Open Short """ if action == Actions.Neutral.value: self._position = Positions.Neutral trade_type = "neutral" self._last_trade_tick = None elif action == Actions.Long_enter.value: self._position = Positions.Long trade_type = "long" self._last_trade_tick = self._current_tick elif action == Actions.Short_enter.value: self._position = Positions.Short trade_type = "short" self._last_trade_tick = self._current_tick elif action == Actions.Exit.value: self._position = Positions.Neutral trade_type = "neutral" self._last_trade_tick = None elif action == Actions.Exit.value: self._position = Positions.Neutral trade_type = "neutral" self._last_trade_tick = None else: print("case not defined") if trade_type is not None: self.trade_history.append( {'price': self.current_price(), 'index': self._current_tick, 'type': trade_type}) if self._total_profit < 1 - self.rl_config.get('max_training_drawdown_pct', 0.8): self._done = True self._position_history.append(self._position) info = dict( tick=self._current_tick, total_reward=self.total_reward, total_profit=self._total_profit, position=self._position.value ) observation = self._get_observation() self._update_history(info) return observation, step_reward, self._done, info def is_tradesignal(self, action: int): # trade signal """ Determine if the signal is a trade signal e.g.: agent wants a Actions.Long_exit while it is in a Positions.short """ return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or (action == Actions.Neutral.value and self._position == Positions.Short) or (action == Actions.Neutral.value and self._position == Positions.Long) or (action == Actions.Short_enter.value and self._position == Positions.Short) or (action == Actions.Short_enter.value and self._position == Positions.Long) or (action == Actions.Exit.value and self._position == Positions.Neutral) or (action == Actions.Long_enter.value and self._position == Positions.Long) or (action == Actions.Long_enter.value and self._position == Positions.Short)) def _is_valid(self, action: int): # trade signal """ Determine if the signal is valid. e.g.: agent wants a Actions.Long_exit while it is in a Positions.short """ # Agent should only try to exit if it is in position if action == Actions.Exit.value: if self._position not in (Positions.Short, Positions.Long): return False # Agent should only try to enter if it is not in position if action in (Actions.Short_enter.value, Actions.Long_enter.value): if self._position != Positions.Neutral: return False return True