Merge pull request #7544 from th0rntwig/prediction-shape

Remove constant labels from prediction
This commit is contained in:
Robert Caulk 2022-10-10 21:24:25 +02:00 committed by GitHub
commit 2e34aa9f04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 1 deletions

View File

@ -243,6 +243,7 @@ class FreqaiDataKitchen:
self.data["filter_drop_index_training"] = drop_index self.data["filter_drop_index_training"] = drop_index
else: else:
filtered_df = self.check_pred_labels(filtered_df)
# we are backtesting so we need to preserve row number to send back to strategy, # we are backtesting so we need to preserve row number to send back to strategy,
# so now we use do_predict to avoid any prediction based on a NaN # so now we use do_predict to avoid any prediction based on a NaN
drop_index = pd.isnull(filtered_df).any(axis=1) drop_index = pd.isnull(filtered_df).any(axis=1)
@ -462,6 +463,24 @@ class FreqaiDataKitchen:
return df return df
def check_pred_labels(self, df_predictions: DataFrame) -> DataFrame:
"""
Check that prediction feature labels match training feature labels.
:params:
:df_predictions: incoming predictions
"""
train_labels = self.data_dictionary["train_features"].columns
pred_labels = df_predictions.columns
num_diffs = len(pred_labels.difference(train_labels))
if num_diffs != 0:
df_predictions = df_predictions[train_labels]
logger.warning(
f"Removed {num_diffs} features from prediction features, "
f"these were likely considered constant values during most recent training."
)
return df_predictions
def principal_component_analysis(self) -> None: def principal_component_analysis(self) -> None:
""" """
Performs Principal Component Analysis on the data for dimensionality reduction Performs Principal Component Analysis on the data for dimensionality reduction

View File

@ -107,6 +107,8 @@ def make_unfiltered_dataframe(mocker, freqai_conf):
unfiltered_dataframe = freqai.dk.use_strategy_to_populate_indicators( unfiltered_dataframe = freqai.dk.use_strategy_to_populate_indicators(
strategy, corr_dataframes, base_dataframes, freqai.dk.pair strategy, corr_dataframes, base_dataframes, freqai.dk.pair
) )
for i in range(5):
unfiltered_dataframe[f'constant_{i}'] = i
unfiltered_dataframe = freqai.dk.slice_dataframe(new_timerange, unfiltered_dataframe) unfiltered_dataframe = freqai.dk.slice_dataframe(new_timerange, unfiltered_dataframe)

View File

@ -157,7 +157,7 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
("CatboostClassifier", 6, "freqai_test_classifier") ("CatboostClassifier", 6, "freqai_test_classifier")
], ],
) )
def test_start_backtesting(mocker, freqai_conf, model, num_files, strat): def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog):
freqai_conf.get("freqai", {}).update({"save_backtest_models": True}) freqai_conf.get("freqai", {}).update({"save_backtest_models": True})
freqai_conf['runmode'] = RunMode.BACKTEST freqai_conf['runmode'] = RunMode.BACKTEST
Trade.use_db = False Trade.use_db = False
@ -181,12 +181,23 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat):
corr_df, base_df = freqai.dd.get_base_and_corr_dataframes(sub_timerange, "LTC/BTC", freqai.dk) corr_df, base_df = freqai.dd.get_base_and_corr_dataframes(sub_timerange, "LTC/BTC", freqai.dk)
df = freqai.dk.use_strategy_to_populate_indicators(strategy, corr_df, base_df, "LTC/BTC") df = freqai.dk.use_strategy_to_populate_indicators(strategy, corr_df, base_df, "LTC/BTC")
for i in range(5):
df[f'%-constant_{i}'] = i
# df.loc[:, f'%-constant_{i}'] = i
metadata = {"pair": "LTC/BTC"} metadata = {"pair": "LTC/BTC"}
freqai.start_backtesting(df, metadata, freqai.dk) freqai.start_backtesting(df, metadata, freqai.dk)
model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()] model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()]
assert len(model_folders) == num_files assert len(model_folders) == num_files
assert log_has_re(
"Removed features ",
caplog,
)
assert log_has_re(
"Removed 5 features from prediction features, ",
caplog,
)
Backtesting.cleanup() Backtesting.cleanup()
shutil.rmtree(Path(freqai.dk.full_path)) shutil.rmtree(Path(freqai.dk.full_path))
@ -256,6 +267,7 @@ def test_start_backtesting_from_existing_folder(mocker, freqai_conf, caplog):
corr_df, base_df = freqai.dd.get_base_and_corr_dataframes(sub_timerange, "LTC/BTC", freqai.dk) corr_df, base_df = freqai.dd.get_base_and_corr_dataframes(sub_timerange, "LTC/BTC", freqai.dk)
df = freqai.dk.use_strategy_to_populate_indicators(strategy, corr_df, base_df, "LTC/BTC") df = freqai.dk.use_strategy_to_populate_indicators(strategy, corr_df, base_df, "LTC/BTC")
freqai.start_backtesting(df, metadata, freqai.dk) freqai.start_backtesting(df, metadata, freqai.dk)
assert log_has_re( assert log_has_re(
@ -312,6 +324,7 @@ def test_follow_mode(mocker, freqai_conf):
freqai.dd.load_all_pair_histories(timerange, freqai.dk) freqai.dd.load_all_pair_histories(timerange, freqai.dk)
df = strategy.dp.get_pair_dataframe('ADA/BTC', '5m') df = strategy.dp.get_pair_dataframe('ADA/BTC', '5m')
freqai.start_live(df, metadata, strategy, freqai.dk) freqai.start_live(df, metadata, strategy, freqai.dk)
assert len(freqai.dk.return_dataframe.index) == 5702 assert len(freqai.dk.return_dataframe.index) == 5702