improve nomenclature and fix short exit bug

This commit is contained in:
robcaulk 2022-08-19 11:04:15 +02:00
parent 4baa36bdcf
commit 4b9499e321

View File

@ -13,10 +13,10 @@ logger = logging.getLogger(__name__)
class Actions(Enum): class Actions(Enum):
Neutral = 0 Neutral = 0
Long_buy = 1 Long_enter = 1
Long_sell = 2 Long_exit = 2
Short_buy = 3 Short_enter = 3
Short_sell = 4 Short_exit = 4
class Positions(Enum): class Positions(Enum):
@ -139,16 +139,16 @@ class Base5ActionRLEnv(gym.Env):
if action == Actions.Neutral.value: if action == Actions.Neutral.value:
self._position = Positions.Neutral self._position = Positions.Neutral
trade_type = "neutral" trade_type = "neutral"
elif action == Actions.Long_buy.value: elif action == Actions.Long_enter.value:
self._position = Positions.Long self._position = Positions.Long
trade_type = "long" trade_type = "long"
elif action == Actions.Short_buy.value: elif action == Actions.Short_enter.value:
self._position = Positions.Short self._position = Positions.Short
trade_type = "short" trade_type = "short"
elif action == Actions.Long_sell.value: elif action == Actions.Long_exit.value:
self._position = Positions.Neutral self._position = Positions.Neutral
trade_type = "neutral" trade_type = "neutral"
elif action == Actions.Short_sell.value: elif action == Actions.Short_exit.value:
self._position = Positions.Neutral self._position = Positions.Neutral
trade_type = "neutral" trade_type = "neutral"
else: else:
@ -221,24 +221,24 @@ class Base5ActionRLEnv(gym.Env):
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
(action == Actions.Neutral.value and self._position == Positions.Short) or (action == Actions.Neutral.value and self._position == Positions.Short) or
(action == Actions.Neutral.value and self._position == Positions.Long) or (action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Short_buy.value and self._position == Positions.Short) or (action == Actions.Short_enter.value and self._position == Positions.Short) or
(action == Actions.Short_buy.value and self._position == Positions.Long) or (action == Actions.Short_enter.value and self._position == Positions.Long) or
(action == Actions.Short_sell.value and self._position == Positions.Short) or (action == Actions.Short_exit.value and self._position == Positions.Short) or
(action == Actions.Short_sell.value and self._position == Positions.Long) or (action == Actions.Short_exit.value and self._position == Positions.Long) or
(action == Actions.Short_sell.value and self._position == Positions.Neutral) or (action == Actions.Short_exit.value and self._position == Positions.Neutral) or
(action == Actions.Long_buy.value and self._position == Positions.Long) or (action == Actions.Long_enter.value and self._position == Positions.Long) or
(action == Actions.Long_buy.value and self._position == Positions.Short) or (action == Actions.Long_enter.value and self._position == Positions.Short) or
(action == Actions.Long_sell.value and self._position == Positions.Long) or (action == Actions.Long_exit.value and self._position == Positions.Long) or
(action == Actions.Long_sell.value and self._position == Positions.Short) or (action == Actions.Long_exit.value and self._position == Positions.Short) or
(action == Actions.Long_sell.value and self._position == Positions.Neutral)) (action == Actions.Long_exit.value and self._position == Positions.Neutral))
def _is_trade(self, action: Actions): def _is_trade(self, action: Actions):
return ((action == Actions.Long_buy.value and self._position == Positions.Neutral) or return ((action == Actions.Long_enter.value and self._position == Positions.Neutral) or
(action == Actions.Short_buy.value and self._position == Positions.Neutral)) (action == Actions.Short_enter.value and self._position == Positions.Neutral))
def is_hold(self, action): def is_hold(self, action):
return ((action == Actions.Short_buy.value and self._position == Positions.Short) or return ((action == Actions.Short_enter.value and self._position == Positions.Short) or
(action == Actions.Long_buy.value and self._position == Positions.Long) or (action == Actions.Long_enter.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Long) or (action == Actions.Neutral.value and self._position == Positions.Long) or
(action == Actions.Neutral.value and self._position == Positions.Short) or (action == Actions.Neutral.value and self._position == Positions.Short) or
(action == Actions.Neutral.value and self._position == Positions.Neutral)) (action == Actions.Neutral.value and self._position == Positions.Neutral))
@ -265,7 +265,7 @@ class Base5ActionRLEnv(gym.Env):
return 0. return 0.
# close long # close long
if action == Actions.Long_sell.value and self._position == Positions.Long: if action == Actions.Long_exit.value and self._position == Positions.Long:
if len(self.close_trade_profit): if len(self.close_trade_profit):
# aim x2 rw # aim x2 rw
if self.close_trade_profit[-1] > self.profit_aim * self.rr: if self.close_trade_profit[-1] > self.profit_aim * self.rr:
@ -292,7 +292,7 @@ class Base5ActionRLEnv(gym.Env):
# return float((np.log(current_price) - np.log(last_trade_price)) * 2) * -1 # return float((np.log(current_price) - np.log(last_trade_price)) * 2) * -1
# close short # close short
if action == Actions.Short_buy.value and self._position == Positions.Short: if action == Actions.Short_exit.value and self._position == Positions.Short:
if len(self.close_trade_profit): if len(self.close_trade_profit):
# aim x2 rw # aim x2 rw
if self.close_trade_profit[-1] > self.profit_aim * self.rr: if self.close_trade_profit[-1] > self.profit_aim * self.rr:
@ -346,7 +346,7 @@ class Base5ActionRLEnv(gym.Env):
# Long positions # Long positions
if self._position == Positions.Long: if self._position == Positions.Long:
current_price = self.prices.iloc[self._current_tick].open current_price = self.prices.iloc[self._current_tick].open
if action == Actions.Short_buy.value or action == Actions.Neutral.value: if action == Actions.Short_enter.value or action == Actions.Neutral.value:
current_price = self.add_sell_fee(current_price) current_price = self.add_sell_fee(current_price)
previous_price = self.prices.iloc[self._current_tick - 1].open previous_price = self.prices.iloc[self._current_tick - 1].open
@ -360,7 +360,7 @@ class Base5ActionRLEnv(gym.Env):
# Short positions # Short positions
if self._position == Positions.Short: if self._position == Positions.Short:
current_price = self.prices.iloc[self._current_tick].open current_price = self.prices.iloc[self._current_tick].open
if action == Actions.Long_buy.value or action == Actions.Neutral.value: if action == Actions.Long_enter.value or action == Actions.Neutral.value:
current_price = self.add_buy_fee(current_price) current_price = self.add_buy_fee(current_price)
previous_price = self.prices.iloc[self._current_tick - 1].open previous_price = self.prices.iloc[self._current_tick - 1].open