improve typing, improve docstrings, ensure global tests pass

This commit is contained in:
robcaulk 2022-09-23 19:17:27 +02:00
parent 9c361f4422
commit 77c360b264
7 changed files with 124 additions and 40 deletions

View File

@ -25,6 +25,17 @@ class Base4ActionRLEnv(BaseEnvironment):
self.action_space = spaces.Discrete(len(Actions))
def step(self, action: int):
"""
Logic for a single step (incrementing one candle in time)
by the agent
:param: action: int = the action type that the agent plans
to take for the current step.
:returns:
observation = current state of environment
step_reward = the reward from `calculate_reward()`
_done = if the agent "died" or if the candles finished
info = dict passed back to openai gym lib
"""
self._done = False
self._current_tick += 1
@ -92,7 +103,6 @@ class Base4ActionRLEnv(BaseEnvironment):
return observation, step_reward, self._done, info
def is_tradesignal(self, action: int):
# trade signal
"""
Determine if the signal is a trade signal
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
@ -107,7 +117,6 @@ class Base4ActionRLEnv(BaseEnvironment):
(action == Actions.Long_enter.value and self._position == Positions.Short))
def _is_valid(self, action: int):
# trade signal
"""
Determine if the signal is valid.
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short

View File

@ -60,6 +60,17 @@ class Base5ActionRLEnv(BaseEnvironment):
return self._get_observation()
def step(self, action: int):
"""
Logic for a single step (incrementing one candle in time)
by the agent
:param: action: int = the action type that the agent plans
to take for the current step.
:returns:
observation = current state of environment
step_reward = the reward from `calculate_reward()`
_done = if the agent "died" or if the candles finished
info = dict passed back to openai gym lib
"""
self._done = False
self._current_tick += 1

View File

@ -43,6 +43,10 @@ class BaseEnvironment(gym.Env):
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
reward_kwargs: dict, starting_point=True):
"""
Resets the environment when the agent fails (in our case, if the drawdown
exceeds the user set max_training_drawdown_pct)
"""
self.df = df
self.signal_features = self.df
self.prices = prices
@ -133,13 +137,18 @@ class BaseEnvironment(gym.Env):
return features_and_state
def get_trade_duration(self):
"""
Get the trade duration if the agent is in a trade
"""
if self._last_trade_tick is None:
return 0
else:
return self._current_tick - self._last_trade_tick
def get_unrealized_profit(self):
"""
Get the unrealized profit if the agent is in a trade
"""
if self._last_trade_tick is None:
return 0.
@ -158,7 +167,6 @@ class BaseEnvironment(gym.Env):
@abstractmethod
def is_tradesignal(self, action: int):
# trade signal
"""
Determine if the signal is a trade signal. This is
unique to the actions in the environment, and therefore must be
@ -167,7 +175,6 @@ class BaseEnvironment(gym.Env):
return
def _is_valid(self, action: int):
# trade signal
"""
Determine if the signal is valid.This is
unique to the actions in the environment, and therefore must be
@ -191,8 +198,13 @@ class BaseEnvironment(gym.Env):
@abstractmethod
def calculate_reward(self, action):
"""
Reward is created by BaseReinforcementLearningModel and can
be inherited/edited by the user made ReinforcementLearner file.
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:params:
action: int = The action made by the agent for the current candle.
:returns:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
def _update_unrealized_total_profit(self):

View File

@ -2,7 +2,7 @@ import logging
from abc import abstractmethod
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Tuple, Type, Union
import gym
import numpy as np
@ -19,8 +19,9 @@ from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
from freqtrade.freqai.RL.BaseEnvironment import Positions
from freqtrade.persistence import Trade
from stable_baselines3.common.vec_env import SubprocVecEnv
logger = logging.getLogger(__name__)
@ -33,15 +34,15 @@ SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO']
class BaseReinforcementLearningModel(IFreqaiModel):
"""
User created Reinforcement Learning Model prediction model.
User created Reinforcement Learning Model prediction class
"""
def __init__(self, **kwargs):
super().__init__(config=kwargs['config'])
th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4))
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: BaseEnvironment = None
self.eval_env: BaseEnvironment = None
self.train_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_callback: EvalCallback = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
@ -126,6 +127,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
dk: FreqaiDataKitchen):
"""
User can override this if they are using a custom MyRLEnv
:params:
data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
prices_train/test: DataFrame = dataframe comprised of the prices to be used in the
environment during training
or testing
dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
@ -148,15 +156,24 @@ class BaseReinforcementLearningModel(IFreqaiModel):
"""
return
def get_state_info(self, pair: str):
def get_state_info(self, pair: str) -> Tuple[float, float, int]:
"""
State info during dry/live/backtesting which is fed back
into the model.
:param:
pair: str = COIN/STAKE to get the environment information for
:returns:
market_side: float = representing short, long, or neutral for
pair
trade_duration: int = the number of candles that the trade has
been open for
"""
open_trades = Trade.get_trades_proxy(is_open=True)
market_side = 0.5
current_profit: float = 0
trade_duration = 0
for trade in open_trades:
if trade.pair == pair:
# FIXME: get_rate and trade_udration shouldn't work with backtesting,
# we need to use candle dates and prices to compute that.
if self.strategy.dp._exchange is None: # type: ignore
logger.error('No exchange available.')
else:
@ -172,11 +189,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
market_side = 0
current_profit = (openrate - current_value) / openrate
# total_profit = 0
# closed_trades = Trade.get_trades_proxy(pair=pair, is_open=False)
# for trade in closed_trades:
# total_profit += trade.close_profit
return market_side, current_profit, int(trade_duration)
def predict(
@ -209,7 +221,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def rl_model_predict(self, dataframe: DataFrame,
dk: FreqaiDataKitchen, model: Any) -> DataFrame:
"""
A helper function to make predictions in the Reinforcement learning module.
:params:
dataframe: DataFrame = the dataframe of features to make the predictions on
dk: FreqaiDatakitchen = data kitchen for the current pair
model: Any = the trained model used to inference the features.
"""
output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
def _predict(window):
@ -274,26 +292,37 @@ class BaseReinforcementLearningModel(IFreqaiModel):
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action):
def calculate_reward(self, action: int) -> float:
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:params:
action: int = The action made by the agent for the current candle.
:returns:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2
pnl = self.get_unrealized_profit()
rew = np.sign(pnl) * (pnl + 1)
factor = 100
factor = 100.
# reward agent for entering trades
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
and self._position == Positions.Neutral:
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
and self._position == Positions.Neutral):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick
if self._last_trade_tick:
trade_duration = self._current_tick - self._last_trade_tick
else:
trade_duration = 0
if trade_duration <= max_trade_duration:
factor *= 1.5
@ -301,8 +330,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
factor *= 0.5
# discourage sitting in position
if self._position in (Positions.Short, Positions.Long) and \
action == Actions.Neutral.value:
if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value):
return -1 * trade_duration / max_trade_duration
# close long
@ -320,7 +349,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return 0.
def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int,
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}) -> Callable:

View File

@ -19,7 +19,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
"""
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
"""
User customizable fit method
:params:
data_dictionary: dict = common data dictionary containing all train/test
features/labels/weights.
dk: FreqaiDatakitchen = data kitchen for current pair.
:returns:
model: Any = trained model to be used for inference in dry/live/backtesting
"""
train_df = data_dictionary["train_features"]
total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df)
@ -59,7 +67,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
"""
def calculate_reward(self, action):
"""
An example reward function. This is the one function that users will likely
wish to inject their own creativity into.
:params:
action: int = The action made by the agent for the current candle.
:returns:
float = the reward to give to the agent for current step (used for optimization
of weights in NN)
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2

View File

@ -6,7 +6,7 @@ from typing import Any, Dict # , Tuple
import torch as th
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from pandas import DataFrame
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.BaseReinforcementLearningModel import (BaseReinforcementLearningModel,
make_env)
@ -55,11 +55,18 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
return model
def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk):
def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any],
prices_train: DataFrame, prices_test: DataFrame,
dk: FreqaiDataKitchen):
"""
If user has particular environment configuration needs, they can do that by
overriding this function. In the present case, the user wants to setup training
environments for multiple workers.
User can override this if they are using a custom MyRLEnv
:params:
data_dictionary: dict = common data dictionary containing train and test
features/labels/weights.
prices_train/test: DataFrame = dataframe comprised of the prices to be used in
the environment during training
or testing
dk: FreqaiDataKitchen = the datakitchen for the current pair
"""
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
@ -79,4 +86,4 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel):
in range(num_cpu)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=dk.data_path)
best_model_save_path=str(dk.data_path))

View File

@ -244,7 +244,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat):
model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()]
assert len(model_folders) == num_files
Trade.use_db = True
shutil.rmtree(Path(freqai.dk.full_path))
@ -297,7 +297,7 @@ def test_start_backtesting_from_existing_folder(mocker, freqai_conf, caplog):
assert len(model_folders) == 6
# without deleting the exiting folder structure, re-run
# without deleting the existing folder structure, re-run
freqai_conf.update({"timerange": "20180120-20180130"})
strategy = get_patched_freqai_strategy(mocker, freqai_conf)