Add ground work for TensorFlow models, add protections from common mistakes

This commit is contained in:
robcaulk
2022-07-12 18:09:17 +02:00
parent fea63fba12
commit ef409dd345
4 changed files with 44 additions and 21 deletions

View File

@@ -57,6 +57,7 @@ class FreqaiDataKitchen:
self.live = live
self.pair = pair
self.svm_model: linear_model.SGDOneClassSVM = None
self.keras = self.freqai_config.get("keras", False)
self.set_all_pairs()
if not self.live:
self.full_timerange = self.create_fulltimerange(
@@ -92,7 +93,7 @@ class FreqaiDataKitchen:
return
def save_data(self, model: Any, coin: str = "", keras_model=False, label=None) -> None:
def save_data(self, model: Any, coin: str = "", label=None) -> None:
"""
Saves all data associated with a model for a single sub-train time range
:params:
@@ -106,7 +107,7 @@ class FreqaiDataKitchen:
save_path = Path(self.data_path)
# Save the trained model
if not keras_model:
if not self.keras:
dump(model, save_path / f"{self.model_filename}_model.joblib")
else:
model.save(save_path / f"{self.model_filename}_model.h5")
@@ -140,7 +141,7 @@ class FreqaiDataKitchen:
return
def load_data(self, coin: str = "", keras_model=False) -> Any:
def load_data(self, coin: str = "") -> Any:
"""
loads all data required to make a prediction on a sub-train time range
:returns:
@@ -174,7 +175,7 @@ class FreqaiDataKitchen:
# try to access model in memory instead of loading object from disk to save time
if self.live and self.model_filename in self.dd.model_dictionary:
model = self.dd.model_dictionary[self.model_filename]
elif not keras_model:
elif not self.keras:
model = load(self.data_path / str(self.model_filename + "_model.joblib"))
else:
from tensorflow import keras
@@ -559,6 +560,13 @@ class FreqaiDataKitchen:
predict: bool = If true, inference an existing SVM model, else construct one
"""
if self.keras:
logger.warning("SVM outlier removal not currently supported for Keras based models. "
"Skipping user requested function.")
if predict:
self.do_predict = np.ones(len(self.data_dictionary["prediction_features"]))
return
if predict:
assert self.svm_model, "No svm model available for outlier removal"
y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"])