add test coverage, fix bug in base environment. Ensure proper fee is used.

This commit is contained in:
robcaulk
2022-11-13 15:31:37 +01:00
parent 81f800a79b
commit af9e400562
4 changed files with 92 additions and 29 deletions

View File

@@ -148,7 +148,6 @@ class Base5ActionRLEnv(BaseEnvironment):
return self._current_tick - self._last_trade_tick
def is_tradesignal(self, action: int):
# trade signal
"""
Determine if the signal is a trade signal
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short

View File

@@ -10,6 +10,8 @@ from gym import spaces
from gym.utils import seeding
from pandas import DataFrame
from freqtrade.data.dataprovider import DataProvider
logger = logging.getLogger(__name__)
@@ -32,8 +34,21 @@ class BaseEnvironment(gym.Env):
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
dp: Optional[DataProvider] = None):
"""
Initializes the training/eval environment.
:param df: dataframe of features
:param prices: dataframe of prices to be used in the training environment
:param window_size: size of window (temporal) to pass to the agent
:param reward_kwargs: extra config settings assigned by user in `rl_config`
:param starting_point: start at edge of window or not
:param id: string id of the environment (used in backend for multiprocessed env)
:param seed: Sets the seed of the environment higher in the gym.Env object
:param config: Typical user configuration file
:param dp: dataprovider from freqtrade
"""
self.config = config
self.rl_config = config['freqai']['rl_config']
self.add_state_info = self.rl_config.get('add_state_info', False)
self.id = id
@@ -41,12 +56,23 @@ class BaseEnvironment(gym.Env):
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
self.compound_trades = config['stake_amount'] == 'unlimited'
if self.config.get('fee', None) is not None:
self.fee = self.config['fee']
elif dp is not None:
self.fee = self.dp.exchange.get_fee(symbol=dp.current_whitelist()[0])
else:
self.fee = 0.0015
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
reward_kwargs: dict, starting_point=True):
"""
Resets the environment when the agent fails (in our case, if the drawdown
exceeds the user set max_training_drawdown_pct)
:param df: dataframe of features
:param prices: dataframe of prices to be used in the training environment
:param window_size: size of window (temporal) to pass to the agent
:param reward_kwargs: extra config settings assigned by user in `rl_config`
:param starting_point: start at edge of window or not
"""
self.df = df
self.signal_features = self.df
@@ -56,8 +82,6 @@ class BaseEnvironment(gym.Env):
self.rr = reward_kwargs["rr"]
self.profit_aim = reward_kwargs["profit_aim"]
self.fee = 0.0015
# # spaces
if self.add_state_info:
self.total_features = self.signal_features.shape[1] + 3
@@ -233,7 +257,7 @@ class BaseEnvironment(gym.Env):
def _update_total_profit(self):
pnl = self.get_unrealized_profit()
if self.compound_trades:
# assumes unitestake and compounding
# assumes unit stake and compounding
self._total_profit = self._total_profit * (1 + pnl)
else:
# assumes unit stake and no compounding

View File

@@ -74,10 +74,10 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.ft_params.update({'use_SVM_to_remove_outliers': False})
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
self.ft_params.update({'use_SVM_to_remove_outliers': False})
self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
if self.freqai_info['data_split_parameters'].get('shuffle', False):
self.freqai_info['data_split_parameters'].update('shuffle', False)
self.freqai_info['data_split_parameters'].update({'shuffle': False})
logger.warning('User tried to shuffle training data. Setting shuffle to False')
def train(
@@ -141,11 +141,18 @@ class BaseReinforcementLearningModel(IFreqaiModel):
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config)
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config))
self.train_env = self.MyRLEnv(df=train_df,
prices=prices_train,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params,
config=self.config,
dp=self.data_provider)
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params,
config=self.config,
dp=self.data_provider))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
@@ -179,12 +186,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
if trade.pair == pair:
if self.data_provider._exchange is None: # type: ignore
logger.error('No exchange available.')
return 0, 0, 0
else:
current_rate = self.data_provider._exchange.get_rate( # type: ignore
pair, refresh=False, side="exit", is_short=trade.is_short)
now = datetime.now(timezone.utc).timestamp()
trade_duration = int((now - trade.open_date_utc) / self.base_tf_seconds)
trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
current_profit = trade.calc_profit_ratio(current_rate)
return market_side, current_profit, int(trade_duration)
@@ -230,7 +238,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def _predict(window):
observations = dataframe.iloc[window.index]
if self.live: # self.guard_state_info_if_backtest():
if self.live and self.rl_config('add_state_info', False):
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
observations['current_profit_pct'] = current_profit
observations['position'] = market_side
@@ -242,17 +250,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return output
# def guard_state_info_if_backtest(self):
# """
# Ensure that backtesting mode doesnt try to use state information.
# """
# if self.rl_config('add_state_info', False) and not self.live:
# logger.warning('Backtesting with state info is currently unavailable '
# 'turning it off.')
# self.rl_config['add_state_info'] = False
# return not self.rl_config['add_state_info']
def build_ohlc_price_dataframes(self, data_dictionary: dict,
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
DataFrame]: