From e8f040bfbd37108b50dab712716a5abc1ccfc2ec Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Mon, 20 Mar 2023 20:38:43 +0200 Subject: [PATCH] add class_name attribute to freqai interface --- freqtrade/freqai/freqai_interface.py | 1 + .../freqai/prediction_models/PyTorchClassifier.py | 15 +++++++++------ .../prediction_models/PyTorchMLPClassifier.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 8a1ac436b..470ae1911 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -83,6 +83,7 @@ class IFreqaiModel(ABC): self.CONV_WIDTH = self.freqai_info.get('conv_width', 1) if self.ft_params.get("inlier_metric_window", 0): self.CONV_WIDTH = self.ft_params.get("inlier_metric_window", 0) * 2 + self.class_names: List[str] = [] # used in classification children classes self.pair_it = 0 self.pair_it_train = 0 self.total_pairs = len(self.config.get("exchange", {}).get("pair_whitelist")) diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifier.py b/freqtrade/freqai/prediction_models/PyTorchClassifier.py index b14a89b38..e47021a55 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifier.py @@ -22,8 +22,11 @@ class PyTorchClassifier(BasePyTorchModel): User must implement fit method Important! - User must declare the target class names in the strategy, under - IStrategy.set_freqai_targets method. + + - User must declare the target class names in the strategy, + under IStrategy.set_freqai_targets method. + + for example, in your strategy: ``` def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs): self.freqai.class_names = ["down", "up"] @@ -31,7 +34,6 @@ class PyTorchClassifier(BasePyTorchModel): dataframe["close"], 'up', 'down') return dataframe - ``` """ def __init__(self, **kwargs): super().__init__(**kwargs) @@ -55,7 +57,7 @@ class PyTorchClassifier(BasePyTorchModel): if not class_names: raise ValueError( "Missing class names. " - "self.model.model_meta_data[\"class_names\"] is None." + "self.model.model_meta_data['class_names'] is None." ) if not self.class_name_to_index: @@ -136,10 +138,11 @@ class PyTorchClassifier(BasePyTorchModel): self.encode_class_names(data_dictionary, dk, class_names) def get_class_names(self) -> List[str]: - if not hasattr(self, "class_names"): + if not self.class_names: raise ValueError( - "Missing attribute: self.class_names " + "self.class_names is empty, " "set self.freqai.class_names = ['class a', 'class b', 'class c'] " "inside IStrategy.set_freqai_targets method." ) + return self.class_names diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py index 6b7d9c034..373b81a82 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py @@ -56,7 +56,7 @@ class PyTorchMLPClassifier(PyTorchClassifier): User sets up the training and test data to fit their desired model here :param data_dictionary: the dictionary constructed by DataHandler to hold all the training and test data/labels. - :raises ValueError: If self.class_names is empty. + :raises ValueError: If self.class_names is not defined in the parent class. """ class_names = self.get_class_names()