action fix

This commit is contained in:
MukavaValkku 2022-08-15 12:29:44 +03:00 committed by robcaulk
parent 9c78e6c26f
commit 718c9d0440

View File

@ -72,11 +72,9 @@ class ReinforcementLearningTDQN(BaseReinforcementLearningModel):
class Actions(Enum): class Actions(Enum):
Neutral = 0 Short = 0
Long_buy = 1 Long = 1
Long_sell = 2 Neutral = 2
Short_buy = 3
Short_sell = 4
class Positions(Enum): class Positions(Enum):
@ -181,7 +179,7 @@ class MyRLEnv(BaseRLEnv):
self.total_reward += step_reward self.total_reward += step_reward
trade_type = None trade_type = None
if self.is_tradesignal(action): # exclude 3 case not trade if self.is_tradesignal(action): # exclude 3 case not trade
# Update position # Update position
""" """
Action: Neutral, position: Long -> Close Long Action: Neutral, position: Long -> Close Long
@ -194,23 +192,18 @@ class MyRLEnv(BaseRLEnv):
Action: Short, position: Long -> Close Long and Open Short Action: Short, position: Long -> Close Long and Open Short
""" """
temp_position = self._position
if action == Actions.Neutral.value: if action == Actions.Neutral.value:
self._position = Positions.Neutral self._position = Positions.Neutral
trade_type = "neutral" trade_type = "neutral"
elif action == Actions.Long_buy.value: elif action == Actions.Long.value:
self._position = Positions.Long self._position = Positions.Long
trade_type = "long" trade_type = "long"
elif action == Actions.Short_buy.value: elif action == Actions.Short.value:
self._position = Positions.Short self._position = Positions.Short
trade_type = "short" trade_type = "short"
elif action == Actions.Long_sell.value:
self._position = Positions.Neutral
trade_type = "neutral"
elif action == Actions.Short_sell.value:
self._position = Positions.Neutral
trade_type = "neutral"
else: else:
print("case not defined") print("case not define")
# Update last trade tick # Update last trade tick
self._last_trade_tick = self._current_tick self._last_trade_tick = self._current_tick
@ -264,26 +257,16 @@ class MyRLEnv(BaseRLEnv):
Action: Long, position: Long -> Hold Long Action: Long, position: Long -> Hold Long
Action: Short, position: Short -> Hold Short Action: Short, position: Short -> Hold Short
""" """
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or return not ((action == Actions.Neutral.value and self._position == Positions.Neutral)
(action == Actions.Short_buy.value and self._position == Positions.Short) or or (action == Actions.Short.value and self._position == Positions.Short)
(action == Actions.Short_sell.value and self._position == Positions.Short) or or (action == Actions.Long.value and self._position == Positions.Long))
(action == Actions.Short_buy.value and self._position == Positions.Long) or
(action == Actions.Short_sell.value and self._position == Positions.Long) or
(action == Actions.Long_buy.value and self._position == Positions.Long) or def _is_trade(self, action: Actions):
(action == Actions.Long_sell.value and self._position == Positions.Long) or return ((action == Actions.Long.value and self._position == Positions.Short) or
(action == Actions.Long_buy.value and self._position == Positions.Short) or (action == Actions.Short.value and self._position == Positions.Long) or
(action == Actions.Long_sell.value and self._position == Positions.Short)) (action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Short)
def _is_trade(self, action): )
return ((action == Actions.Long_buy.value and self._position == Positions.Short) or
(action == Actions.Short_buy.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Short) or
(action == Actions.Neutral.Short_sell and self._position == Positions.Long) or
(action == Actions.Neutral.Long_sell and self._position == Positions.Short)
)
def is_hold(self, action): def is_hold(self, action):
return ((action == Actions.Short.value and self._position == Positions.Short) return ((action == Actions.Short.value and self._position == Positions.Short)