Add ground work for TensorFlow models, add protections from common mistakes
This commit is contained in:
@@ -57,6 +57,7 @@ class FreqaiDataKitchen:
|
||||
self.live = live
|
||||
self.pair = pair
|
||||
self.svm_model: linear_model.SGDOneClassSVM = None
|
||||
self.keras = self.freqai_config.get("keras", False)
|
||||
self.set_all_pairs()
|
||||
if not self.live:
|
||||
self.full_timerange = self.create_fulltimerange(
|
||||
@@ -92,7 +93,7 @@ class FreqaiDataKitchen:
|
||||
|
||||
return
|
||||
|
||||
def save_data(self, model: Any, coin: str = "", keras_model=False, label=None) -> None:
|
||||
def save_data(self, model: Any, coin: str = "", label=None) -> None:
|
||||
"""
|
||||
Saves all data associated with a model for a single sub-train time range
|
||||
:params:
|
||||
@@ -106,7 +107,7 @@ class FreqaiDataKitchen:
|
||||
save_path = Path(self.data_path)
|
||||
|
||||
# Save the trained model
|
||||
if not keras_model:
|
||||
if not self.keras:
|
||||
dump(model, save_path / f"{self.model_filename}_model.joblib")
|
||||
else:
|
||||
model.save(save_path / f"{self.model_filename}_model.h5")
|
||||
@@ -140,7 +141,7 @@ class FreqaiDataKitchen:
|
||||
|
||||
return
|
||||
|
||||
def load_data(self, coin: str = "", keras_model=False) -> Any:
|
||||
def load_data(self, coin: str = "") -> Any:
|
||||
"""
|
||||
loads all data required to make a prediction on a sub-train time range
|
||||
:returns:
|
||||
@@ -174,7 +175,7 @@ class FreqaiDataKitchen:
|
||||
# try to access model in memory instead of loading object from disk to save time
|
||||
if self.live and self.model_filename in self.dd.model_dictionary:
|
||||
model = self.dd.model_dictionary[self.model_filename]
|
||||
elif not keras_model:
|
||||
elif not self.keras:
|
||||
model = load(self.data_path / str(self.model_filename + "_model.joblib"))
|
||||
else:
|
||||
from tensorflow import keras
|
||||
@@ -559,6 +560,13 @@ class FreqaiDataKitchen:
|
||||
predict: bool = If true, inference an existing SVM model, else construct one
|
||||
"""
|
||||
|
||||
if self.keras:
|
||||
logger.warning("SVM outlier removal not currently supported for Keras based models. "
|
||||
"Skipping user requested function.")
|
||||
if predict:
|
||||
self.do_predict = np.ones(len(self.data_dictionary["prediction_features"]))
|
||||
return
|
||||
|
||||
if predict:
|
||||
assert self.svm_model, "No svm model available for outlier removal"
|
||||
y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"])
|
||||
|
Reference in New Issue
Block a user