ensure typing, remove unsued code

This commit is contained in:
robcaulk 2022-11-26 12:11:59 +01:00
parent 8dbfd2cacf
commit 81fd2e588f
5 changed files with 46 additions and 44 deletions

View File

@ -195,7 +195,7 @@ As you begin to modify the strategy and the prediction model, you will quickly r
Users can override any functions from those parent classes. Here is an example
of a user customized `calculate_reward()` function.
"""
def calculate_reward(self, action):
def calculate_reward(self, action: int) -> float:
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2

View File

@ -158,7 +158,7 @@ class Base5ActionRLEnv(BaseEnvironment):
(action == Actions.Long_exit.value and self._position == Positions.Short) or
(action == Actions.Long_exit.value and self._position == Positions.Neutral))
def _is_valid(self, action: int):
def _is_valid(self, action: int) -> bool:
# trade signal
"""
Determine if the signal is valid.

View File

@ -208,13 +208,13 @@ class BaseEnvironment(gym.Env):
"""
return
def _is_valid(self, action: int):
def _is_valid(self, action: int) -> bool:
"""
Determine if the signal is valid.This is
unique to the actions in the environment, and therefore must be
inherited.
"""
return
return True
def add_entry_fee(self, price):
return price * (1 + self.fee)
@ -230,7 +230,7 @@ class BaseEnvironment(gym.Env):
self.history[key].append(value)
@abstractmethod
def calculate_reward(self, action):
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
@ -263,38 +263,40 @@ class BaseEnvironment(gym.Env):
# assumes unit stake and no compounding
self._total_profit += pnl
def most_recent_return(self, action: int):
"""
Calculate the tick to tick return if in a trade.
Return is generated from rising prices in Long
and falling prices in Short positions.
The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
"""
# Long positions
if self._position == Positions.Long:
current_price = self.prices.iloc[self._current_tick].open
previous_price = self.prices.iloc[self._current_tick - 1].open
if (self._position_history[self._current_tick - 1] == Positions.Short
or self._position_history[self._current_tick - 1] == Positions.Neutral):
previous_price = self.add_entry_fee(previous_price)
return np.log(current_price) - np.log(previous_price)
# Short positions
if self._position == Positions.Short:
current_price = self.prices.iloc[self._current_tick].open
previous_price = self.prices.iloc[self._current_tick - 1].open
if (self._position_history[self._current_tick - 1] == Positions.Long
or self._position_history[self._current_tick - 1] == Positions.Neutral):
previous_price = self.add_exit_fee(previous_price)
return np.log(previous_price) - np.log(current_price)
return 0
def update_portfolio_log_returns(self, action):
self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)
def current_price(self) -> float:
return self.prices.iloc[self._current_tick].open
# Keeping around incase we want to start building more complex environment
# templates in the future.
# def most_recent_return(self):
# """
# Calculate the tick to tick return if in a trade.
# Return is generated from rising prices in Long
# and falling prices in Short positions.
# The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
# """
# # Long positions
# if self._position == Positions.Long:
# current_price = self.prices.iloc[self._current_tick].open
# previous_price = self.prices.iloc[self._current_tick - 1].open
# if (self._position_history[self._current_tick - 1] == Positions.Short
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
# previous_price = self.add_entry_fee(previous_price)
# return np.log(current_price) - np.log(previous_price)
# # Short positions
# if self._position == Positions.Short:
# current_price = self.prices.iloc[self._current_tick].open
# previous_price = self.prices.iloc[self._current_tick - 1].open
# if (self._position_history[self._current_tick - 1] == Positions.Long
# or self._position_history[self._current_tick - 1] == Positions.Neutral):
# previous_price = self.add_exit_fee(previous_price)
# return np.log(previous_price) - np.log(current_price)
# return 0
# def update_portfolio_log_returns(self, action):
# self.portfolio_log_returns[self._current_tick] = self.most_recent_return(action)

View File

@ -89,7 +89,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action):
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
@ -103,7 +103,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
return -2
pnl = self.get_unrealized_profit()
factor = 100
factor = 100.
# reward agent for entering trades
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
@ -114,7 +114,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
if trade_duration <= max_trade_duration:
factor *= 1.5

View File

@ -20,7 +20,7 @@ class ReinforcementLearner_test_4ac(ReinforcementLearner):
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action):
def calculate_reward(self, action: int) -> float:
# first, penalize if the action is not valid
if not self._is_valid(action):
@ -28,7 +28,7 @@ class ReinforcementLearner_test_4ac(ReinforcementLearner):
pnl = self.get_unrealized_profit()
rew = np.sign(pnl) * (pnl + 1)
factor = 100
factor = 100.
# reward agent for entering trades
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
@ -39,7 +39,7 @@ class ReinforcementLearner_test_4ac(ReinforcementLearner):
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
if trade_duration <= max_trade_duration:
factor *= 1.5