let user avoid normalizing labels

This commit is contained in:
robcaulk 2022-07-23 16:14:13 +02:00
parent 50d630a155
commit c91e23dc50

View File

@ -356,7 +356,7 @@ class FreqaiDataKitchen:
return self.data_dictionary return self.data_dictionary
def normalize_data(self, data_dictionary: Dict) -> Dict[Any, Any]: def normalize_data(self, data_dictionary: Dict, do_labels: bool = True) -> Dict[Any, Any]:
""" """
Normalize all data in the data_dictionary according to the training dataset Normalize all data in the data_dictionary according to the training dataset
:params: :params:
@ -374,6 +374,11 @@ class FreqaiDataKitchen:
2 * (data_dictionary["test_features"] - train_min) / (train_max - train_min) - 1 2 * (data_dictionary["test_features"] - train_min) / (train_max - train_min) - 1
) )
for item in train_max.keys():
self.data[item + "_max"] = train_max[item]
self.data[item + "_min"] = train_min[item]
if do_labels:
train_labels_max = data_dictionary["train_labels"].max() train_labels_max = data_dictionary["train_labels"].max()
train_labels_min = data_dictionary["train_labels"].min() train_labels_min = data_dictionary["train_labels"].min()
data_dictionary["train_labels"] = ( data_dictionary["train_labels"] = (
@ -389,10 +394,6 @@ class FreqaiDataKitchen:
- 1 - 1
) )
for item in train_max.keys():
self.data[item + "_max"] = train_max[item]
self.data[item + "_min"] = train_min[item]
self.data["labels_max"] = train_labels_max.to_dict() self.data["labels_max"] = train_labels_max.to_dict()
self.data["labels_min"] = train_labels_min.to_dict() self.data["labels_min"] = train_labels_min.to_dict()