reformat code

This commit is contained in:
Yinon Polak
2023-03-06 17:50:02 +02:00
parent 348a08f1c4
commit e6e747bcd8
2 changed files with 16 additions and 14 deletions

View File

@@ -69,7 +69,7 @@ class PyTorchModelTrainer:
self.model.eval()
epochs = self.calc_n_epochs(
n_obs=len(data_dictionary[f'test_features']),
n_obs=len(data_dictionary['test_features']),
batch_size=self.batch_size,
n_iters=self.eval_iters
)
@@ -101,8 +101,11 @@ class PyTorchModelTrainer:
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
.long()
.view(labels_view) # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel
.view(labels_view)
)
# todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes].
# need to resolve it per ClassifierModel
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,