Change storage loc and fix test fail

This commit is contained in:
th0rntwig 2022-10-23 16:24:02 +02:00
parent 033c5bd441
commit 49ff51f11f
2 changed files with 6 additions and 6 deletions

View File

@ -71,7 +71,6 @@ class FreqaiDataKitchen:
self.data_path = Path()
self.label_list: List = []
self.training_features_list: List = []
self.constant_features_list: List = []
self.model_filename: str = ""
self.backtesting_results_path = Path()
self.backtest_predictions_folder: str = "backtesting_predictions"
@ -211,10 +210,10 @@ class FreqaiDataKitchen:
const_cols = list((filtered_df.nunique() == 1).loc[lambda x: x].index)
if const_cols:
filtered_df = filtered_df.filter(filtered_df.columns.difference(const_cols))
self.constant_features_list = const_cols
self.data['constant_features_list'] = const_cols
logger.warning(f"Removed features {const_cols} with constant values.")
else:
self.constant_features_list = []
self.data['constant_features_list'] = []
# we don't care about total row number (total no. datapoints) in training, we only care
# about removing any row with NaNs
# if labels has multiple columns (user wants to train multiple modelEs), we detect here
@ -245,7 +244,7 @@ class FreqaiDataKitchen:
self.data["filter_drop_index_training"] = drop_index
else:
if len(self.constant_features_list):
if len(self.data['constant_features_list']):
filtered_df = self.check_pred_labels(filtered_df)
# we are backtesting so we need to preserve row number to send back to strategy,
# so now we use do_predict to avoid any prediction based on a NaN
@ -472,7 +471,7 @@ class FreqaiDataKitchen:
:params:
:df_predictions: incoming predictions
"""
constant_labels = self.constant_features_list
constant_labels = self.data['constant_features_list']
df_predictions = df_predictions.filter(
df_predictions.columns.difference(constant_labels)
)

View File

@ -125,7 +125,8 @@ def test_normalize_data(mocker, freqai_conf):
freqai = make_data_dictionary(mocker, freqai_conf)
data_dict = freqai.dk.data_dictionary
freqai.dk.normalize_data(data_dict)
assert len(freqai.dk.data) == 32
assert any('_max' in entry for entry in freqai.dk.data.keys())
assert any('_min' in entry for entry in freqai.dk.data.keys())
def test_filter_features(mocker, freqai_conf):