change documentation and small bugfix
This commit is contained in:
		| @@ -104,8 +104,6 @@ class PyTorchModelTrainer: | ||||
|                 .long() | ||||
|                 .view(labels_view) | ||||
|             ) | ||||
|             # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. | ||||
|             #  need to resolve it per ClassifierModel | ||||
|  | ||||
|             data_loader = DataLoader( | ||||
|                 dataset, | ||||
|   | ||||
| @@ -24,16 +24,18 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__(**kwargs) | ||||
|         self.multiclass_names = self.freqai_info["multiclass_target_names"] | ||||
|         self.multiclass_names = self.freqai_info.get("multiclass_target_names", None) | ||||
|         logger.info(f"setting multiclass_names: {self.multiclass_names}") | ||||
|         if not self.multiclass_names: | ||||
|             raise OperationalException( | ||||
|                 "Missing 'multiclass_names' in freqai_info," | ||||
|                 " multi class pytorch model requires predefined list of" | ||||
|                 " class names matching the strategy being used" | ||||
|                 "Missing 'multiclass_names' in freqai_info, " | ||||
|                 "multi class pytorch classifier model requires predefined list of " | ||||
|                 "class names matching the strategy being used." | ||||
|             ) | ||||
|  | ||||
|         self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)} | ||||
|         self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)} | ||||
|         logger.info(f"class_name_to_index: {self.class_name_to_index}") | ||||
|  | ||||
|         model_training_parameters = self.freqai_info["model_training_parameters"] | ||||
|         self.n_hidden = model_training_parameters.get("n_hidden", 1024) | ||||
| @@ -48,7 +50,6 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|         :param tensor_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         """ | ||||
|  | ||||
|         self.encode_classes_name(data_dictionary, dk) | ||||
|         n_features = data_dictionary['train_features'].shape[-1] | ||||
|         model = PyTorchMLPModel( | ||||
| @@ -124,9 +125,12 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|         non_defined_labels = set(labels) - set(self.multiclass_names) | ||||
|         if len(non_defined_labels) != 0: | ||||
|             raise OperationalException( | ||||
|                 f"Found non defined labels {non_defined_labels} ", | ||||
|                 f"expecting labels {self.multiclass_names}" | ||||
|                 f"Found non defined labels: {non_defined_labels}, ", | ||||
|                 f"expecting labels: {self.multiclass_names}" | ||||
|             ) | ||||
|  | ||||
|     def decode_classes_name(self, classes: List[int]) -> List[str]: | ||||
|         return list(map(lambda x: self.index_to_class_name[x], classes)) | ||||
|     def decode_classes_name(self, classes: torch.Tensor[int]) -> List[str]: | ||||
|         """ | ||||
|         decode class name int -> str | ||||
|         """ | ||||
|         return list(map(lambda x: self.index_to_class_name[x.item()], classes)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user