improve typing, improve docstrings, ensure global tests pass

This commit is contained in:
robcaulk
2022-09-23 19:17:27 +02:00
parent 9c361f4422
commit 77c360b264
7 changed files with 124 additions and 40 deletions

View File

@@ -2,7 +2,7 @@ import logging
from abc import abstractmethod
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Tuple, Type, Union
import gym
import numpy as np
@@ -19,8 +19,9 @@ from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
from freqtrade.freqai.RL.BaseEnvironment import Positions
from freqtrade.persistence import Trade
from stable_baselines3.common.vec_env import SubprocVecEnv
logger = logging.getLogger(__name__)
@@ -33,15 +34,15 @@ SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO']
class BaseReinforcementLearningModel(IFreqaiModel):
"""
User created Reinforcement Learning Model prediction model.
User created Reinforcement Learning Model prediction class
"""
def __init__(self, **kwargs):
super().__init__(config=kwargs['config'])
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: BaseEnvironment = None
self.eval_env: BaseEnvironment = None
self.train_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_callback: EvalCallback = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
@@ -126,6 +127,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
dk: FreqaiDataKitchen):
"""
User can override this if they are using a custom MyRLEnv
:params:
data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
environment during training
or testing
dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
@@ -148,15 +156,24 @@ class BaseReinforcementLearningModel(IFreqaiModel):
"""
return
def get_state_info(self, pair: str):
def get_state_info(self, pair: str) -> Tuple[float, float, int]:
"""
State info during dry/live/backtesting which is fed back
into the model.
:param:
pair: str = COIN/STAKE to get the environment information for
:returns:
market_side: float = representing short, long, or neutral for
pair
trade_duration: int = the number of candles that the trade has
been open for
"""
open_trades = Trade.get_trades_proxy(is_open=True)
market_side = 0.5
current_profit: float = 0
trade_duration = 0
for trade in open_trades:
if trade.pair == pair:
# FIXME: get_rate and trade_udration shouldn't work with backtesting,
# we need to use candle dates and prices to compute that.
if self.strategy.dp._exchange is None: # type: ignore
logger.error('No exchange available.')
else:
@@ -172,11 +189,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
market_side = 0
current_profit = (openrate - current_value) / openrate
# total_profit = 0
# closed_trades = Trade.get_trades_proxy(pair=pair, is_open=False)
# for trade in closed_trades:
# total_profit += trade.close_profit
return market_side, current_profit, int(trade_duration)
def predict(
@@ -209,7 +221,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def rl_model_predict(self, dataframe: DataFrame,
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
"""
A helper function to make predictions in the Reinforcement learning module.
:params:
dataframe: DataFrame = the dataframe of features to make the predictions on
dk: FreqaiDatakitchen = data kitchen for the current pair
model: Any = the trained model used to inference the features.
"""
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
def _predict(window):
@@ -274,26 +292,37 @@ class BaseReinforcementLearningModel(IFreqaiModel):
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action):
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:params:
action: int = The action made by the agent for the current candle.
:returns:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2
pnl = self.get_unrealized_profit()
rew = np.sign(pnl) * (pnl + 1)
factor = 100
factor = 100.
# reward agent for entering trades
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
and self._position == Positions.Neutral:
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
and self._position == Positions.Neutral):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick
if self._last_trade_tick:
trade_duration = self._current_tick - self._last_trade_tick
else:
trade_duration = 0
if trade_duration <= max_trade_duration:
factor *= 1.5
@@ -301,8 +330,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
factor *= 0.5
# discourage sitting in position
if self._position in (Positions.Short, Positions.Long) and \
action == Actions.Neutral.value:
if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value):
return -1 * trade_duration / max_trade_duration
# close long
@@ -320,7 +349,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return 0.
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}) -> Callable: