Merge pull request #7593 from th0rntwig/prediction-shape

Fix constant PCA
This commit is contained in:
Robert Caulk 2022-10-24 08:33:36 +02:00 committed by GitHub
commit 137aa1756b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 11 deletions

View File

@ -210,7 +210,10 @@ class FreqaiDataKitchen:
const_cols = list((filtered_df.nunique() == 1).loc[lambda x: x].index) const_cols = list((filtered_df.nunique() == 1).loc[lambda x: x].index)
if const_cols: if const_cols:
filtered_df = filtered_df.filter(filtered_df.columns.difference(const_cols)) filtered_df = filtered_df.filter(filtered_df.columns.difference(const_cols))
self.data['constant_features_list'] = const_cols
logger.warning(f"Removed features {const_cols} with constant values.") logger.warning(f"Removed features {const_cols} with constant values.")
else:
self.data['constant_features_list'] = []
# we don't care about total row number (total no. datapoints) in training, we only care # we don't care about total row number (total no. datapoints) in training, we only care
# about removing any row with NaNs # about removing any row with NaNs
# if labels has multiple columns (user wants to train multiple modelEs), we detect here # if labels has multiple columns (user wants to train multiple modelEs), we detect here
@ -241,7 +244,8 @@ class FreqaiDataKitchen:
self.data["filter_drop_index_training"] = drop_index self.data["filter_drop_index_training"] = drop_index
else: else:
filtered_df = self.check_pred_labels(filtered_df) if len(self.data['constant_features_list']):
filtered_df = self.check_pred_labels(filtered_df)
# we are backtesting so we need to preserve row number to send back to strategy, # we are backtesting so we need to preserve row number to send back to strategy,
# so now we use do_predict to avoid any prediction based on a NaN # so now we use do_predict to avoid any prediction based on a NaN
drop_index = pd.isnull(filtered_df).any(axis=1) drop_index = pd.isnull(filtered_df).any(axis=1)
@ -467,15 +471,14 @@ class FreqaiDataKitchen:
:params: :params:
:df_predictions: incoming predictions :df_predictions: incoming predictions
""" """
train_labels = self.data_dictionary["train_features"].columns constant_labels = self.data['constant_features_list']
pred_labels = df_predictions.columns df_predictions = df_predictions.filter(
num_diffs = len(pred_labels.difference(train_labels)) df_predictions.columns.difference(constant_labels)
if num_diffs != 0: )
df_predictions = df_predictions[train_labels] logger.warning(
logger.warning( f"Removed {len(constant_labels)} features from prediction features, "
f"Removed {num_diffs} features from prediction features, " f"these were considered constant values during most recent training."
f"these were likely considered constant values during most recent training." )
)
return df_predictions return df_predictions

View File

@ -125,7 +125,8 @@ def test_normalize_data(mocker, freqai_conf):
freqai = make_data_dictionary(mocker, freqai_conf) freqai = make_data_dictionary(mocker, freqai_conf)
data_dict = freqai.dk.data_dictionary data_dict = freqai.dk.data_dictionary
freqai.dk.normalize_data(data_dict) freqai.dk.normalize_data(data_dict)
assert len(freqai.dk.data) == 32 assert any('_max' in entry for entry in freqai.dk.data.keys())
assert any('_min' in entry for entry in freqai.dk.data.keys())
def test_filter_features(mocker, freqai_conf): def test_filter_features(mocker, freqai_conf):