add class_name attribute to freqai interface
This commit is contained in:
		| @@ -83,6 +83,7 @@ class IFreqaiModel(ABC): | ||||
|         self.CONV_WIDTH = self.freqai_info.get('conv_width', 1) | ||||
|         if self.ft_params.get("inlier_metric_window", 0): | ||||
|             self.CONV_WIDTH = self.ft_params.get("inlier_metric_window", 0) * 2 | ||||
|         self.class_names: List[str] = []  # used in classification children classes | ||||
|         self.pair_it = 0 | ||||
|         self.pair_it_train = 0 | ||||
|         self.total_pairs = len(self.config.get("exchange", {}).get("pair_whitelist")) | ||||
|   | ||||
| @@ -22,8 +22,11 @@ class PyTorchClassifier(BasePyTorchModel): | ||||
|     User must implement fit method | ||||
|  | ||||
|     Important! | ||||
|     User must declare the target class names in the strategy, under | ||||
|     IStrategy.set_freqai_targets method. | ||||
|  | ||||
|     - User must declare the target class names in the strategy, | ||||
|     under IStrategy.set_freqai_targets method. | ||||
|  | ||||
|     for example, in your strategy: | ||||
|     ``` | ||||
|         def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs): | ||||
|             self.freqai.class_names = ["down", "up"] | ||||
| @@ -31,7 +34,6 @@ class PyTorchClassifier(BasePyTorchModel): | ||||
|                                                   dataframe["close"], 'up', 'down') | ||||
|  | ||||
|             return dataframe | ||||
|     ``` | ||||
|     """ | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__(**kwargs) | ||||
| @@ -55,7 +57,7 @@ class PyTorchClassifier(BasePyTorchModel): | ||||
|         if not class_names: | ||||
|             raise ValueError( | ||||
|                 "Missing class names. " | ||||
|                 "self.model.model_meta_data[\"class_names\"] is None." | ||||
|                 "self.model.model_meta_data['class_names'] is None." | ||||
|             ) | ||||
|  | ||||
|         if not self.class_name_to_index: | ||||
| @@ -136,10 +138,11 @@ class PyTorchClassifier(BasePyTorchModel): | ||||
|         self.encode_class_names(data_dictionary, dk, class_names) | ||||
|  | ||||
|     def get_class_names(self) -> List[str]: | ||||
|         if not hasattr(self, "class_names"): | ||||
|         if not self.class_names: | ||||
|             raise ValueError( | ||||
|                 "Missing attribute: self.class_names " | ||||
|                 "self.class_names is empty, " | ||||
|                 "set self.freqai.class_names = ['class a', 'class b', 'class c'] " | ||||
|                 "inside IStrategy.set_freqai_targets method." | ||||
|             ) | ||||
|  | ||||
|         return self.class_names | ||||
|   | ||||
| @@ -56,7 +56,7 @@ class PyTorchMLPClassifier(PyTorchClassifier): | ||||
|         User sets up the training and test data to fit their desired model here | ||||
|         :param data_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         :raises ValueError: If self.class_names is empty. | ||||
|         :raises ValueError: If self.class_names is not defined in the parent class. | ||||
|         """ | ||||
|  | ||||
|         class_names = self.get_class_names() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user