diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index fc0a7600e..d02f1d896 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -104,8 +104,6 @@ class PyTorchModelTrainer: .long() .view(labels_view) ) - # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. - # need to resolve it per ClassifierModel data_loader = DataLoader( dataset, diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index aead0e46c..e58fa9cff 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -24,16 +24,18 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): def __init__(self, **kwargs): super().__init__(**kwargs) - self.multiclass_names = self.freqai_info["multiclass_target_names"] + self.multiclass_names = self.freqai_info.get("multiclass_target_names", None) + logger.info(f"setting multiclass_names: {self.multiclass_names}") if not self.multiclass_names: raise OperationalException( - "Missing 'multiclass_names' in freqai_info," - " multi class pytorch model requires predefined list of" - " class names matching the strategy being used" + "Missing 'multiclass_names' in freqai_info, " + "multi class pytorch classifier model requires predefined list of " + "class names matching the strategy being used." ) self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)} self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)} + logger.info(f"class_name_to_index: {self.class_name_to_index}") model_training_parameters = self.freqai_info["model_training_parameters"] self.n_hidden = model_training_parameters.get("n_hidden", 1024) @@ -48,7 +50,6 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): :param tensor_dictionary: the dictionary constructed by DataHandler to hold all the training and test data/labels. """ - self.encode_classes_name(data_dictionary, dk) n_features = data_dictionary['train_features'].shape[-1] model = PyTorchMLPModel( @@ -124,9 +125,12 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): non_defined_labels = set(labels) - set(self.multiclass_names) if len(non_defined_labels) != 0: raise OperationalException( - f"Found non defined labels {non_defined_labels} ", - f"expecting labels {self.multiclass_names}" + f"Found non defined labels: {non_defined_labels}, ", + f"expecting labels: {self.multiclass_names}" ) - def decode_classes_name(self, classes: List[int]) -> List[str]: - return list(map(lambda x: self.index_to_class_name[x], classes)) \ No newline at end of file + def decode_classes_name(self, classes: torch.Tensor[int]) -> List[str]: + """ + decode class name int -> str + """ + return list(map(lambda x: self.index_to_class_name[x.item()], classes))