diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index b6ebcf703..0d101ee9c 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -148,7 +148,6 @@ class Base5ActionRLEnv(BaseEnvironment): return self._current_tick - self._last_trade_tick def is_tradesignal(self, action: int): - # trade signal """ Determine if the signal is a trade signal e.g.: agent wants a Actions.Long_exit while it is in a Positions.short diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 3b56fc2c4..bb8cd992c 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -10,6 +10,8 @@ from gym import spaces from gym.utils import seeding from pandas import DataFrame +from freqtrade.data.dataprovider import DataProvider + logger = logging.getLogger(__name__) @@ -32,8 +34,21 @@ class BaseEnvironment(gym.Env): def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(), reward_kwargs: dict = {}, window_size=10, starting_point=True, - id: str = 'baseenv-1', seed: int = 1, config: dict = {}): - + id: str = 'baseenv-1', seed: int = 1, config: dict = {}, + dp: Optional[DataProvider] = None): + """ + Initializes the training/eval environment. + :param df: dataframe of features + :param prices: dataframe of prices to be used in the training environment + :param window_size: size of window (temporal) to pass to the agent + :param reward_kwargs: extra config settings assigned by user in `rl_config` + :param starting_point: start at edge of window or not + :param id: string id of the environment (used in backend for multiprocessed env) + :param seed: Sets the seed of the environment higher in the gym.Env object + :param config: Typical user configuration file + :param dp: dataprovider from freqtrade + """ + self.config = config self.rl_config = config['freqai']['rl_config'] self.add_state_info = self.rl_config.get('add_state_info', False) self.id = id @@ -41,12 +56,23 @@ class BaseEnvironment(gym.Env): self.reset_env(df, prices, window_size, reward_kwargs, starting_point) self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8) self.compound_trades = config['stake_amount'] == 'unlimited' + if self.config.get('fee', None) is not None: + self.fee = self.config['fee'] + elif dp is not None: + self.fee = self.dp.exchange.get_fee(symbol=dp.current_whitelist()[0]) + else: + self.fee = 0.0015 def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int, reward_kwargs: dict, starting_point=True): """ Resets the environment when the agent fails (in our case, if the drawdown exceeds the user set max_training_drawdown_pct) + :param df: dataframe of features + :param prices: dataframe of prices to be used in the training environment + :param window_size: size of window (temporal) to pass to the agent + :param reward_kwargs: extra config settings assigned by user in `rl_config` + :param starting_point: start at edge of window or not """ self.df = df self.signal_features = self.df @@ -56,8 +82,6 @@ class BaseEnvironment(gym.Env): self.rr = reward_kwargs["rr"] self.profit_aim = reward_kwargs["profit_aim"] - self.fee = 0.0015 - # # spaces if self.add_state_info: self.total_features = self.signal_features.shape[1] + 3 @@ -233,7 +257,7 @@ class BaseEnvironment(gym.Env): def _update_total_profit(self): pnl = self.get_unrealized_profit() if self.compound_trades: - # assumes unitestake and compounding + # assumes unit stake and compounding self._total_profit = self._total_profit * (1 + pnl) else: # assumes unit stake and no compounding diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 85756ad8f..a8c79ce6e 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -74,10 +74,10 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.ft_params.update({'use_SVM_to_remove_outliers': False}) logger.warning('User tried to use SVM with RL. Deactivating SVM.') if self.ft_params.get('use_DBSCAN_to_remove_outliers', False): - self.ft_params.update({'use_SVM_to_remove_outliers': False}) + self.ft_params.update({'use_DBSCAN_to_remove_outliers': False}) logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.') if self.freqai_info['data_split_parameters'].get('shuffle', False): - self.freqai_info['data_split_parameters'].update('shuffle', False) + self.freqai_info['data_split_parameters'].update({'shuffle': False}) logger.warning('User tried to shuffle training data. Setting shuffle to False') def train( @@ -141,11 +141,18 @@ class BaseReinforcementLearningModel(IFreqaiModel): train_df = data_dictionary["train_features"] test_df = data_dictionary["test_features"] - self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH, - reward_kwargs=self.reward_params, config=self.config) - self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, - window_size=self.CONV_WIDTH, - reward_kwargs=self.reward_params, config=self.config)) + self.train_env = self.MyRLEnv(df=train_df, + prices=prices_train, + window_size=self.CONV_WIDTH, + reward_kwargs=self.reward_params, + config=self.config, + dp=self.data_provider) + self.eval_env = Monitor(self.MyRLEnv(df=test_df, + prices=prices_test, + window_size=self.CONV_WIDTH, + reward_kwargs=self.reward_params, + config=self.config, + dp=self.data_provider)) self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), best_model_save_path=str(dk.data_path)) @@ -179,12 +186,13 @@ class BaseReinforcementLearningModel(IFreqaiModel): if trade.pair == pair: if self.data_provider._exchange is None: # type: ignore logger.error('No exchange available.') + return 0, 0, 0 else: current_rate = self.data_provider._exchange.get_rate( # type: ignore pair, refresh=False, side="exit", is_short=trade.is_short) now = datetime.now(timezone.utc).timestamp() - trade_duration = int((now - trade.open_date_utc) / self.base_tf_seconds) + trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds) current_profit = trade.calc_profit_ratio(current_rate) return market_side, current_profit, int(trade_duration) @@ -230,7 +238,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): def _predict(window): observations = dataframe.iloc[window.index] - if self.live: # self.guard_state_info_if_backtest(): + if self.live and self.rl_config('add_state_info', False): market_side, current_profit, trade_duration = self.get_state_info(dk.pair) observations['current_profit_pct'] = current_profit observations['position'] = market_side @@ -242,17 +250,6 @@ class BaseReinforcementLearningModel(IFreqaiModel): return output - # def guard_state_info_if_backtest(self): - # """ - # Ensure that backtesting mode doesnt try to use state information. - # """ - # if self.rl_config('add_state_info', False) and not self.live: - # logger.warning('Backtesting with state info is currently unavailable ' - # 'turning it off.') - # self.rl_config['add_state_info'] = False - - # return not self.rl_config['add_state_info'] - def build_ohlc_price_dataframes(self, data_dictionary: dict, pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame, DataFrame]: diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 6ed9dac3d..08f33add9 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -13,7 +13,7 @@ from freqtrade.freqai.utils import download_all_data_for_training, get_required_ from freqtrade.optimize.backtesting import Backtesting from freqtrade.persistence import Trade from freqtrade.plugins.pairlistmanager import PairListManager -from tests.conftest import get_patched_exchange, log_has_re +from tests.conftest import create_mock_trades, get_patched_exchange, log_has_re from tests.freqai.conftest import get_patched_freqai_strategy, make_rl_config @@ -32,7 +32,7 @@ def is_mac() -> bool: ('XGBoostRegressor', False, True), ('XGBoostRFRegressor', False, False), ('CatboostRegressor', False, False), - ('ReinforcementLearner', False, False), + ('ReinforcementLearner', False, True), ('ReinforcementLearner_multiproc', False, False), ('ReinforcementLearner_test_4ac', False, False) ]) @@ -40,7 +40,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if is_arm() and model == 'CatboostRegressor': pytest.skip("CatBoost is not supported on ARM") - if is_mac(): + if is_mac() and 'Reinforcement' in model: pytest.skip("Reinforcement learning module not available on intel based Mac OS") model_save_ext = 'joblib' @@ -53,6 +53,9 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if 'ReinforcementLearner' in model: model_save_ext = 'zip' freqai_conf = make_rl_config(freqai_conf) + # test the RL guardrails + freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True}) + freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True}) if 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") @@ -497,3 +500,43 @@ def test_download_all_data_for_training(mocker, freqai_conf, caplog, tmpdir): "Downloading", caplog, ) + + +@pytest.mark.usefixtures("init_persistence") +@pytest.mark.parametrize('dp_exists', [(False), (True)]) +def test_get_state_info(mocker, freqai_conf, dp_exists, caplog, tickers): + + if is_mac(): + pytest.skip("Reinforcement learning module not available on intel based Mac OS") + + freqai_conf.update({"freqaimodel": "ReinforcementLearner"}) + freqai_conf.update({"timerange": "20180110-20180130"}) + freqai_conf.update({"strategy": "freqai_rl_test_strat"}) + freqai_conf = make_rl_config(freqai_conf) + freqai_conf['entry_pricing']['price_side'] = 'same' + freqai_conf['exit_pricing']['price_side'] = 'same' + + strategy = get_patched_freqai_strategy(mocker, freqai_conf) + exchange = get_patched_exchange(mocker, freqai_conf) + ticker_mock = MagicMock(return_value=tickers()['ETH/BTC']) + mocker.patch("freqtrade.exchange.Exchange.fetch_ticker", ticker_mock) + strategy.dp = DataProvider(freqai_conf, exchange) + + if not dp_exists: + strategy.dp._exchange = None + + strategy.freqai_info = freqai_conf.get("freqai", {}) + freqai = strategy.freqai + freqai.data_provider = strategy.dp + freqai.live = True + + Trade.use_db = True + create_mock_trades(MagicMock(return_value=0.0025), False, True) + freqai.get_state_info("ADA/BTC") + freqai.get_state_info("ETH/BTC") + + if not dp_exists: + assert log_has_re( + "No exchange available", + caplog, + )