change documentation and small bugfix

This commit is contained in:
Yinon Polak 2023-03-08 15:38:22 +02:00
parent 76fbec0c17
commit 1805db2b07
2 changed files with 13 additions and 11 deletions

View File

@ -104,8 +104,6 @@ class PyTorchModelTrainer:
.long()
.view(labels_view)
)
# todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes].
# need to resolve it per ClassifierModel
data_loader = DataLoader(
dataset,

View File

@ -24,16 +24,18 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.multiclass_names = self.freqai_info["multiclass_target_names"]
self.multiclass_names = self.freqai_info.get("multiclass_target_names", None)
logger.info(f"setting multiclass_names: {self.multiclass_names}")
if not self.multiclass_names:
raise OperationalException(
"Missing 'multiclass_names' in freqai_info, "
" multi class pytorch model requires predefined list of"
" class names matching the strategy being used"
"multi class pytorch classifier model requires predefined list of "
"class names matching the strategy being used."
)
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
logger.info(f"class_name_to_index: {self.class_name_to_index}")
model_training_parameters = self.freqai_info["model_training_parameters"]
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
@ -48,7 +50,6 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:param tensor_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels.
"""
self.encode_classes_name(data_dictionary, dk)
n_features = data_dictionary['train_features'].shape[-1]
model = PyTorchMLPModel(
@ -124,9 +125,12 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
non_defined_labels = set(labels) - set(self.multiclass_names)
if len(non_defined_labels) != 0:
raise OperationalException(
f"Found non defined labels {non_defined_labels} ",
f"expecting labels {self.multiclass_names}"
f"Found non defined labels: {non_defined_labels}, ",
f"expecting labels: {self.multiclass_names}"
)
def decode_classes_name(self, classes: List[int]) -> List[str]:
return list(map(lambda x: self.index_to_class_name[x], classes))
def decode_classes_name(self, classes: torch.Tensor[int]) -> List[str]:
"""
decode class name int -> str
"""
return list(map(lambda x: self.index_to_class_name[x.item()], classes))