add todo - currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel
This commit is contained in:
		| @@ -101,7 +101,7 @@ class PyTorchModelTrainer: | ||||
|                 torch.from_numpy(data_dictionary[f'{split}_features'].values).float(), | ||||
|                 torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values) | ||||
|                 .long() | ||||
|                 .view(labels_view) | ||||
|                 .view(labels_view)  # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel | ||||
|             ) | ||||
|             data_loader = DataLoader( | ||||
|                 dataset, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user