change documentation and small bugfix
This commit is contained in:
parent
76fbec0c17
commit
1805db2b07
@ -104,8 +104,6 @@ class PyTorchModelTrainer:
|
||||
.long()
|
||||
.view(labels_view)
|
||||
)
|
||||
# todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes].
|
||||
# need to resolve it per ClassifierModel
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
|
@ -24,16 +24,18 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.multiclass_names = self.freqai_info["multiclass_target_names"]
|
||||
self.multiclass_names = self.freqai_info.get("multiclass_target_names", None)
|
||||
logger.info(f"setting multiclass_names: {self.multiclass_names}")
|
||||
if not self.multiclass_names:
|
||||
raise OperationalException(
|
||||
"Missing 'multiclass_names' in freqai_info, "
|
||||
" multi class pytorch model requires predefined list of"
|
||||
" class names matching the strategy being used"
|
||||
"multi class pytorch classifier model requires predefined list of "
|
||||
"class names matching the strategy being used."
|
||||
)
|
||||
|
||||
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
|
||||
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
|
||||
logger.info(f"class_name_to_index: {self.class_name_to_index}")
|
||||
|
||||
model_training_parameters = self.freqai_info["model_training_parameters"]
|
||||
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
||||
@ -48,7 +50,6 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
:param tensor_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
"""
|
||||
|
||||
self.encode_classes_name(data_dictionary, dk)
|
||||
n_features = data_dictionary['train_features'].shape[-1]
|
||||
model = PyTorchMLPModel(
|
||||
@ -124,9 +125,12 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
non_defined_labels = set(labels) - set(self.multiclass_names)
|
||||
if len(non_defined_labels) != 0:
|
||||
raise OperationalException(
|
||||
f"Found non defined labels {non_defined_labels} ",
|
||||
f"expecting labels {self.multiclass_names}"
|
||||
f"Found non defined labels: {non_defined_labels}, ",
|
||||
f"expecting labels: {self.multiclass_names}"
|
||||
)
|
||||
|
||||
def decode_classes_name(self, classes: List[int]) -> List[str]:
|
||||
return list(map(lambda x: self.index_to_class_name[x], classes))
|
||||
def decode_classes_name(self, classes: torch.Tensor[int]) -> List[str]:
|
||||
"""
|
||||
decode class name int -> str
|
||||
"""
|
||||
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
|
||||
|
Loading…
Reference in New Issue
Block a user