change documentation and small bugfix
This commit is contained in:
parent
76fbec0c17
commit
1805db2b07
@ -104,8 +104,6 @@ class PyTorchModelTrainer:
|
|||||||
.long()
|
.long()
|
||||||
.view(labels_view)
|
.view(labels_view)
|
||||||
)
|
)
|
||||||
# todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes].
|
|
||||||
# need to resolve it per ClassifierModel
|
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -24,16 +24,18 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.multiclass_names = self.freqai_info["multiclass_target_names"]
|
self.multiclass_names = self.freqai_info.get("multiclass_target_names", None)
|
||||||
|
logger.info(f"setting multiclass_names: {self.multiclass_names}")
|
||||||
if not self.multiclass_names:
|
if not self.multiclass_names:
|
||||||
raise OperationalException(
|
raise OperationalException(
|
||||||
"Missing 'multiclass_names' in freqai_info,"
|
"Missing 'multiclass_names' in freqai_info, "
|
||||||
" multi class pytorch model requires predefined list of"
|
"multi class pytorch classifier model requires predefined list of "
|
||||||
" class names matching the strategy being used"
|
"class names matching the strategy being used."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
|
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
|
||||||
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
|
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
|
||||||
|
logger.info(f"class_name_to_index: {self.class_name_to_index}")
|
||||||
|
|
||||||
model_training_parameters = self.freqai_info["model_training_parameters"]
|
model_training_parameters = self.freqai_info["model_training_parameters"]
|
||||||
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
||||||
@ -48,7 +50,6 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
:param tensor_dictionary: the dictionary constructed by DataHandler to hold
|
:param tensor_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
all the training and test data/labels.
|
all the training and test data/labels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.encode_classes_name(data_dictionary, dk)
|
self.encode_classes_name(data_dictionary, dk)
|
||||||
n_features = data_dictionary['train_features'].shape[-1]
|
n_features = data_dictionary['train_features'].shape[-1]
|
||||||
model = PyTorchMLPModel(
|
model = PyTorchMLPModel(
|
||||||
@ -124,9 +125,12 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
non_defined_labels = set(labels) - set(self.multiclass_names)
|
non_defined_labels = set(labels) - set(self.multiclass_names)
|
||||||
if len(non_defined_labels) != 0:
|
if len(non_defined_labels) != 0:
|
||||||
raise OperationalException(
|
raise OperationalException(
|
||||||
f"Found non defined labels {non_defined_labels} ",
|
f"Found non defined labels: {non_defined_labels}, ",
|
||||||
f"expecting labels {self.multiclass_names}"
|
f"expecting labels: {self.multiclass_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode_classes_name(self, classes: List[int]) -> List[str]:
|
def decode_classes_name(self, classes: torch.Tensor[int]) -> List[str]:
|
||||||
return list(map(lambda x: self.index_to_class_name[x], classes))
|
"""
|
||||||
|
decode class name int -> str
|
||||||
|
"""
|
||||||
|
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
|
||||||
|
Loading…
Reference in New Issue
Block a user