add todo - currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel
This commit is contained in:
parent
b1ac2bf515
commit
348a08f1c4
@ -101,7 +101,7 @@ class PyTorchModelTrainer:
|
|||||||
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
|
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
|
||||||
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
|
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
|
||||||
.long()
|
.long()
|
||||||
.view(labels_view)
|
.view(labels_view) # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel
|
||||||
)
|
)
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
Loading…
Reference in New Issue
Block a user