3ac to 5ac

This commit is contained in:
MukavaValkku 2022-08-15 12:45:08 +03:00 committed by robcaulk
parent 718c9d0440
commit 096533bcb9

View File

@ -72,9 +72,11 @@ class ReinforcementLearningTDQN(BaseReinforcementLearningModel):
class Actions(Enum):
Short = 0
Long = 1
Neutral = 2
Neutral = 0
Long_buy = 1
Long_sell = 2
Short_buy = 3
Short_sell = 4
class Positions(Enum):
@ -192,18 +194,23 @@ class MyRLEnv(BaseRLEnv):
Action: Short, position: Long -> Close Long and Open Short
"""
temp_position = self._position
if action == Actions.Neutral.value:
self._position = Positions.Neutral
trade_type = "neutral"
elif action == Actions.Long.value:
elif action == Actions.Long_buy.value:
self._position = Positions.Long
trade_type = "long"
elif action == Actions.Short.value:
elif action == Actions.Short_buy.value:
self._position = Positions.Short
trade_type = "short"
elif action == Actions.Long_sell.value:
self._position = Positions.Neutral
trade_type = "neutral"
elif action == Actions.Short_sell.value:
self._position = Positions.Neutral
trade_type = "neutral"
else:
print("case not define")
print("case not defined")
# Update last trade tick
self._last_trade_tick = self._current_tick
@ -257,15 +264,25 @@ class MyRLEnv(BaseRLEnv):
Action: Long, position: Long -> Hold Long
Action: Short, position: Short -> Hold Short
"""
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral)
or (action == Actions.Short.value and self._position == Positions.Short)
or (action == Actions.Long.value and self._position == Positions.Long))
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
(action == Actions.Short_buy.value and self._position == Positions.Short) or
(action == Actions.Short_sell.value and self._position == Positions.Short) or
(action == Actions.Short_buy.value and self._position == Positions.Long) or
(action == Actions.Short_sell.value and self._position == Positions.Long) or
def _is_trade(self, action: Actions):
return ((action == Actions.Long.value and self._position == Positions.Short) or
(action == Actions.Short.value and self._position == Positions.Long) or
(action == Actions.Long_buy.value and self._position == Positions.Long) or
(action == Actions.Long_sell.value and self._position == Positions.Long) or
(action == Actions.Long_buy.value and self._position == Positions.Short) or
(action == Actions.Long_sell.value and self._position == Positions.Short))
def _is_trade(self, action):
return ((action == Actions.Long_buy.value and self._position == Positions.Short) or
(action == Actions.Short_buy.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Short)
(action == Actions.Neutral.value and self._position == Positions.Short) or
(action == Actions.Neutral.Short_sell and self._position == Positions.Long) or
(action == Actions.Neutral.Long_sell and self._position == Positions.Short)
)
def is_hold(self, action):