change documentation and small bugfix

This commit is contained in:
Yinon Polak 2023-03-08 15:38:22 +02:00
parent 76fbec0c17
commit 1805db2b07
2 changed files with 13 additions and 11 deletions

View File

@ -104,8 +104,6 @@ class PyTorchModelTrainer:
.long() .long()
.view(labels_view) .view(labels_view)
) )
# todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes].
# need to resolve it per ClassifierModel
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,

View File

@ -24,16 +24,18 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.multiclass_names = self.freqai_info["multiclass_target_names"] self.multiclass_names = self.freqai_info.get("multiclass_target_names", None)
logger.info(f"setting multiclass_names: {self.multiclass_names}")
if not self.multiclass_names: if not self.multiclass_names:
raise OperationalException( raise OperationalException(
"Missing 'multiclass_names' in freqai_info, " "Missing 'multiclass_names' in freqai_info, "
" multi class pytorch model requires predefined list of" "multi class pytorch classifier model requires predefined list of "
" class names matching the strategy being used" "class names matching the strategy being used."
) )
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)} self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)} self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
logger.info(f"class_name_to_index: {self.class_name_to_index}")
model_training_parameters = self.freqai_info["model_training_parameters"] model_training_parameters = self.freqai_info["model_training_parameters"]
self.n_hidden = model_training_parameters.get("n_hidden", 1024) self.n_hidden = model_training_parameters.get("n_hidden", 1024)
@ -48,7 +50,6 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:param tensor_dictionary: the dictionary constructed by DataHandler to hold :param tensor_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels. all the training and test data/labels.
""" """
self.encode_classes_name(data_dictionary, dk) self.encode_classes_name(data_dictionary, dk)
n_features = data_dictionary['train_features'].shape[-1] n_features = data_dictionary['train_features'].shape[-1]
model = PyTorchMLPModel( model = PyTorchMLPModel(
@ -124,9 +125,12 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
non_defined_labels = set(labels) - set(self.multiclass_names) non_defined_labels = set(labels) - set(self.multiclass_names)
if len(non_defined_labels) != 0: if len(non_defined_labels) != 0:
raise OperationalException( raise OperationalException(
f"Found non defined labels {non_defined_labels} ", f"Found non defined labels: {non_defined_labels}, ",
f"expecting labels {self.multiclass_names}" f"expecting labels: {self.multiclass_names}"
) )
def decode_classes_name(self, classes: List[int]) -> List[str]: def decode_classes_name(self, classes: torch.Tensor[int]) -> List[str]:
return list(map(lambda x: self.index_to_class_name[x], classes)) """
decode class name int -> str
"""
return list(map(lambda x: self.index_to_class_name[x.item()], classes))