add todo - currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel

This commit is contained in:
Yinon Polak 2023-03-06 16:41:47 +02:00
parent b1ac2bf515
commit 348a08f1c4

View File

@ -101,7 +101,7 @@ class PyTorchModelTrainer:
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(), torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values) torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
.long() .long()
.view(labels_view) .view(labels_view) # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel
) )
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,