type hints fixes
This commit is contained in:
parent
8acdd0b47c
commit
5dd60eda36
@ -84,7 +84,7 @@ class PyTorchModelTrainer:
|
||||
loss = self.criterion(yb_pred, yb)
|
||||
losses[i] = loss.item()
|
||||
|
||||
loss_dictionary[split] = losses.mean()
|
||||
loss_dictionary[split] = losses.mean().item()
|
||||
|
||||
self.model.train()
|
||||
return loss_dictionary
|
||||
|
Loading…
Reference in New Issue
Block a user