remove train loss calculation from estimate_loss

This commit is contained in:
Yinon Polak 2023-03-13 00:17:34 +02:00
parent 523a58d3d6
commit b927c9dc01

View File

@ -66,14 +66,9 @@ class PyTorchModelTrainer:
n_iters=self.max_iters
)
for epoch in range(epochs):
# evaluation
losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches)
logger.info(
f"epoch ({epoch}/{epochs}):"
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
)
# training
for batch_data in data_loaders_dictionary["train"]:
losses = []
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
xb = xb.to(self.device)
yb = yb.to(self.device)
@ -83,35 +78,40 @@ class PyTorchModelTrainer:
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
losses.append(loss.item())
train_loss = sum(losses) / len(losses)
# evaluation
test_loss = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches, "test")
logger.info(
f"epoch ({epoch}/{epochs}):"
f" train loss {train_loss:.4f} ; test loss {test_loss:.4f}"
)
@torch.no_grad()
def estimate_loss(
self,
data_loader_dictionary: Dict[str, DataLoader],
max_n_eval_batches: Optional[int]
) -> Dict[str, float]:
max_n_eval_batches: Optional[int],
split: str,
) -> float:
self.model.eval()
loss_dictionary = {}
n_batches = 0
for split in ["train", "test"]:
losses = []
for i, batch in enumerate(data_loader_dictionary[split]):
if max_n_eval_batches and i > max_n_eval_batches:
n_batches += 1
break
losses = []
for i, batch in enumerate(data_loader_dictionary[split]):
if max_n_eval_batches and i > max_n_eval_batches:
n_batches += 1
break
xb, yb = batch
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
losses.append(loss.item())
loss_dictionary[split] = sum(losses) / len(losses)
xb, yb = batch
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
losses.append(loss.item())
self.model.train()
return loss_dictionary
return sum(losses) / len(losses)
def create_data_loaders_dictionary(
self,