first step toward cleaning output and enabling multimodel training per pair
This commit is contained in:
@@ -88,7 +88,7 @@ class FreqaiDataKitchen:
|
||||
|
||||
return
|
||||
|
||||
def save_data(self, model: Any, coin: str = '') -> None:
|
||||
def save_data(self, model: Any, coin: str = '', keras_model=False) -> None:
|
||||
"""
|
||||
Saves all data associated with a model for a single sub-train time range
|
||||
:params:
|
||||
@@ -102,7 +102,10 @@ class FreqaiDataKitchen:
|
||||
save_path = Path(self.data_path)
|
||||
|
||||
# Save the trained model
|
||||
dump(model, save_path / str(self.model_filename + "_model.joblib"))
|
||||
if not keras_model:
|
||||
dump(model, save_path / str(self.model_filename + "_model.joblib"))
|
||||
else:
|
||||
model.save(save_path / str(self.model_filename + "_model.h5"))
|
||||
|
||||
if self.svm_model is not None:
|
||||
dump(self.svm_model, save_path / str(self.model_filename + "_svm_model.joblib"))
|
||||
@@ -144,7 +147,7 @@ class FreqaiDataKitchen:
|
||||
|
||||
return
|
||||
|
||||
def load_data(self, coin: str = '') -> Any:
|
||||
def load_data(self, coin: str = '', keras_model=False) -> Any:
|
||||
"""
|
||||
loads all data required to make a prediction on a sub-train time range
|
||||
:returns:
|
||||
@@ -190,8 +193,11 @@ class FreqaiDataKitchen:
|
||||
# try to access model in memory instead of loading object from disk to save time
|
||||
if self.live and self.model_filename in self.data_drawer.model_dictionary:
|
||||
model = self.data_drawer.model_dictionary[self.model_filename]
|
||||
else:
|
||||
elif not keras_model:
|
||||
model = load(self.data_path / str(self.model_filename + "_model.joblib"))
|
||||
else:
|
||||
from tensorflow import keras
|
||||
model = keras.models.load_model(self.data_path / str(self.model_filename + "_model.h5"))
|
||||
|
||||
if Path(self.data_path / str(self.model_filename +
|
||||
"_svm_model.joblib")).resolve().exists():
|
||||
@@ -287,7 +293,11 @@ class FreqaiDataKitchen:
|
||||
training_filter
|
||||
): # we don't care about total row number (total no. datapoints) in training, we only care
|
||||
# about removing any row with NaNs
|
||||
drop_index_labels = pd.isnull(labels)
|
||||
# if labels has multiple columns (user wants to train multiple models), we detect here
|
||||
if labels.shape[1] == 1:
|
||||
drop_index_labels = pd.isnull(labels)
|
||||
else:
|
||||
drop_index_labels = pd.isnull(labels).any(1)
|
||||
drop_index_labels = drop_index_labels.replace(True, 1).replace(False, 0)
|
||||
filtered_dataframe = filtered_dataframe[
|
||||
(drop_index == 0) & (drop_index_labels == 0)
|
||||
|
||||
Reference in New Issue
Block a user