Add additional data normalization methods to freqai module, including StandardScaler, MinMaxScaler, and QuantileTransformer. Add support for pickle metadata, normalization_factory, and unit tests.

This commit is contained in:
Zohar Kol
2023-03-29 17:16:20 +03:00
parent 8a49d62068
commit 4aa9284737
6 changed files with 415 additions and 98 deletions

View File

@@ -427,6 +427,9 @@ class FreqaiDataDrawer:
with (save_path / f"{dk.model_filename}_metadata.json").open("w") as fp:
rapidjson.dump(dk.data, fp, default=self.np_encoder, number_mode=rapidjson.NM_NATIVE)
with (save_path / f"{dk.model_filename}_metadata.pkl").open("wb") as fp:
cloudpickle.dump(dk.pkl_data, fp)
return
def save_data(self, model: Any, coin: str, dk: FreqaiDataKitchen) -> None:
@@ -456,10 +459,14 @@ class FreqaiDataDrawer:
dk.data["model_filename"] = str(dk.model_filename)
dk.data["training_features_list"] = dk.training_features_list
dk.data["label_list"] = dk.label_list
# store the metadata
# store the json metadata
with (save_path / f"{dk.model_filename}_metadata.json").open("w") as fp:
rapidjson.dump(dk.data, fp, default=self.np_encoder, number_mode=rapidjson.NM_NATIVE)
# store the pickle metadata
with (save_path / f"{dk.model_filename}_metadata.pkl").open("wb") as fp:
cloudpickle.dump(dk.pkl_data, fp)
# save the train data to file so we can check preds for area of applicability later
dk.data_dictionary["train_features"].to_pickle(
save_path / f"{dk.model_filename}_trained_df.pkl"
@@ -486,6 +493,16 @@ class FreqaiDataDrawer:
return
def load_pickle_metadata(self, dk: FreqaiDataKitchen):
pickle_file_path = dk.data_path / f"{dk.model_filename}_metadata.pkl"
exists = pickle_file_path.is_file()
# Check if the metadata pickle file exists before attempting to read it.
# This is for backward compatibility with models generated before the
# pickle metadata feature was implemented.
if exists:
with (dk.data_path / f"{dk.model_filename}_metadata.pkl").open("rb") as fp:
dk.pkl_data = cloudpickle.load(fp)
def load_metadata(self, dk: FreqaiDataKitchen) -> None:
"""
Load only metadata into datakitchen to increase performance during
@@ -496,6 +513,8 @@ class FreqaiDataDrawer:
dk.training_features_list = dk.data["training_features_list"]
dk.label_list = dk.data["label_list"]
self.load_pickle_metadata(dk)
def load_data(self, coin: str, dk: FreqaiDataKitchen) -> Any:
"""
loads all data required to make a prediction on a sub-train time range
@@ -517,6 +536,8 @@ class FreqaiDataDrawer:
with (dk.data_path / f"{dk.model_filename}_metadata.json").open("r") as fp:
dk.data = rapidjson.load(fp, number_mode=rapidjson.NM_NATIVE)
self.load_pickle_metadata(dk)
dk.data_dictionary["train_features"] = pd.read_pickle(
dk.data_path / f"{dk.model_filename}_trained_df.pkl"
)