diff --git a/freqtrade/freqai/data_kitchen.py b/freqtrade/freqai/data_kitchen.py index 1f78df0f8..56c1a67ed 100644 --- a/freqtrade/freqai/data_kitchen.py +++ b/freqtrade/freqai/data_kitchen.py @@ -57,6 +57,7 @@ class FreqaiDataKitchen: self.live = live self.pair = pair self.svm_model: linear_model.SGDOneClassSVM = None + self.keras = self.freqai_config.get("keras", False) self.set_all_pairs() if not self.live: self.full_timerange = self.create_fulltimerange( @@ -92,7 +93,7 @@ class FreqaiDataKitchen: return - def save_data(self, model: Any, coin: str = "", keras_model=False, label=None) -> None: + def save_data(self, model: Any, coin: str = "", label=None) -> None: """ Saves all data associated with a model for a single sub-train time range :params: @@ -106,7 +107,7 @@ class FreqaiDataKitchen: save_path = Path(self.data_path) # Save the trained model - if not keras_model: + if not self.keras: dump(model, save_path / f"{self.model_filename}_model.joblib") else: model.save(save_path / f"{self.model_filename}_model.h5") @@ -140,7 +141,7 @@ class FreqaiDataKitchen: return - def load_data(self, coin: str = "", keras_model=False) -> Any: + def load_data(self, coin: str = "") -> Any: """ loads all data required to make a prediction on a sub-train time range :returns: @@ -174,7 +175,7 @@ class FreqaiDataKitchen: # try to access model in memory instead of loading object from disk to save time if self.live and self.model_filename in self.dd.model_dictionary: model = self.dd.model_dictionary[self.model_filename] - elif not keras_model: + elif not self.keras: model = load(self.data_path / str(self.model_filename + "_model.joblib")) else: from tensorflow import keras @@ -559,6 +560,13 @@ class FreqaiDataKitchen: predict: bool = If true, inference an existing SVM model, else construct one """ + if self.keras: + logger.warning("SVM outlier removal not currently supported for Keras based models. " + "Skipping user requested function.") + if predict: + self.do_predict = np.ones(len(self.data_dictionary["prediction_features"])) + return + if predict: assert self.svm_model, "No svm model available for outlier removal" y_pred = self.svm_model.predict(self.data_dictionary["prediction_features"]) diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index b03b1f3b0..56a179dc3 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -69,6 +69,9 @@ class IFreqaiModel(ABC): self.ready_to_scan = False self.first = True self.keras = self.freqai_info.get("keras", False) + if self.keras and self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0): + self.freqai_info["feature_parameters"]["DI_threshold"] = 0 + logger.warning("DI threshold is not configured for Keras models yet. Deactivating.") self.CONV_WIDTH = self.freqai_info.get("conv_width", 2) def assert_config(self, config: Dict[str, Any]) -> None: @@ -197,9 +200,9 @@ class IFreqaiModel(ABC): self.model = self.train(dataframe_train, metadata["pair"], dk) self.dd.pair_dict[metadata["pair"]]["trained_timestamp"] = trained_timestamp.stopts dk.set_new_model_names(metadata["pair"], trained_timestamp) - dk.save_data(self.model, metadata["pair"], keras_model=self.keras) + dk.save_data(self.model, metadata["pair"]) else: - self.model = dk.load_data(metadata["pair"], keras_model=self.keras) + self.model = dk.load_data(metadata["pair"]) self.check_if_feature_list_matches_strategy(dataframe_train, dk) @@ -276,7 +279,7 @@ class IFreqaiModel(ABC): ) # load the model and associated data into the data kitchen - self.model = dk.load_data(coin=metadata["pair"], keras_model=self.keras) + self.model = dk.load_data(coin=metadata["pair"]) if not self.model: logger.warning( @@ -353,13 +356,15 @@ class IFreqaiModel(ABC): of how outlier data points are dropped from the dataframe used for training. """ - if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"): + if self.freqai_info.get("feature_parameters", {}).get( + "principal_component_analysis", False + ): dk.principal_component_analysis() - if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"): + if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False): dk.use_SVM_to_remove_outliers(predict=False) - if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"): + if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0): dk.data["avg_mean_dist"] = dk.compute_distances() # if self.feature_parameters["determine_statistical_distributions"]: @@ -378,13 +383,15 @@ class IFreqaiModel(ABC): of how the do_predict vector is modified. do_predict is ultimately passed back to strategy for buy signals. """ - if self.freqai_info.get("feature_parameters", {}).get("principal_component_analysis"): + if self.freqai_info.get("feature_parameters", {}).get( + "principal_component_analysis", False + ): dk.pca_transform(dataframe) - if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers"): + if self.freqai_info.get("feature_parameters", {}).get("use_SVM_to_remove_outliers", False): dk.use_SVM_to_remove_outliers(predict=True) - if self.freqai_info.get("feature_parameters", {}).get("DI_threshold"): + if self.freqai_info.get("feature_parameters", {}).get("DI_threshold", 0): dk.check_if_pred_in_training_spaces() # if self.feature_parameters["determine_statistical_distributions"]: @@ -479,14 +486,15 @@ class IFreqaiModel(ABC): if self.dd.pair_dict[pair]["priority"] == 1 and self.scanning: with self.lock: self.dd.pair_to_end_of_training_queue(pair) - dk.save_data(model, coin=pair, keras_model=self.keras) + dk.save_data(model, coin=pair) if self.freqai_info.get("purge_old_models", False): self.dd.purge_old_models() # self.retrain = False - def set_initial_historic_predictions(self, df: DataFrame, model: Any, - dk: FreqaiDataKitchen, pair: str) -> None: + def set_initial_historic_predictions( + self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str + ) -> None: trained_predictions = model.predict(df) pred_df = DataFrame(trained_predictions, columns=dk.label_list) for label in dk.label_list: diff --git a/freqtrade/freqai/prediction_models/BaseRegressionModel.py b/freqtrade/freqai/prediction_models/BaseRegressionModel.py index 260e24182..f9a9bb69f 100644 --- a/freqtrade/freqai/prediction_models/BaseRegressionModel.py +++ b/freqtrade/freqai/prediction_models/BaseRegressionModel.py @@ -12,9 +12,9 @@ logger = logging.getLogger(__name__) class BaseRegressionModel(IFreqaiModel): """ - User created prediction model. The class needs to override three necessary - functions, predict(), train(), fit(). The class inherits ModelHandler which - has its own DataHandler where data is held, saved, loaded, and managed. + Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.). + User *must* inherit from this class and set fit() and predict(). See example scripts + such as prediction_models/CatboostPredictionModel.py for guidance. """ def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame: diff --git a/freqtrade/resolvers/freqaimodel_resolver.py b/freqtrade/resolvers/freqaimodel_resolver.py index e666b462c..0fcfca363 100644 --- a/freqtrade/resolvers/freqaimodel_resolver.py +++ b/freqtrade/resolvers/freqaimodel_resolver.py @@ -24,8 +24,9 @@ class FreqaiModelResolver(IResolver): object_type = IFreqaiModel object_type_str = "FreqaiModel" user_subdir = USERPATH_FREQAIMODELS - initial_search_path = Path(__file__).parent.parent.joinpath( - "freqai/prediction_models").resolve() + initial_search_path = ( + Path(__file__).parent.parent.joinpath("freqai/prediction_models").resolve() + ) @staticmethod def load_freqaimodel(config: Dict) -> IFreqaiModel: @@ -33,6 +34,7 @@ class FreqaiModelResolver(IResolver): Load the custom class from config parameter :param config: configuration dictionary """ + disallowed_models = ["BaseRegressionModel", "BaseTensorFlowModel"] freqaimodel_name = config.get("freqaimodel") if not freqaimodel_name: @@ -40,6 +42,11 @@ class FreqaiModelResolver(IResolver): "No freqaimodel set. Please use `--freqaimodel` to " "specify the FreqaiModel class to use.\n" ) + if freqaimodel_name in disallowed_models: + raise OperationalException( + f"{freqaimodel_name} is a baseclass and cannot be used directly. User must choose " + "an existing child class or inherit from this baseclass.\n" + ) freqaimodel = FreqaiModelResolver.load_object( freqaimodel_name, config,