ensure typing, remove unsued code

This commit is contained in:
robcaulk 2022-11-26 12:11:59 +01:00
parent 8dbfd2cacf
commit 81fd2e588f
5 changed files with 46 additions and 44 deletions

View File

@ -195,7 +195,7 @@ As you begin to modify the strategy and the prediction model, you will quickly r
Users can override any functions from those parent classes. Here is an example Users can override any functions from those parent classes. Here is an example
of a user customized `calculate_reward()` function. of a user customized `calculate_reward()` function.
""" """
def calculate_reward(self, action): def calculate_reward(self, action: int) -> float:
# first, penalize if the action is not valid # first, penalize if the action is not valid
if not self._is_valid(action): if not self._is_valid(action):
return -2 return -2

View File

@ -158,7 +158,7 @@ class Base5ActionRLEnv(BaseEnvironment):
(action == Actions.Long_exit.value and self._position == Positions.Short) or (action == Actions.Long_exit.value and self._position == Positions.Short) or
(action == Actions.Long_exit.value and self._position == Positions.Neutral)) (action == Actions.Long_exit.value and self._position == Positions.Neutral))
def _is_valid(self, action: int): def _is_valid(self, action: int) -> bool:
# trade signal # trade signal
""" """
Determine if the signal is valid. Determine if the signal is valid.

View File

@ -208,13 +208,13 @@ class BaseEnvironment(gym.Env):
""" """
return return
def _is_valid(self, action: int): def _is_valid(self, action: int) -> bool:
""" """
Determine if the signal is valid.This is Determine if the signal is valid.This is
unique to the actions in the environment, and therefore must be unique to the actions in the environment, and therefore must be
inherited. inherited.
""" """
return return True
def add_entry_fee(self, price): def add_entry_fee(self, price):
return price * (1 + self.fee) return price * (1 + self.fee)
@ -230,7 +230,7 @@ class BaseEnvironment(gym.Env):
self.history[key].append(value) self.history[key].append(value)
@abstractmethod @abstractmethod
def calculate_reward(self, action): def calculate_reward(self, action: int) -> float:
""" """
An example reward function. This is the one function that users will likely An example reward function. This is the one function that users will likely
wish to inject their own creativity into. wish to inject their own creativity into.
@ -263,38 +263,40 @@ class BaseEnvironment(gym.Env):
# assumes unit stake and no compounding # assumes unit stake and no compounding
self._total_profit += pnl self._total_profit += pnl
def most_recent_return(self, action: int):
"""
Calculate the tick to tick return if in a trade.
Return is generated from rising prices in Long
and falling prices in Short positions.
The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
"""
# Long positions
if self._position == Positions.Long:
current_price = self.prices.iloc[self._current_tick].open
previous_price = self.prices.iloc[self._current_tick - 1].open
if (self._position_history[self._current_tick - 1] == Positions.Short
or self._position_history[self._current_tick - 1] == Positions.Neutral):
previous_price = self.add_entry_fee(previous_price)
return np.log(current_price) - np.log(previous_price)
# Short positions
if self._position == Positions.Short:
current_price = self.prices.iloc[self._current_tick].open
previous_price = self.prices.iloc[self._current_tick - 1].open
if (self._position_history[self._current_tick - 1] == Positions.Long
or self._position_history[self._current_tick - 1] == Positions.Neutral):
previous_price = self.add_exit_fee(previous_price)
return np.log(previous_price) - np.log(current_price)
return 0
def update_portfolio_log_returns(self, action):
self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
def current_price(self) -> float: def current_price(self) -> float:
return self.prices.iloc[self._current_tick].open return self.prices.iloc[self._current_tick].open
# Keeping around incase we want to start building more complex environment
# templates in the future.
# def most_recent_return(self):
# """
# Calculate the tick to tick return if in a trade.
# Return is generated from rising prices in Long
# and falling prices in Short positions.
# The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
# """
# # Long positions
# if self._position == Positions.Long:
# current_price = self.prices.iloc[self._current_tick].open
# previous_price = self.prices.iloc[self._current_tick - 1].open
# if (self._position_history[self._current_tick - 1] == Positions.Short
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
# previous_price = self.add_entry_fee(previous_price)
# return np.log(current_price) - np.log(previous_price)
# # Short positions
# if self._position == Positions.Short:
# current_price = self.prices.iloc[self._current_tick].open
# previous_price = self.prices.iloc[self._current_tick - 1].open
# if (self._position_history[self._current_tick - 1] == Positions.Long
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
# previous_price = self.add_exit_fee(previous_price)
# return np.log(previous_price) - np.log(current_price)
# return 0
# def update_portfolio_log_returns(self, action):
# self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)

View File

@ -89,7 +89,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
sets a custom reward based on profit and trade duration. sets a custom reward based on profit and trade duration.
""" """
def calculate_reward(self, action): def calculate_reward(self, action: int) -> float:
""" """
An example reward function. This is the one function that users will likely An example reward function. This is the one function that users will likely
wish to inject their own creativity into. wish to inject their own creativity into.
@ -103,7 +103,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
return -2 return -2
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
factor = 100 factor = 100.
# reward agent for entering trades # reward agent for entering trades
if (action in (Actions.Long_enter.value, Actions.Short_enter.value) if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
@ -114,7 +114,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
return -1 return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick trade_duration = self._current_tick - self._last_trade_tick # type: ignore
if trade_duration <= max_trade_duration: if trade_duration <= max_trade_duration:
factor *= 1.5 factor *= 1.5

View File

@ -20,7 +20,7 @@ class ReinforcementLearner_test_4ac(ReinforcementLearner):
sets a custom reward based on profit and trade duration. sets a custom reward based on profit and trade duration.
""" """
def calculate_reward(self, action): def calculate_reward(self, action: int) -> float:
# first, penalize if the action is not valid # first, penalize if the action is not valid
if not self._is_valid(action): if not self._is_valid(action):
@ -28,7 +28,7 @@ class ReinforcementLearner_test_4ac(ReinforcementLearner):
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
rew = np.sign(pnl) * (pnl + 1) rew = np.sign(pnl) * (pnl + 1)
factor = 100 factor = 100.
# reward agent for entering trades # reward agent for entering trades
if (action in (Actions.Long_enter.value, Actions.Short_enter.value) if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
@ -39,7 +39,7 @@ class ReinforcementLearner_test_4ac(ReinforcementLearner):
return -1 return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick trade_duration = self._current_tick - self._last_trade_tick # type: ignore
if trade_duration <= max_trade_duration: if trade_duration <= max_trade_duration:
factor *= 1.5 factor *= 1.5