diff --git a/docs/freqai-parameter-table.md b/docs/freqai-parameter-table.md index 275062a33..f67ea8541 100644 --- a/docs/freqai-parameter-table.md +++ b/docs/freqai-parameter-table.md @@ -84,6 +84,7 @@ Mandatory parameters are marked as **Required** and have to be set in one of the | `add_state_info` | Tell FreqAI to include state information in the feature set for training and inferencing. The current state variables include trade duration, current profit, trade position. This is only available in dry/live runs, and is automatically switched to false for backtesting.
**Datatype:** bool.
Default: `False`. | `net_arch` | Network architecture which is well described in [`stable_baselines3` doc](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#examples). In summary: `[, dict(vf=[], pi=[])]`. By default this is set to `[128, 128]`, which defines 2 shared hidden layers with 128 units each. | `randomize_starting_position` | Randomize the starting point of each episode to avoid overfitting.
**Datatype:** bool.
Default: `False`. +| `drop_ohlc_from_features` | Do not include the normalized ohlc data in the feature set passed to the agent during training (ohlc will still be used for driving the environment in all cases)
**Datatype:** Boolean.
**Default:** `False` ### Additional parameters diff --git a/docs/freqai-reinforcement-learning.md b/docs/freqai-reinforcement-learning.md index 3810aec4e..04ca42a5d 100644 --- a/docs/freqai-reinforcement-learning.md +++ b/docs/freqai-reinforcement-learning.md @@ -176,9 +176,11 @@ As you begin to modify the strategy and the prediction model, you will quickly r factor = 100 + pair = self.pair.replace(':', '') + # you can use feature values from dataframe # Assumes the shifted RSI indicator has been generated in the strategy. - rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{self.pair}_" + rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{pair}_" f"{self.config['timeframe']}"].iloc[self._current_tick] # reward agent for entering trades diff --git a/freqtrade/constants.py b/freqtrade/constants.py index 1727da92e..46e9b5cd4 100644 --- a/freqtrade/constants.py +++ b/freqtrade/constants.py @@ -588,6 +588,7 @@ CONF_SCHEMA = { "rl_config": { "type": "object", "properties": { + "drop_ohlc_from_features": {"type": "boolean", "default": False}, "train_cycles": {"type": "integer"}, "max_trade_duration_candles": {"type": "integer"}, "add_state_info": {"type": "boolean", "default": False}, diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index a8ef69394..bac717f9f 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -114,6 +114,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): # normalize all data based on train_dataset only prices_train, prices_test = self.build_ohlc_price_dataframes(dk.data_dictionary, pair, dk) + data_dictionary = dk.normalize_data(data_dictionary) # data cleaning/analysis @@ -148,12 +149,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): env_info = self.pack_env_dict(dk.pair) - self.train_env = self.MyRLEnv(df=train_df, - prices=prices_train, - **env_info) - self.eval_env = Monitor(self.MyRLEnv(df=test_df, - prices=prices_test, - **env_info)) + self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info) + self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info)) self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), best_model_save_path=str(dk.data_path)) @@ -285,7 +282,6 @@ class BaseReinforcementLearningModel(IFreqaiModel): train_df = data_dictionary["train_features"] test_df = data_dictionary["test_features"] - # %-raw_volume_gen_shift-2_ETH/USDT_1h # price data for model training and evaluation tf = self.config['timeframe'] rename_dict = {'%-raw_open': 'open', '%-raw_low': 'low', @@ -318,6 +314,12 @@ class BaseReinforcementLearningModel(IFreqaiModel): prices_test.rename(columns=rename_dict, inplace=True) prices_test.reset_index(drop=True) + if self.rl_config["drop_ohlc_from_features"]: + train_df.drop(rename_dict.keys(), axis=1, inplace=True) + test_df.drop(rename_dict.keys(), axis=1, inplace=True) + feature_list = dk.training_features_list + feature_list = [e for e in feature_list if e not in rename_dict.keys()] + return prices_train, prices_test def load_model_from_disk(self, dk: FreqaiDataKitchen) -> Any: diff --git a/tests/freqai/conftest.py b/tests/freqai/conftest.py index 68e7ea49a..e140ee80b 100644 --- a/tests/freqai/conftest.py +++ b/tests/freqai/conftest.py @@ -78,7 +78,9 @@ def make_rl_config(conf): "rr": 1, "profit_aim": 0.02, "win_reward_factor": 2 - }} + }, + "drop_ohlc_from_features": False + } return conf diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index f8bee3659..3b370aea4 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -68,13 +68,6 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, freqai_conf['freqai']['feature_parameters'].update({"shuffle_after_split": shuffle}) freqai_conf['freqai']['feature_parameters'].update({"buffer_train_data_candles": buffer}) - if 'ReinforcementLearner' in model: - model_save_ext = 'zip' - freqai_conf = make_rl_config(freqai_conf) - # test the RL guardrails - freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True}) - freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True}) - if 'ReinforcementLearner' in model: model_save_ext = 'zip' freqai_conf = make_rl_config(freqai_conf) @@ -84,6 +77,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if 'test_3ac' in model or 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") + freqai_conf["freqai"]["rl_config"]["drop_ohlc_from_features"] = True strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf)