unsqueeze target tensor when 1 dimensional

This commit is contained in:
Yinon Polak 2023-03-21 11:42:05 +02:00
parent 9906e7d646
commit 443263803c
1 changed files with 1 additions and 1 deletions

View File

@ -124,7 +124,7 @@ class PyTorchModelTrainer:
data_loader_dictionary = {}
for split in ["train", "test"]:
labels_shape = data_dictionary[f"{split}_labels"].shape
labels_view = labels_shape[0] if labels_shape[1] == 1 else labels_shape
labels_view = (labels_shape[0], 1) if labels_shape[1] == 1 else labels_shape
dataset = TensorDataset(
torch.from_numpy(data_dictionary[f"{split}_features"].values).float(),
torch.from_numpy(data_dictionary[f"{split}_labels"].values)