add test coverage, fix bug in base environment. Ensure proper fee is used.

This commit is contained in:
robcaulk 2022-11-13 15:31:37 +01:00
parent 81f800a79b
commit af9e400562
4 changed files with 92 additions and 29 deletions

View File

@ -148,7 +148,6 @@ class Base5ActionRLEnv(BaseEnvironment):
return self._current_tick - self._last_trade_tick return self._current_tick - self._last_trade_tick
def is_tradesignal(self, action: int): def is_tradesignal(self, action: int):
# trade signal
""" """
Determine if the signal is a trade signal Determine if the signal is a trade signal
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short e.g.: agent wants a Actions.Long_exit while it is in a Positions.short

View File

@ -10,6 +10,8 @@ from gym import spaces
from gym.utils import seeding from gym.utils import seeding
from pandas import DataFrame from pandas import DataFrame
from freqtrade.data.dataprovider import DataProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,8 +34,21 @@ class BaseEnvironment(gym.Env):
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(), def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True, reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {}): id: str = 'baseenv-1', seed: int = 1, config: dict = {},
dp: Optional[DataProvider] = None):
"""
Initializes the training/eval environment.
:param df: dataframe of features
:param prices: dataframe of prices to be used in the training environment
:param window_size: size of window (temporal) to pass to the agent
:param reward_kwargs: extra config settings assigned by user in `rl_config`
:param starting_point: start at edge of window or not
:param id: string id of the environment (used in backend for multiprocessed env)
:param seed: Sets the seed of the environment higher in the gym.Env object
:param config: Typical user configuration file
:param dp: dataprovider from freqtrade
"""
self.config = config
self.rl_config = config['freqai']['rl_config'] self.rl_config = config['freqai']['rl_config']
self.add_state_info = self.rl_config.get('add_state_info', False) self.add_state_info = self.rl_config.get('add_state_info', False)
self.id = id self.id = id
@ -41,12 +56,23 @@ class BaseEnvironment(gym.Env):
self.reset_env(df, prices, window_size, reward_kwargs, starting_point) self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8) self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
self.compound_trades = config['stake_amount'] == 'unlimited' self.compound_trades = config['stake_amount'] == 'unlimited'
if self.config.get('fee', None) is not None:
self.fee = self.config['fee']
elif dp is not None:
self.fee = self.dp.exchange.get_fee(symbol=dp.current_whitelist()[0])
else:
self.fee = 0.0015
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int, def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
reward_kwargs: dict, starting_point=True): reward_kwargs: dict, starting_point=True):
""" """
Resets the environment when the agent fails (in our case, if the drawdown Resets the environment when the agent fails (in our case, if the drawdown
exceeds the user set max_training_drawdown_pct) exceeds the user set max_training_drawdown_pct)
:param df: dataframe of features
:param prices: dataframe of prices to be used in the training environment
:param window_size: size of window (temporal) to pass to the agent
:param reward_kwargs: extra config settings assigned by user in `rl_config`
:param starting_point: start at edge of window or not
""" """
self.df = df self.df = df
self.signal_features = self.df self.signal_features = self.df
@ -56,8 +82,6 @@ class BaseEnvironment(gym.Env):
self.rr = reward_kwargs["rr"] self.rr = reward_kwargs["rr"]
self.profit_aim = reward_kwargs["profit_aim"] self.profit_aim = reward_kwargs["profit_aim"]
self.fee = 0.0015
# # spaces # # spaces
if self.add_state_info: if self.add_state_info:
self.total_features = self.signal_features.shape[1] + 3 self.total_features = self.signal_features.shape[1] + 3
@ -233,7 +257,7 @@ class BaseEnvironment(gym.Env):
def _update_total_profit(self): def _update_total_profit(self):
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
if self.compound_trades: if self.compound_trades:
# assumes unitestake and compounding # assumes unit stake and compounding
self._total_profit = self._total_profit * (1 + pnl) self._total_profit = self._total_profit * (1 + pnl)
else: else:
# assumes unit stake and no compounding # assumes unit stake and no compounding

View File

@ -74,10 +74,10 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.ft_params.update({'use_SVM_to_remove_outliers': False}) self.ft_params.update({'use_SVM_to_remove_outliers': False})
logger.warning('User tried to use SVM with RL. Deactivating SVM.') logger.warning('User tried to use SVM with RL. Deactivating SVM.')
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False): if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
self.ft_params.update({'use_SVM_to_remove_outliers': False}) self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.') logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
if self.freqai_info['data_split_parameters'].get('shuffle', False): if self.freqai_info['data_split_parameters'].get('shuffle', False):
self.freqai_info['data_split_parameters'].update('shuffle', False) self.freqai_info['data_split_parameters'].update({'shuffle': False})
logger.warning('User tried to shuffle training data. Setting shuffle to False') logger.warning('User tried to shuffle training data. Setting shuffle to False')
def train( def train(
@ -141,11 +141,18 @@ class BaseReinforcementLearningModel(IFreqaiModel):
train_df = data_dictionary["train_features"] train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"] test_df = data_dictionary["test_features"]
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, self.train_env = self.MyRLEnv(df=train_df,
reward_kwargs=self.reward_params, config=self.config) prices=prices_train,
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test,
window_size=self.CONV_WIDTH, window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, config=self.config)) reward_kwargs=self.reward_params,
config=self.config,
dp=self.data_provider)
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params,
config=self.config,
dp=self.data_provider))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
@ -179,12 +186,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
if trade.pair == pair: if trade.pair == pair:
if self.data_provider._exchange is None: # type: ignore if self.data_provider._exchange is None: # type: ignore
logger.error('No exchange available.') logger.error('No exchange available.')
return 0, 0, 0
else: else:
current_rate = self.data_provider._exchange.get_rate( # type: ignore current_rate = self.data_provider._exchange.get_rate( # type: ignore
pair, refresh=False, side="exit", is_short=trade.is_short) pair, refresh=False, side="exit", is_short=trade.is_short)
now = datetime.now(timezone.utc).timestamp() now = datetime.now(timezone.utc).timestamp()
trade_duration = int((now - trade.open_date_utc) / self.base_tf_seconds) trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
current_profit = trade.calc_profit_ratio(current_rate) current_profit = trade.calc_profit_ratio(current_rate)
return market_side, current_profit, int(trade_duration) return market_side, current_profit, int(trade_duration)
@ -230,7 +238,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def _predict(window): def _predict(window):
observations = dataframe.iloc[window.index] observations = dataframe.iloc[window.index]
if self.live: # self.guard_state_info_if_backtest(): if self.live and self.rl_config('add_state_info', False):
market_side, current_profit, trade_duration = self.get_state_info(dk.pair) market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
observations['current_profit_pct'] = current_profit observations['current_profit_pct'] = current_profit
observations['position'] = market_side observations['position'] = market_side
@ -242,17 +250,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return output return output
# def guard_state_info_if_backtest(self):
# """
# Ensure that backtesting mode doesnt try to use state information.
# """
# if self.rl_config('add_state_info', False) and not self.live:
# logger.warning('Backtesting with state info is currently unavailable '
# 'turning it off.')
# self.rl_config['add_state_info'] = False
# return not self.rl_config['add_state_info']
def build_ohlc_price_dataframes(self, data_dictionary: dict, def build_ohlc_price_dataframes(self, data_dictionary: dict,
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame, pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
DataFrame]: DataFrame]:

View File

@ -13,7 +13,7 @@ from freqtrade.freqai.utils import download_all_data_for_training, get_required_
from freqtrade.optimize.backtesting import Backtesting from freqtrade.optimize.backtesting import Backtesting
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
from freqtrade.plugins.pairlistmanager import PairListManager from freqtrade.plugins.pairlistmanager import PairListManager
from tests.conftest import get_patched_exchange, log_has_re from tests.conftest import create_mock_trades, get_patched_exchange, log_has_re
from tests.freqai.conftest import get_patched_freqai_strategy, make_rl_config from tests.freqai.conftest import get_patched_freqai_strategy, make_rl_config
@ -32,7 +32,7 @@ def is_mac() -> bool:
('XGBoostRegressor', False, True), ('XGBoostRegressor', False, True),
('XGBoostRFRegressor', False, False), ('XGBoostRFRegressor', False, False),
('CatboostRegressor', False, False), ('CatboostRegressor', False, False),
('ReinforcementLearner', False, False), ('ReinforcementLearner', False, True),
('ReinforcementLearner_multiproc', False, False), ('ReinforcementLearner_multiproc', False, False),
('ReinforcementLearner_test_4ac', False, False) ('ReinforcementLearner_test_4ac', False, False)
]) ])
@ -40,7 +40,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
if is_arm() and model == 'CatboostRegressor': if is_arm() and model == 'CatboostRegressor':
pytest.skip("CatBoost is not supported on ARM") pytest.skip("CatBoost is not supported on ARM")
if is_mac(): if is_mac() and 'Reinforcement' in model:
pytest.skip("Reinforcement learning module not available on intel based Mac OS") pytest.skip("Reinforcement learning module not available on intel based Mac OS")
model_save_ext = 'joblib' model_save_ext = 'joblib'
@ -53,6 +53,9 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
if 'ReinforcementLearner' in model: if 'ReinforcementLearner' in model:
model_save_ext = 'zip' model_save_ext = 'zip'
freqai_conf = make_rl_config(freqai_conf) freqai_conf = make_rl_config(freqai_conf)
# test the RL guardrails
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
if 'test_4ac' in model: if 'test_4ac' in model:
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
@ -497,3 +500,43 @@ def test_download_all_data_for_training(mocker, freqai_conf, caplog, tmpdir):
"Downloading", "Downloading",
caplog, caplog,
) )
@pytest.mark.usefixtures("init_persistence")
@pytest.mark.parametrize('dp_exists', [(False), (True)])
def test_get_state_info(mocker, freqai_conf, dp_exists, caplog, tickers):
if is_mac():
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
freqai_conf.update({"freqaimodel": "ReinforcementLearner"})
freqai_conf.update({"timerange": "20180110-20180130"})
freqai_conf.update({"strategy": "freqai_rl_test_strat"})
freqai_conf = make_rl_config(freqai_conf)
freqai_conf['entry_pricing']['price_side'] = 'same'
freqai_conf['exit_pricing']['price_side'] = 'same'
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
exchange = get_patched_exchange(mocker, freqai_conf)
ticker_mock = MagicMock(return_value=tickers()['ETH/BTC'])
mocker.patch("freqtrade.exchange.Exchange.fetch_ticker", ticker_mock)
strategy.dp = DataProvider(freqai_conf, exchange)
if not dp_exists:
strategy.dp._exchange = None
strategy.freqai_info = freqai_conf.get("freqai", {})
freqai = strategy.freqai
freqai.data_provider = strategy.dp
freqai.live = True
Trade.use_db = True
create_mock_trades(MagicMock(return_value=0.0025), False, True)
freqai.get_state_info("ADA/BTC")
freqai.get_state_info("ETH/BTC")
if not dp_exists:
assert log_has_re(
"No exchange available",
caplog,
)