From 77c360b264c9dee489081c2761cc3be4ba0b01d1 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Fri, 23 Sep 2022 19:17:27 +0200 Subject: [PATCH] improve typing, improve docstrings, ensure global tests pass --- freqtrade/freqai/RL/Base4ActionRLEnv.py | 13 +++- freqtrade/freqai/RL/Base5ActionRLEnv.py | 11 +++ freqtrade/freqai/RL/BaseEnvironment.py | 22 ++++-- .../RL/BaseReinforcementLearningModel.py | 75 +++++++++++++------ .../prediction_models/ReinforcementLearner.py | 20 ++++- .../ReinforcementLearner_multiproc.py | 19 +++-- tests/freqai/test_freqai_interface.py | 4 +- 7 files changed, 124 insertions(+), 40 deletions(-) diff --git a/freqtrade/freqai/RL/Base4ActionRLEnv.py b/freqtrade/freqai/RL/Base4ActionRLEnv.py index bd5785b85..b4fe78b71 100644 --- a/freqtrade/freqai/RL/Base4ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base4ActionRLEnv.py @@ -25,6 +25,17 @@ class Base4ActionRLEnv(BaseEnvironment): self.action_space = spaces.Discrete(len(Actions)) def step(self, action: int): + """ + Logic for a single step (incrementing one candle in time) + by the agent + :param: action: int = the action type that the agent plans + to take for the current step. + :returns: + observation = current state of environment + step_reward = the reward from `calculate_reward()` + _done = if the agent "died" or if the candles finished + info = dict passed back to openai gym lib + """ self._done = False self._current_tick += 1 @@ -92,7 +103,6 @@ class Base4ActionRLEnv(BaseEnvironment): return observation, step_reward, self._done, info def is_tradesignal(self, action: int): - # trade signal """ Determine if the signal is a trade signal e.g.: agent wants a Actions.Long_exit while it is in a Positions.short @@ -107,7 +117,6 @@ class Base4ActionRLEnv(BaseEnvironment): (action == Actions.Long_enter.value and self._position == Positions.Short)) def _is_valid(self, action: int): - # trade signal """ Determine if the signal is valid. e.g.: agent wants a Actions.Long_exit while it is in a Positions.short diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index e0a38f9d1..80543bf72 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -60,6 +60,17 @@ class Base5ActionRLEnv(BaseEnvironment): return self._get_observation() def step(self, action: int): + """ + Logic for a single step (incrementing one candle in time) + by the agent + :param: action: int = the action type that the agent plans + to take for the current step. + :returns: + observation = current state of environment + step_reward = the reward from `calculate_reward()` + _done = if the agent "died" or if the candles finished + info = dict passed back to openai gym lib + """ self._done = False self._current_tick += 1 diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 200b7d138..6474483c6 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -43,6 +43,10 @@ class BaseEnvironment(gym.Env): def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int, reward_kwargs: dict, starting_point=True): + """ + Resets the environment when the agent fails (in our case, if the drawdown + exceeds the user set max_training_drawdown_pct) + """ self.df = df self.signal_features = self.df self.prices = prices @@ -133,13 +137,18 @@ class BaseEnvironment(gym.Env): return features_and_state def get_trade_duration(self): + """ + Get the trade duration if the agent is in a trade + """ if self._last_trade_tick is None: return 0 else: return self._current_tick - self._last_trade_tick def get_unrealized_profit(self): - + """ + Get the unrealized profit if the agent is in a trade + """ if self._last_trade_tick is None: return 0. @@ -158,7 +167,6 @@ class BaseEnvironment(gym.Env): @abstractmethod def is_tradesignal(self, action: int): - # trade signal """ Determine if the signal is a trade signal. This is unique to the actions in the environment, and therefore must be @@ -167,7 +175,6 @@ class BaseEnvironment(gym.Env): return def _is_valid(self, action: int): - # trade signal """ Determine if the signal is valid.This is unique to the actions in the environment, and therefore must be @@ -191,8 +198,13 @@ class BaseEnvironment(gym.Env): @abstractmethod def calculate_reward(self, action): """ - Reward is created by BaseReinforcementLearningModel and can - be inherited/edited by the user made ReinforcementLearner file. + An example reward function. This is the one function that users will likely + wish to inject their own creativity into. + :params: + action: int = The action made by the agent for the current candle. + :returns: + float = the reward to give to the agent for current step (used for optimization + of weights in NN) """ def _update_unrealized_total_profit(self): diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index d10bf4dc3..c82fd1ea9 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -2,7 +2,7 @@ import logging from abc import abstractmethod from datetime import datetime, timezone from pathlib import Path -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Tuple, Type, Union import gym import numpy as np @@ -19,8 +19,9 @@ from freqtrade.exceptions import OperationalException from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv -from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions +from freqtrade.freqai.RL.BaseEnvironment import Positions from freqtrade.persistence import Trade +from stable_baselines3.common.vec_env import SubprocVecEnv logger = logging.getLogger(__name__) @@ -33,15 +34,15 @@ SB3_CONTRIB_MODELS = ['TRPO', 'ARS', 'RecurrentPPO', 'MaskablePPO'] class BaseReinforcementLearningModel(IFreqaiModel): """ - User created Reinforcement Learning Model prediction model. + User created Reinforcement Learning Model prediction class """ def __init__(self, **kwargs): super().__init__(config=kwargs['config']) th.set_num_threads(self.freqai_info['rl_config'].get('thread_count', 4)) self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] - self.train_env: BaseEnvironment = None - self.eval_env: BaseEnvironment = None + self.train_env: Union[SubprocVecEnv, gym.Env] = None + self.eval_env: Union[SubprocVecEnv, gym.Env] = None self.eval_callback: EvalCallback = None self.model_type = self.freqai_info['rl_config']['model_type'] self.rl_config = self.freqai_info['rl_config'] @@ -126,6 +127,13 @@ class BaseReinforcementLearningModel(IFreqaiModel): dk: FreqaiDataKitchen): """ User can override this if they are using a custom MyRLEnv + :params: + data_dictionary: dict = common data dictionary containing train and test + features/labels/weights. + prices_train/test: DataFrame = dataframe comprised of the prices to be used in the + environment during training + or testing + dk: FreqaiDataKitchen = the datakitchen for the current pair """ train_df = data_dictionary["train_features"] test_df = data_dictionary["test_features"] @@ -148,15 +156,24 @@ class BaseReinforcementLearningModel(IFreqaiModel): """ return - def get_state_info(self, pair: str): + def get_state_info(self, pair: str) -> Tuple[float, float, int]: + """ + State info during dry/live/backtesting which is fed back + into the model. + :param: + pair: str = COIN/STAKE to get the environment information for + :returns: + market_side: float = representing short, long, or neutral for + pair + trade_duration: int = the number of candles that the trade has + been open for + """ open_trades = Trade.get_trades_proxy(is_open=True) market_side = 0.5 current_profit: float = 0 trade_duration = 0 for trade in open_trades: if trade.pair == pair: - # FIXME: get_rate and trade_udration shouldn't work with backtesting, - # we need to use candle dates and prices to compute that. if self.strategy.dp._exchange is None: # type: ignore logger.error('No exchange available.') else: @@ -172,11 +189,6 @@ class BaseReinforcementLearningModel(IFreqaiModel): market_side = 0 current_profit = (openrate - current_value) / openrate - # total_profit = 0 - # closed_trades = Trade.get_trades_proxy(pair=pair, is_open=False) - # for trade in closed_trades: - # total_profit += trade.close_profit - return market_side, current_profit, int(trade_duration) def predict( @@ -209,7 +221,13 @@ class BaseReinforcementLearningModel(IFreqaiModel): def rl_model_predict(self, dataframe: DataFrame, dk: FreqaiDataKitchen, model: Any) -> DataFrame: - + """ + A helper function to make predictions in the Reinforcement learning module. + :params: + dataframe: DataFrame = the dataframe of features to make the predictions on + dk: FreqaiDatakitchen = data kitchen for the current pair + model: Any = the trained model used to inference the features. + """ output = pd.DataFrame(np.zeros(len(dataframe)), columns=dk.label_list) def _predict(window): @@ -274,26 +292,37 @@ class BaseReinforcementLearningModel(IFreqaiModel): sets a custom reward based on profit and trade duration. """ - def calculate_reward(self, action): - + def calculate_reward(self, action: int) -> float: + """ + An example reward function. This is the one function that users will likely + wish to inject their own creativity into. + :params: + action: int = The action made by the agent for the current candle. + :returns: + float = the reward to give to the agent for current step (used for optimization + of weights in NN) + """ # first, penalize if the action is not valid if not self._is_valid(action): return -2 pnl = self.get_unrealized_profit() rew = np.sign(pnl) * (pnl + 1) - factor = 100 + factor = 100. # reward agent for entering trades - if action in (Actions.Long_enter.value, Actions.Short_enter.value) \ - and self._position == Positions.Neutral: + if (action in (Actions.Long_enter.value, Actions.Short_enter.value) + and self._position == Positions.Neutral): return 25 # discourage agent from not entering trades if action == Actions.Neutral.value and self._position == Positions.Neutral: return -1 max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) - trade_duration = self._current_tick - self._last_trade_tick + if self._last_trade_tick: + trade_duration = self._current_tick - self._last_trade_tick + else: + trade_duration = 0 if trade_duration <= max_trade_duration: factor *= 1.5 @@ -301,8 +330,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): factor *= 0.5 # discourage sitting in position - if self._position in (Positions.Short, Positions.Long) and \ - action == Actions.Neutral.value: + if (self._position in (Positions.Short, Positions.Long) and + action == Actions.Neutral.value): return -1 * trade_duration / max_trade_duration # close long @@ -320,7 +349,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): return 0. -def make_env(MyRLEnv: BaseEnvironment, env_id: str, rank: int, +def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame, reward_params: Dict[str, int], window_size: int, monitor: bool = False, config: Dict[str, Any] = {}) -> Callable: diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 2e5c9f97b..00afd61d4 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -19,7 +19,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel): """ def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): - + """ + User customizable fit method + :params: + data_dictionary: dict = common data dictionary containing all train/test + features/labels/weights. + dk: FreqaiDatakitchen = data kitchen for current pair. + :returns: + model: Any = trained model to be used for inference in dry/live/backtesting + """ train_df = data_dictionary["train_features"] total_timesteps = self.freqai_info["rl_config"]["train_cycles"] * len(train_df) @@ -59,7 +67,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel): """ def calculate_reward(self, action): - + """ + An example reward function. This is the one function that users will likely + wish to inject their own creativity into. + :params: + action: int = The action made by the agent for the current candle. + :returns: + float = the reward to give to the agent for current step (used for optimization + of weights in NN) + """ # first, penalize if the action is not valid if not self._is_valid(action): return -2 diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index c14511921..5b2ea2ef5 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -6,7 +6,7 @@ from typing import Any, Dict # , Tuple import torch as th from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.vec_env import SubprocVecEnv - +from pandas import DataFrame from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.RL.BaseReinforcementLearningModel import (BaseReinforcementLearningModel, make_env) @@ -55,11 +55,18 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): return model - def set_train_and_eval_environments(self, data_dictionary, prices_train, prices_test, dk): + def set_train_and_eval_environments(self, data_dictionary: Dict[str, Any], + prices_train: DataFrame, prices_test: DataFrame, + dk: FreqaiDataKitchen): """ - If user has particular environment configuration needs, they can do that by - overriding this function. In the present case, the user wants to setup training - environments for multiple workers. + User can override this if they are using a custom MyRLEnv + :params: + data_dictionary: dict = common data dictionary containing train and test + features/labels/weights. + prices_train/test: DataFrame = dataframe comprised of the prices to be used in + the environment during training + or testing + dk: FreqaiDataKitchen = the datakitchen for the current pair """ train_df = data_dictionary["train_features"] test_df = data_dictionary["test_features"] @@ -79,4 +86,4 @@ class ReinforcementLearner_multiproc(BaseReinforcementLearningModel): in range(num_cpu)]) self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), - best_model_save_path=dk.data_path) + best_model_save_path=str(dk.data_path)) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index f0af90f18..1bc30a670 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -244,7 +244,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat): model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()] assert len(model_folders) == num_files - + Trade.use_db = True shutil.rmtree(Path(freqai.dk.full_path)) @@ -297,7 +297,7 @@ def test_start_backtesting_from_existing_folder(mocker, freqai_conf, caplog): assert len(model_folders) == 6 - # without deleting the exiting folder structure, re-run + # without deleting the existing folder structure, re-run freqai_conf.update({"timerange": "20180120-20180130"}) strategy = get_patched_freqai_strategy(mocker, freqai_conf)