remove train loss calculation from estimate_loss

This commit is contained in:
Yinon Polak 2023-03-13 00:17:34 +02:00
parent 523a58d3d6
commit b927c9dc01

View File

@ -66,14 +66,9 @@ class PyTorchModelTrainer:
n_iters=self.max_iters n_iters=self.max_iters
) )
for epoch in range(epochs): for epoch in range(epochs):
# evaluation
losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches)
logger.info(
f"epoch ({epoch}/{epochs}):"
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
)
# training # training
for batch_data in data_loaders_dictionary["train"]: losses = []
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data xb, yb = batch_data
xb = xb.to(self.device) xb = xb.to(self.device)
yb = yb.to(self.device) yb = yb.to(self.device)
@ -83,18 +78,25 @@ class PyTorchModelTrainer:
self.optimizer.zero_grad(set_to_none=True) self.optimizer.zero_grad(set_to_none=True)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
losses.append(loss.item())
train_loss = sum(losses) / len(losses)
# evaluation
test_loss = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches, "test")
logger.info(
f"epoch ({epoch}/{epochs}):"
f" train loss {train_loss:.4f} ; test loss {test_loss:.4f}"
)
@torch.no_grad() @torch.no_grad()
def estimate_loss( def estimate_loss(
self, self,
data_loader_dictionary: Dict[str, DataLoader], data_loader_dictionary: Dict[str, DataLoader],
max_n_eval_batches: Optional[int] max_n_eval_batches: Optional[int],
) -> Dict[str, float]: split: str,
) -> float:
self.model.eval() self.model.eval()
loss_dictionary = {}
n_batches = 0 n_batches = 0
for split in ["train", "test"]:
losses = [] losses = []
for i, batch in enumerate(data_loader_dictionary[split]): for i, batch in enumerate(data_loader_dictionary[split]):
if max_n_eval_batches and i > max_n_eval_batches: if max_n_eval_batches and i > max_n_eval_batches:
@ -108,10 +110,8 @@ class PyTorchModelTrainer:
loss = self.criterion(yb_pred, yb) loss = self.criterion(yb_pred, yb)
losses.append(loss.item()) losses.append(loss.item())
loss_dictionary[split] = sum(losses) / len(losses)
self.model.train() self.model.train()
return loss_dictionary return sum(losses) / len(losses)
def create_data_loaders_dictionary( def create_data_loaders_dictionary(
self, self,