Add ground work for TensorFlow models, add protections from common mistakes
This commit is contained in:
parent
fea63fba12
commit
ef409dd345
@ -57,6 +57,7 @@ class FreqaiDataKitchen:
|
|||||||
self.live = live
|
self.live = live
|
||||||
self.pair = pair
|
self.pair = pair
|
||||||
self.svm_model: linear_model.SGDOneClassSVM = None
|
self.svm_model: linear_model.SGDOneClassSVM = None
|
||||||
|
self.keras = self.freqai_config.get("keras", False)
|
||||||
self.set_all_pairs()
|
self.set_all_pairs()
|
||||||
if not self.live:
|
if not self.live:
|
||||||
self.full_timerange = self.create_fulltimerange(
|
self.full_timerange = self.create_fulltimerange(
|
||||||
@ -92,7 +93,7 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_data(self, model: Any, coin: str = "", keras_model=False, label=None) -> None:
|
def save_data(self, model: Any, coin: str = "", label=None) -> None:
|
||||||
"""
|
"""
|
||||||
Saves all data associated with a model for a single sub-train time range
|
Saves all data associated with a model for a single sub-train time range
|
||||||
:params:
|
:params:
|
||||||
@ -106,7 +107,7 @@ class FreqaiDataKitchen:
|
|||||||
save_path = Path(self.data_path)
|
save_path = Path(self.data_path)
|
||||||
|
|
||||||
# Save the trained model
|
# Save the trained model
|
||||||
if not keras_model:
|
if not self.keras:
|
||||||
dump(model, save_path / f"{self.model_filename}_model.joblib")
|
dump(model, save_path / f"{self.model_filename}_model.joblib")
|
||||||
else:
|
else:
|
||||||
model.save(save_path / f"{self.model_filename}_model.h5")
|
model.save(save_path / f"{self.model_filename}_model.h5")
|
||||||
@ -140,7 +141,7 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def load_data(self, coin: str = "", keras_model=False) -> Any:
|
def load_data(self, coin: str = "") -> Any:
|
||||||
"""
|
"""
|
||||||
loads all data required to make a prediction on a sub-train time range
|
loads all data required to make a prediction on a sub-train time range
|
||||||
:returns:
|
:returns:
|
||||||
@ -174,7 +175,7 @@ class FreqaiDataKitchen:
|
|||||||
# try to access model in memory instead of loading object from disk to save time
|
# try to access model in memory instead of loading object from disk to save time
|
||||||
if self.live and self.model_filename in self.dd.model_dictionary:
|
if self.live and self.model_filename in self.dd.model_dictionary:
|
||||||
model = self.dd.model_dictionary[self.model_filename]
|
model = self.dd.model_dictionary[self.model_filename]
|
||||||
elif not keras_model:
|
elif not self.keras:
|
||||||
model = load(self.data_path / str(self.model_filename + "_model.joblib"))
|
model = load(self.data_path / str(self.model_filename + "_model.joblib"))
|
||||||
else:
|
else:
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
@ -559,6 +560,13 @@ class FreqaiDataKitchen:
|
|||||||
predict: bool = If true, inference an existing SVM model, else construct one
|
predict: bool = If true, inference an existing SVM model, else construct one
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.keras:
|
||||||
|
logger.warning("SVM outlier removal not currently supported for Keras based models. "
|
||||||
|
"Skipping user requested function.")
|
||||||
|
if predict:
|
||||||
|
self.do_predict = np.ones(len(self.data_dictionary["prediction_features"]))
|
||||||
|
return
|
||||||
|
|
||||||
if predict:
|
if predict:
|
||||||
assert self.svm_model, "No svm model available for outlier removal"
|
assert self.svm_model, "No svm model available for outlier removal"
|
||||||
y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"])
|
y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"])
|
||||||
|
@ -69,6 +69,9 @@ class IFreqaiModel(ABC):
|
|||||||
self.ready_to_scan = False
|
self.ready_to_scan = False
|
||||||
self.first = True
|
self.first = True
|
||||||
self.keras = self.freqai_info.get("keras", False)
|
self.keras = self.freqai_info.get("keras", False)
|
||||||
|
if self.keras and self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
|
||||||
|
self.freqai_info["feature_parameters"]["DI_threshold"] = 0
|
||||||
|
logger.warning("DI threshold is not configured for Keras models yet. Deactivating.")
|
||||||
self.CONV_WIDTH = self.freqai_info.get("conv_width", 2)
|
self.CONV_WIDTH = self.freqai_info.get("conv_width", 2)
|
||||||
|
|
||||||
def assert_config(self, config: Dict[str, Any]) -> None:
|
def assert_config(self, config: Dict[str, Any]) -> None:
|
||||||
@ -197,9 +200,9 @@ class IFreqaiModel(ABC):
|
|||||||
self.model = self.train(dataframe_train, metadata["pair"], dk)
|
self.model = self.train(dataframe_train, metadata["pair"], dk)
|
||||||
self.dd.pair_dict[metadata["pair"]]["trained_timestamp"] = trained_timestamp.stopts
|
self.dd.pair_dict[metadata["pair"]]["trained_timestamp"] = trained_timestamp.stopts
|
||||||
dk.set_new_model_names(metadata["pair"], trained_timestamp)
|
dk.set_new_model_names(metadata["pair"], trained_timestamp)
|
||||||
dk.save_data(self.model, metadata["pair"], keras_model=self.keras)
|
dk.save_data(self.model, metadata["pair"])
|
||||||
else:
|
else:
|
||||||
self.model = dk.load_data(metadata["pair"], keras_model=self.keras)
|
self.model = dk.load_data(metadata["pair"])
|
||||||
|
|
||||||
self.check_if_feature_list_matches_strategy(dataframe_train, dk)
|
self.check_if_feature_list_matches_strategy(dataframe_train, dk)
|
||||||
|
|
||||||
@ -276,7 +279,7 @@ class IFreqaiModel(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# load the model and associated data into the data kitchen
|
# load the model and associated data into the data kitchen
|
||||||
self.model = dk.load_data(coin=metadata["pair"], keras_model=self.keras)
|
self.model = dk.load_data(coin=metadata["pair"])
|
||||||
|
|
||||||
if not self.model:
|
if not self.model:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -353,13 +356,15 @@ class IFreqaiModel(ABC):
|
|||||||
of how outlier data points are dropped from the dataframe used for training.
|
of how outlier data points are dropped from the dataframe used for training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"):
|
if self.freqai_info.get("feature_parameters", {}).get(
|
||||||
|
"principal_component_analysis", False
|
||||||
|
):
|
||||||
dk.principal_component_analysis()
|
dk.principal_component_analysis()
|
||||||
|
|
||||||
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"):
|
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False):
|
||||||
dk.use_SVM_to_remove_outliers(predict=False)
|
dk.use_SVM_to_remove_outliers(predict=False)
|
||||||
|
|
||||||
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"):
|
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
|
||||||
dk.data["avg_mean_dist"] = dk.compute_distances()
|
dk.data["avg_mean_dist"] = dk.compute_distances()
|
||||||
|
|
||||||
# if self.feature_parameters["determine_statistical_distributions"]:
|
# if self.feature_parameters["determine_statistical_distributions"]:
|
||||||
@ -378,13 +383,15 @@ class IFreqaiModel(ABC):
|
|||||||
of how the do_predict vector is modified. do_predict is ultimately passed back to strategy
|
of how the do_predict vector is modified. do_predict is ultimately passed back to strategy
|
||||||
for buy signals.
|
for buy signals.
|
||||||
"""
|
"""
|
||||||
if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"):
|
if self.freqai_info.get("feature_parameters", {}).get(
|
||||||
|
"principal_component_analysis", False
|
||||||
|
):
|
||||||
dk.pca_transform(dataframe)
|
dk.pca_transform(dataframe)
|
||||||
|
|
||||||
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"):
|
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False):
|
||||||
dk.use_SVM_to_remove_outliers(predict=True)
|
dk.use_SVM_to_remove_outliers(predict=True)
|
||||||
|
|
||||||
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"):
|
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
|
||||||
dk.check_if_pred_in_training_spaces()
|
dk.check_if_pred_in_training_spaces()
|
||||||
|
|
||||||
# if self.feature_parameters["determine_statistical_distributions"]:
|
# if self.feature_parameters["determine_statistical_distributions"]:
|
||||||
@ -479,14 +486,15 @@ class IFreqaiModel(ABC):
|
|||||||
if self.dd.pair_dict[pair]["priority"] == 1 and self.scanning:
|
if self.dd.pair_dict[pair]["priority"] == 1 and self.scanning:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.dd.pair_to_end_of_training_queue(pair)
|
self.dd.pair_to_end_of_training_queue(pair)
|
||||||
dk.save_data(model, coin=pair, keras_model=self.keras)
|
dk.save_data(model, coin=pair)
|
||||||
|
|
||||||
if self.freqai_info.get("purge_old_models", False):
|
if self.freqai_info.get("purge_old_models", False):
|
||||||
self.dd.purge_old_models()
|
self.dd.purge_old_models()
|
||||||
# self.retrain = False
|
# self.retrain = False
|
||||||
|
|
||||||
def set_initial_historic_predictions(self, df: DataFrame, model: Any,
|
def set_initial_historic_predictions(
|
||||||
dk: FreqaiDataKitchen, pair: str) -> None:
|
self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str
|
||||||
|
) -> None:
|
||||||
trained_predictions = model.predict(df)
|
trained_predictions = model.predict(df)
|
||||||
pred_df = DataFrame(trained_predictions, columns=dk.label_list)
|
pred_df = DataFrame(trained_predictions, columns=dk.label_list)
|
||||||
for label in dk.label_list:
|
for label in dk.label_list:
|
||||||
|
@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class BaseRegressionModel(IFreqaiModel):
|
class BaseRegressionModel(IFreqaiModel):
|
||||||
"""
|
"""
|
||||||
User created prediction model. The class needs to override three necessary
|
Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.).
|
||||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
User *must* inherit from this class and set fit() and predict(). See example scripts
|
||||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
such as prediction_models/CatboostPredictionModel.py for guidance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:
|
def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:
|
||||||
|
@ -24,8 +24,9 @@ class FreqaiModelResolver(IResolver):
|
|||||||
object_type = IFreqaiModel
|
object_type = IFreqaiModel
|
||||||
object_type_str = "FreqaiModel"
|
object_type_str = "FreqaiModel"
|
||||||
user_subdir = USERPATH_FREQAIMODELS
|
user_subdir = USERPATH_FREQAIMODELS
|
||||||
initial_search_path = Path(__file__).parent.parent.joinpath(
|
initial_search_path = (
|
||||||
"freqai/prediction_models").resolve()
|
Path(__file__).parent.parent.joinpath("freqai/prediction_models").resolve()
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_freqaimodel(config: Dict) -> IFreqaiModel:
|
def load_freqaimodel(config: Dict) -> IFreqaiModel:
|
||||||
@ -33,6 +34,7 @@ class FreqaiModelResolver(IResolver):
|
|||||||
Load the custom class from config parameter
|
Load the custom class from config parameter
|
||||||
:param config: configuration dictionary
|
:param config: configuration dictionary
|
||||||
"""
|
"""
|
||||||
|
disallowed_models = ["BaseRegressionModel", "BaseTensorFlowModel"]
|
||||||
|
|
||||||
freqaimodel_name = config.get("freqaimodel")
|
freqaimodel_name = config.get("freqaimodel")
|
||||||
if not freqaimodel_name:
|
if not freqaimodel_name:
|
||||||
@ -40,6 +42,11 @@ class FreqaiModelResolver(IResolver):
|
|||||||
"No freqaimodel set. Please use `--freqaimodel` to "
|
"No freqaimodel set. Please use `--freqaimodel` to "
|
||||||
"specify the FreqaiModel class to use.\n"
|
"specify the FreqaiModel class to use.\n"
|
||||||
)
|
)
|
||||||
|
if freqaimodel_name in disallowed_models:
|
||||||
|
raise OperationalException(
|
||||||
|
f"{freqaimodel_name} is a baseclass and cannot be used directly. User must choose "
|
||||||
|
"an existing child class or inherit from this baseclass.\n"
|
||||||
|
)
|
||||||
freqaimodel = FreqaiModelResolver.load_object(
|
freqaimodel = FreqaiModelResolver.load_object(
|
||||||
freqaimodel_name,
|
freqaimodel_name,
|
||||||
config,
|
config,
|
||||||
|
Loading…
Reference in New Issue
Block a user