improve typing, improve docstrings, ensure global tests pass
This commit is contained in:
@@ -2,7 +2,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from typing import Any, Callable, Dict, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@@ -19,8 +19,9 @@ from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
|
||||
from freqtrade.freqai.RL.BaseEnvironment import Positions
|
||||
from freqtrade.persistence import Trade
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -33,15 +34,15 @@ SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO']
|
||||
|
||||
class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
"""
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
User created Reinforcement Learning Model prediction class
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(config=kwargs['config'])
|
||||
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
|
||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||
self.train_env: BaseEnvironment = None
|
||||
self.eval_env: BaseEnvironment = None
|
||||
self.train_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
|
||||
self.eval_callback: EvalCallback = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
@@ -126,6 +127,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
dk: FreqaiDataKitchen):
|
||||
"""
|
||||
User can override this if they are using a custom MyRLEnv
|
||||
:params:
|
||||
data_dictionary: dict = common data dictionary containing train and test
|
||||
features/labels/weights.
|
||||
prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
|
||||
environment during training
|
||||
or testing
|
||||
dk: FreqaiDataKitchen = the datakitchen for the current pair
|
||||
"""
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
@@ -148,15 +156,24 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
"""
|
||||
return
|
||||
|
||||
def get_state_info(self, pair: str):
|
||||
def get_state_info(self, pair: str) -> Tuple[float, float, int]:
|
||||
"""
|
||||
State info during dry/live/backtesting which is fed back
|
||||
into the model.
|
||||
:param:
|
||||
pair: str = COIN/STAKE to get the environment information for
|
||||
:returns:
|
||||
market_side: float = representing short, long, or neutral for
|
||||
pair
|
||||
trade_duration: int = the number of candles that the trade has
|
||||
been open for
|
||||
"""
|
||||
open_trades = Trade.get_trades_proxy(is_open=True)
|
||||
market_side = 0.5
|
||||
current_profit: float = 0
|
||||
trade_duration = 0
|
||||
for trade in open_trades:
|
||||
if trade.pair == pair:
|
||||
# FIXME: get_rate and trade_udration shouldn't work with backtesting,
|
||||
# we need to use candle dates and prices to compute that.
|
||||
if self.strategy.dp._exchange is None: # type: ignore
|
||||
logger.error('No exchange available.')
|
||||
else:
|
||||
@@ -172,11 +189,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
market_side = 0
|
||||
current_profit = (openrate - current_value) / openrate
|
||||
|
||||
# total_profit = 0
|
||||
# closed_trades = Trade.get_trades_proxy(pair=pair, is_open=False)
|
||||
# for trade in closed_trades:
|
||||
# total_profit += trade.close_profit
|
||||
|
||||
return market_side, current_profit, int(trade_duration)
|
||||
|
||||
def predict(
|
||||
@@ -209,7 +221,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
def rl_model_predict(self, dataframe: DataFrame,
|
||||
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
|
||||
|
||||
"""
|
||||
A helper function to make predictions in the Reinforcement learning module.
|
||||
:params:
|
||||
dataframe: DataFrame = the dataframe of features to make the predictions on
|
||||
dk: FreqaiDatakitchen = data kitchen for the current pair
|
||||
model: Any = the trained model used to inference the features.
|
||||
"""
|
||||
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
|
||||
|
||||
def _predict(window):
|
||||
@@ -274,26 +292,37 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action):
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
:params:
|
||||
action: int = The action made by the agent for the current candle.
|
||||
:returns:
|
||||
float = the reward to give to the agent for current step (used for optimization
|
||||
of weights in NN)
|
||||
"""
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
rew = np.sign(pnl) * (pnl + 1)
|
||||
factor = 100
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
||||
and self._position == Positions.Neutral:
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||
trade_duration = self._current_tick - self._last_trade_tick
|
||||
if self._last_trade_tick:
|
||||
trade_duration = self._current_tick - self._last_trade_tick
|
||||
else:
|
||||
trade_duration = 0
|
||||
|
||||
if trade_duration <= max_trade_duration:
|
||||
factor *= 1.5
|
||||
@@ -301,8 +330,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if self._position in (Positions.Short, Positions.Long) and \
|
||||
action == Actions.Neutral.value:
|
||||
if (self._position in (Positions.Short, Positions.Long) and
|
||||
action == Actions.Neutral.value):
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close long
|
||||
@@ -320,7 +349,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
return 0.
|
||||
|
||||
|
||||
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
|
||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}) -> Callable:
|
||||
|
Reference in New Issue
Block a user