remove train loss calculation from estimate_loss
This commit is contained in:
parent
523a58d3d6
commit
b927c9dc01
@ -66,14 +66,9 @@ class PyTorchModelTrainer:
|
|||||||
n_iters=self.max_iters
|
n_iters=self.max_iters
|
||||||
)
|
)
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
# evaluation
|
|
||||||
losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches)
|
|
||||||
logger.info(
|
|
||||||
f"epoch ({epoch}/{epochs}):"
|
|
||||||
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
|
|
||||||
)
|
|
||||||
# training
|
# training
|
||||||
for batch_data in data_loaders_dictionary["train"]:
|
losses = []
|
||||||
|
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
|
||||||
xb, yb = batch_data
|
xb, yb = batch_data
|
||||||
xb = xb.to(self.device)
|
xb = xb.to(self.device)
|
||||||
yb = yb.to(self.device)
|
yb = yb.to(self.device)
|
||||||
@ -83,35 +78,40 @@ class PyTorchModelTrainer:
|
|||||||
self.optimizer.zero_grad(set_to_none=True)
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
losses.append(loss.item())
|
||||||
|
train_loss = sum(losses) / len(losses)
|
||||||
|
|
||||||
|
# evaluation
|
||||||
|
test_loss = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches, "test")
|
||||||
|
logger.info(
|
||||||
|
f"epoch ({epoch}/{epochs}):"
|
||||||
|
f" train loss {train_loss:.4f} ; test loss {test_loss:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def estimate_loss(
|
def estimate_loss(
|
||||||
self,
|
self,
|
||||||
data_loader_dictionary: Dict[str, DataLoader],
|
data_loader_dictionary: Dict[str, DataLoader],
|
||||||
max_n_eval_batches: Optional[int]
|
max_n_eval_batches: Optional[int],
|
||||||
) -> Dict[str, float]:
|
split: str,
|
||||||
|
) -> float:
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
loss_dictionary = {}
|
|
||||||
n_batches = 0
|
n_batches = 0
|
||||||
for split in ["train", "test"]:
|
losses = []
|
||||||
losses = []
|
for i, batch in enumerate(data_loader_dictionary[split]):
|
||||||
for i, batch in enumerate(data_loader_dictionary[split]):
|
if max_n_eval_batches and i > max_n_eval_batches:
|
||||||
if max_n_eval_batches and i > max_n_eval_batches:
|
n_batches += 1
|
||||||
n_batches += 1
|
break
|
||||||
break
|
|
||||||
|
|
||||||
xb, yb = batch
|
xb, yb = batch
|
||||||
xb = xb.to(self.device)
|
xb = xb.to(self.device)
|
||||||
yb = yb.to(self.device)
|
yb = yb.to(self.device)
|
||||||
yb_pred = self.model(xb)
|
yb_pred = self.model(xb)
|
||||||
loss = self.criterion(yb_pred, yb)
|
loss = self.criterion(yb_pred, yb)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
loss_dictionary[split] = sum(losses) / len(losses)
|
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
return loss_dictionary
|
return sum(losses) / len(losses)
|
||||||
|
|
||||||
def create_data_loaders_dictionary(
|
def create_data_loaders_dictionary(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user