ensure RL works with new naming scheme

This commit is contained in:
robcaulk 2022-12-28 14:52:33 +01:00
parent c2936d551b
commit 6f7eb71bbb
2 changed files with 13 additions and 9 deletions

View File

@ -280,22 +280,26 @@ class BaseReinforcementLearningModel(IFreqaiModel):
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
# %-raw_volume_gen_shift-2_ETH/USDT_1h
# price data for model training and evaluation
tf = self.config['timeframe']
ohlc_list = [f'%-{pair}raw_open_{tf}', f'%-{pair}raw_low_{tf}',
f'%-{pair}raw_high_{tf}', f'%-{pair}raw_close_{tf}']
rename_dict = {f'%-{pair}raw_open_{tf}': 'open', f'%-{pair}raw_low_{tf}': 'low',
f'%-{pair}raw_high_{tf}': ' high', f'%-{pair}raw_close_{tf}': 'close'}
ohlc_list = [f'%-raw_open_gen_{pair}_{tf}', f'%-raw_low_gen_{pair}_{tf}',
f'%-raw_high_gen_{pair}_{tf}', f'%-raw_close_gen_{pair}_{tf}']
rename_dict = {f'%-raw_open_gen_{pair}_{tf}': 'open',
f'%-raw_low_gen_{pair}_{tf}': 'low',
f'%-raw_high_gen_{pair}_{tf}': ' high',
f'%-raw_close_gen_{pair}_{tf}': 'close'}
prices_train = train_df.filter(ohlc_list, axis=1)
if prices_train.empty:
raise OperationalException('Reinforcement learning module didnt find the raw prices '
'assigned in feature_engineering_standard(). '
'Please assign them with:\n'
'informative[f"%-{pair}raw_close"] = informative["close"]\n'
'informative[f"%-{pair}raw_open"] = informative["open"]\n'
'informative[f"%-{pair}raw_high"] = informative["high"]\n'
'informative[f"%-{pair}raw_low"] = informative["low"]\n')
'dataframe["%-raw_close"] = dataframe["close"]\n'
'dataframe["%-raw_open"] = dataframe["open"]\n'
'dataframe["%-raw_high"] = dataframe["high"]\n'
'dataframe["%-raw_low"] = dataframe["low"]\n'
'inside `feature_engineering_expand_basic()`')
prices_train.rename(columns=rename_dict, inplace=True)
prices_train.reset_index(drop=True)

View File

@ -90,7 +90,7 @@ def test_use_SVM_to_remove_outliers_and_outlier_protection(mocker, freqai_conf,
freqai_conf['freqai']['feature_parameters'].update({"outlier_protection_percentage": 0.1})
freqai.dk.use_SVM_to_remove_outliers(predict=False)
assert log_has_re(
"SVM detected 7.36%",
"SVM detected 7.83%",
caplog,
)