improve default reward, fix bugs in environment
This commit is contained in:
@@ -140,30 +140,32 @@ class Base5ActionRLEnv(gym.Env):
|
||||
if action == Actions.Neutral.value:
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
elif action == Actions.Long_enter.value:
|
||||
self._position = Positions.Long
|
||||
trade_type = "long"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Short_enter.value:
|
||||
self._position = Positions.Short
|
||||
trade_type = "short"
|
||||
self._last_trade_tick = self._current_tick
|
||||
elif action == Actions.Long_exit.value:
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
elif action == Actions.Short_exit.value:
|
||||
self._position = Positions.Neutral
|
||||
trade_type = "neutral"
|
||||
self._last_trade_tick = None
|
||||
else:
|
||||
print("case not defined")
|
||||
|
||||
# Update last trade tick
|
||||
self._last_trade_tick = self._current_tick
|
||||
|
||||
if trade_type is not None:
|
||||
self.trade_history.append(
|
||||
{'price': self.current_price(), 'index': self._current_tick,
|
||||
'type': trade_type})
|
||||
|
||||
if self._total_profit < 0.2:
|
||||
if self._total_profit < 0.5:
|
||||
self._done = True
|
||||
|
||||
self._position_history.append(self._position)
|
||||
@@ -221,8 +223,7 @@ class Base5ActionRLEnv(gym.Env):
|
||||
def is_tradesignal(self, action: int):
|
||||
# trade signal
|
||||
"""
|
||||
not trade signal is :
|
||||
Determine if the signal is non sensical
|
||||
Determine if the signal is a trade signal
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or
|
||||
@@ -237,6 +238,24 @@ class Base5ActionRLEnv(gym.Env):
|
||||
(action == Actions.Long_exit.value and self._position == Positions.Short) or
|
||||
(action == Actions.Long_exit.value and self._position == Positions.Neutral))
|
||||
|
||||
def _is_valid(self, action: int):
|
||||
# trade signal
|
||||
"""
|
||||
Determine if the signal is valid.
|
||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||
"""
|
||||
# Agent should only try to exit if it is in position
|
||||
if action in (Actions.Short_exit.value, Actions.Long_exit.value):
|
||||
if self._position not in (Positions.Short, Positions.Long):
|
||||
return False
|
||||
|
||||
# Agent should only try to enter if it is not in position
|
||||
if action in (Actions.Short_enter.value, Actions.Long_enter.value):
|
||||
if self._position != Positions.Neutral:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_trade(self, action: Actions):
|
||||
return ((action == Actions.Long_enter.value and self._position == Positions.Neutral) or
|
||||
(action == Actions.Short_enter.value and self._position == Positions.Neutral))
|
||||
@@ -278,13 +297,8 @@ class Base5ActionRLEnv(gym.Env):
|
||||
if self._is_trade(action) or self._done:
|
||||
pnl = self.get_unrealized_profit()
|
||||
|
||||
if self._position == Positions.Long:
|
||||
self._total_profit = self._total_profit + self._total_profit * pnl
|
||||
self._profits.append((self._current_tick, self._total_profit))
|
||||
self.close_trade_profit.append(pnl)
|
||||
|
||||
if self._position == Positions.Short:
|
||||
self._total_profit = self._total_profit + self._total_profit * pnl
|
||||
if self._position in (Positions.Long, Positions.Short):
|
||||
self._total_profit *= (1 + pnl)
|
||||
self._profits.append((self._current_tick, self._total_profit))
|
||||
self.close_trade_profit.append(pnl)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user