diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py b/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py index 5ec917719..8f5fe4e03 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py @@ -72,11 +72,9 @@ class ReinforcementLearningTDQN(BaseReinforcementLearningModel): class Actions(Enum): - Neutral = 0 - Long_buy = 1 - Long_sell = 2 - Short_buy = 3 - Short_sell = 4 + Short = 0 + Long = 1 + Neutral = 2 class Positions(Enum): @@ -181,36 +179,31 @@ class MyRLEnv(BaseRLEnv): self.total_reward += step_reward trade_type = None - if self.is_tradesignal(action): # exclude 3 case not trade + if self.is_tradesignal(action): # exclude 3 case not trade # Update position """ - Action: Neutral, position: Long -> Close Long - Action: Neutral, position: Short -> Close Short - - Action: Long, position: Neutral -> Open Long + Action: Neutral, position: Long -> Close Long + Action: Neutral, position: Short -> Close Short + + Action: Long, position: Neutral -> Open Long Action: Long, position: Short -> Close Short and Open Long - - Action: Short, position: Neutral -> Open Short + + Action: Short, position: Neutral -> Open Short Action: Short, position: Long -> Close Long and Open Short """ - + + temp_position = self._position if action == Actions.Neutral.value: self._position = Positions.Neutral trade_type = "neutral" - elif action == Actions.Long_buy.value: + elif action == Actions.Long.value: self._position = Positions.Long trade_type = "long" - elif action == Actions.Short_buy.value: + elif action == Actions.Short.value: self._position = Positions.Short trade_type = "short" - elif action == Actions.Long_sell.value: - self._position = Positions.Neutral - trade_type = "neutral" - elif action == Actions.Short_sell.value: - self._position = Positions.Neutral - trade_type = "neutral" else: - print("case not defined") + print("case not define") # Update last trade tick self._last_trade_tick = self._current_tick @@ -257,33 +250,23 @@ class MyRLEnv(BaseRLEnv): return 0. def is_tradesignal(self, action): - # trade signal + # trade signal """ not trade signal is : - Action: Neutral, position: Neutral -> Nothing + Action: Neutral, position: Neutral -> Nothing Action: Long, position: Long -> Hold Long Action: Short, position: Short -> Hold Short """ - return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or - (action == Actions.Short_buy.value and self._position == Positions.Short) or - (action == Actions.Short_sell.value and self._position == Positions.Short) or - (action == Actions.Short_buy.value and self._position == Positions.Long) or - (action == Actions.Short_sell.value and self._position == Positions.Long) or + return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) + or (action == Actions.Short.value and self._position == Positions.Short) + or (action == Actions.Long.value and self._position == Positions.Long)) - (action == Actions.Long_buy.value and self._position == Positions.Long) or - (action == Actions.Long_sell.value and self._position == Positions.Long) or - (action == Actions.Long_buy.value and self._position == Positions.Short) or - (action == Actions.Long_sell.value and self._position == Positions.Short)) - - def _is_trade(self, action): - return ((action == Actions.Long_buy.value and self._position == Positions.Short) or - (action == Actions.Short_buy.value and self._position == Positions.Long) or - (action == Actions.Neutral.value and self._position == Positions.Long) or - (action == Actions.Neutral.value and self._position == Positions.Short) or - - (action == Actions.Neutral.Short_sell and self._position == Positions.Long) or - (action == Actions.Neutral.Long_sell and self._position == Positions.Short) - ) + def _is_trade(self, action: Actions): + return ((action == Actions.Long.value and self._position == Positions.Short) or + (action == Actions.Short.value and self._position == Positions.Long) or + (action == Actions.Neutral.value and self._position == Positions.Long) or + (action == Actions.Neutral.value and self._position == Positions.Short) + ) def is_hold(self, action): return ((action == Actions.Short.value and self._position == Positions.Short)