improve default reward, fix bugs in environment

This commit is contained in:
robcaulk
2022-08-24 18:32:40 +02:00
parent a61821e1c6
commit d1bee29b1e
3 changed files with 102 additions and 53 deletions

View File

@@ -140,30 +140,32 @@ class Base5ActionRLEnv(gym.Env):
if action == Actions.Neutral.value:
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
elif action == Actions.Long_enter.value:
self._position = Positions.Long
trade_type = "long"
self._last_trade_tick = self._current_tick
elif action == Actions.Short_enter.value:
self._position = Positions.Short
trade_type = "short"
self._last_trade_tick = self._current_tick
elif action == Actions.Long_exit.value:
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
elif action == Actions.Short_exit.value:
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
else:
print("case not defined")
# Update last trade tick
self._last_trade_tick = self._current_tick
if trade_type is not None:
self.trade_history.append(
{'price': self.current_price(), 'index': self._current_tick,
'type': trade_type})
if self._total_profit < 0.2:
if self._total_profit < 0.5:
self._done = True
self._position_history.append(self._position)
@@ -221,8 +223,7 @@ class Base5ActionRLEnv(gym.Env):
def is_tradesignal(self, action: int):
# trade signal
"""
not trade signal is :
Determine if the signal is non sensical
Determine if the signal is a trade signal
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
"""
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
@@ -237,6 +238,24 @@ class Base5ActionRLEnv(gym.Env):
(action == Actions.Long_exit.value and self._position == Positions.Short) or
(action == Actions.Long_exit.value and self._position == Positions.Neutral))
def _is_valid(self, action: int):
# trade signal
"""
Determine if the signal is valid.
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
"""
# Agent should only try to exit if it is in position
if action in (Actions.Short_exit.value, Actions.Long_exit.value):
if self._position not in (Positions.Short, Positions.Long):
return False
# Agent should only try to enter if it is not in position
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
if self._position != Positions.Neutral:
return False
return True
def _is_trade(self, action: Actions):
return ((action == Actions.Long_enter.value and self._position == Positions.Neutral) or
(action == Actions.Short_enter.value and self._position == Positions.Neutral))
@@ -278,13 +297,8 @@ class Base5ActionRLEnv(gym.Env):
if self._is_trade(action) or self._done:
pnl = self.get_unrealized_profit()
if self._position == Positions.Long:
self._total_profit = self._total_profit + self._total_profit * pnl
self._profits.append((self._current_tick, self._total_profit))
self.close_trade_profit.append(pnl)
if self._position == Positions.Short:
self._total_profit = self._total_profit + self._total_profit * pnl
if self._position in (Positions.Long, Positions.Short):
self._total_profit *= (1 + pnl)
self._profits.append((self._current_tick, self._total_profit))
self.close_trade_profit.append(pnl)