Add ground work for TensorFlow models, add protections from common mistakes

This commit is contained in:
robcaulk 2022-07-12 18:09:17 +02:00
parent fea63fba12
commit ef409dd345
4 changed files with 44 additions and 21 deletions

View File

@ -57,6 +57,7 @@ class FreqaiDataKitchen:
self.live = live self.live = live
self.pair = pair self.pair = pair
self.svm_model: linear_model.SGDOneClassSVM = None self.svm_model: linear_model.SGDOneClassSVM = None
self.keras = self.freqai_config.get("keras", False)
self.set_all_pairs() self.set_all_pairs()
if not self.live: if not self.live:
self.full_timerange = self.create_fulltimerange( self.full_timerange = self.create_fulltimerange(
@ -92,7 +93,7 @@ class FreqaiDataKitchen:
return return
def save_data(self, model: Any, coin: str = "", keras_model=False, label=None) -> None: def save_data(self, model: Any, coin: str = "", label=None) -> None:
""" """
Saves all data associated with a model for a single sub-train time range Saves all data associated with a model for a single sub-train time range
:params: :params:
@ -106,7 +107,7 @@ class FreqaiDataKitchen:
save_path = Path(self.data_path) save_path = Path(self.data_path)
# Save the trained model # Save the trained model
if not keras_model: if not self.keras:
dump(model, save_path / f"{self.model_filename}_model.joblib") dump(model, save_path / f"{self.model_filename}_model.joblib")
else: else:
model.save(save_path / f"{self.model_filename}_model.h5") model.save(save_path / f"{self.model_filename}_model.h5")
@ -140,7 +141,7 @@ class FreqaiDataKitchen:
return return
def load_data(self, coin: str = "", keras_model=False) -> Any: def load_data(self, coin: str = "") -> Any:
""" """
loads all data required to make a prediction on a sub-train time range loads all data required to make a prediction on a sub-train time range
:returns: :returns:
@ -174,7 +175,7 @@ class FreqaiDataKitchen:
# try to access model in memory instead of loading object from disk to save time # try to access model in memory instead of loading object from disk to save time
if self.live and self.model_filename in self.dd.model_dictionary: if self.live and self.model_filename in self.dd.model_dictionary:
model = self.dd.model_dictionary[self.model_filename] model = self.dd.model_dictionary[self.model_filename]
elif not keras_model: elif not self.keras:
model = load(self.data_path / str(self.model_filename + "_model.joblib")) model = load(self.data_path / str(self.model_filename + "_model.joblib"))
else: else:
from tensorflow import keras from tensorflow import keras
@ -559,6 +560,13 @@ class FreqaiDataKitchen:
predict: bool = If true, inference an existing SVM model, else construct one predict: bool = If true, inference an existing SVM model, else construct one
""" """
if self.keras:
logger.warning("SVM outlier removal not currently supported for Keras based models. "
"Skipping user requested function.")
if predict:
self.do_predict = np.ones(len(self.data_dictionary["prediction_features"]))
return
if predict: if predict:
assert self.svm_model, "No svm model available for outlier removal" assert self.svm_model, "No svm model available for outlier removal"
y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"]) y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"])

View File

@ -69,6 +69,9 @@ class IFreqaiModel(ABC):
self.ready_to_scan = False self.ready_to_scan = False
self.first = True self.first = True
self.keras = self.freqai_info.get("keras", False) self.keras = self.freqai_info.get("keras", False)
if self.keras and self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
self.freqai_info["feature_parameters"]["DI_threshold"] = 0
logger.warning("DI threshold is not configured for Keras models yet. Deactivating.")
self.CONV_WIDTH = self.freqai_info.get("conv_width", 2) self.CONV_WIDTH = self.freqai_info.get("conv_width", 2)
def assert_config(self, config: Dict[str, Any]) -> None: def assert_config(self, config: Dict[str, Any]) -> None:
@ -197,9 +200,9 @@ class IFreqaiModel(ABC):
self.model = self.train(dataframe_train, metadata["pair"], dk) self.model = self.train(dataframe_train, metadata["pair"], dk)
self.dd.pair_dict[metadata["pair"]]["trained_timestamp"] = trained_timestamp.stopts self.dd.pair_dict[metadata["pair"]]["trained_timestamp"] = trained_timestamp.stopts
dk.set_new_model_names(metadata["pair"], trained_timestamp) dk.set_new_model_names(metadata["pair"], trained_timestamp)
dk.save_data(self.model, metadata["pair"], keras_model=self.keras) dk.save_data(self.model, metadata["pair"])
else: else:
self.model = dk.load_data(metadata["pair"], keras_model=self.keras) self.model = dk.load_data(metadata["pair"])
self.check_if_feature_list_matches_strategy(dataframe_train, dk) self.check_if_feature_list_matches_strategy(dataframe_train, dk)
@ -276,7 +279,7 @@ class IFreqaiModel(ABC):
) )
# load the model and associated data into the data kitchen # load the model and associated data into the data kitchen
self.model = dk.load_data(coin=metadata["pair"], keras_model=self.keras) self.model = dk.load_data(coin=metadata["pair"])
if not self.model: if not self.model:
logger.warning( logger.warning(
@ -353,13 +356,15 @@ class IFreqaiModel(ABC):
of how outlier data points are dropped from the dataframe used for training. of how outlier data points are dropped from the dataframe used for training.
""" """
if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"): if self.freqai_info.get("feature_parameters", {}).get(
"principal_component_analysis", False
):
dk.principal_component_analysis() dk.principal_component_analysis()
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"): if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False):
dk.use_SVM_to_remove_outliers(predict=False) dk.use_SVM_to_remove_outliers(predict=False)
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"): if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
dk.data["avg_mean_dist"] = dk.compute_distances() dk.data["avg_mean_dist"] = dk.compute_distances()
# if self.feature_parameters["determine_statistical_distributions"]: # if self.feature_parameters["determine_statistical_distributions"]:
@ -378,13 +383,15 @@ class IFreqaiModel(ABC):
of how the do_predict vector is modified. do_predict is ultimately passed back to strategy of how the do_predict vector is modified. do_predict is ultimately passed back to strategy
for buy signals. for buy signals.
""" """
if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"): if self.freqai_info.get("feature_parameters", {}).get(
"principal_component_analysis", False
):
dk.pca_transform(dataframe) dk.pca_transform(dataframe)
if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"): if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False):
dk.use_SVM_to_remove_outliers(predict=True) dk.use_SVM_to_remove_outliers(predict=True)
if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"): if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0):
dk.check_if_pred_in_training_spaces() dk.check_if_pred_in_training_spaces()
# if self.feature_parameters["determine_statistical_distributions"]: # if self.feature_parameters["determine_statistical_distributions"]:
@ -479,14 +486,15 @@ class IFreqaiModel(ABC):
if self.dd.pair_dict[pair]["priority"] == 1 and self.scanning: if self.dd.pair_dict[pair]["priority"] == 1 and self.scanning:
with self.lock: with self.lock:
self.dd.pair_to_end_of_training_queue(pair) self.dd.pair_to_end_of_training_queue(pair)
dk.save_data(model, coin=pair, keras_model=self.keras) dk.save_data(model, coin=pair)
if self.freqai_info.get("purge_old_models", False): if self.freqai_info.get("purge_old_models", False):
self.dd.purge_old_models() self.dd.purge_old_models()
# self.retrain = False # self.retrain = False
def set_initial_historic_predictions(self, df: DataFrame, model: Any, def set_initial_historic_predictions(
dk: FreqaiDataKitchen, pair: str) -> None: self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str
) -> None:
trained_predictions = model.predict(df) trained_predictions = model.predict(df)
pred_df = DataFrame(trained_predictions, columns=dk.label_list) pred_df = DataFrame(trained_predictions, columns=dk.label_list)
for label in dk.label_list: for label in dk.label_list:

View File

@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
class BaseRegressionModel(IFreqaiModel): class BaseRegressionModel(IFreqaiModel):
""" """
User created prediction model. The class needs to override three necessary Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.).
functions, predict(), train(), fit(). The class inherits ModelHandler which User *must* inherit from this class and set fit() and predict(). See example scripts
has its own DataHandler where data is held, saved, loaded, and managed. such as prediction_models/CatboostPredictionModel.py for guidance.
""" """
def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame: def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame:

View File

@ -24,8 +24,9 @@ class FreqaiModelResolver(IResolver):
object_type = IFreqaiModel object_type = IFreqaiModel
object_type_str = "FreqaiModel" object_type_str = "FreqaiModel"
user_subdir = USERPATH_FREQAIMODELS user_subdir = USERPATH_FREQAIMODELS
initial_search_path = Path(__file__).parent.parent.joinpath( initial_search_path = (
"freqai/prediction_models").resolve() Path(__file__).parent.parent.joinpath("freqai/prediction_models").resolve()
)
@staticmethod @staticmethod
def load_freqaimodel(config: Dict) -> IFreqaiModel: def load_freqaimodel(config: Dict) -> IFreqaiModel:
@ -33,6 +34,7 @@ class FreqaiModelResolver(IResolver):
Load the custom class from config parameter Load the custom class from config parameter
:param config: configuration dictionary :param config: configuration dictionary
""" """
disallowed_models = ["BaseRegressionModel", "BaseTensorFlowModel"]
freqaimodel_name = config.get("freqaimodel") freqaimodel_name = config.get("freqaimodel")
if not freqaimodel_name: if not freqaimodel_name:
@ -40,6 +42,11 @@ class FreqaiModelResolver(IResolver):
"No freqaimodel set. Please use `--freqaimodel` to " "No freqaimodel set. Please use `--freqaimodel` to "
"specify the FreqaiModel class to use.\n" "specify the FreqaiModel class to use.\n"
) )
if freqaimodel_name in disallowed_models:
raise OperationalException(
f"{freqaimodel_name} is a baseclass and cannot be used directly. User must choose "
"an existing child class or inherit from this baseclass.\n"
)
freqaimodel = FreqaiModelResolver.load_object( freqaimodel = FreqaiModelResolver.load_object(
freqaimodel_name, freqaimodel_name,
config, config,