use one iteration on all test and train data for evaluation

This commit is contained in:
Yinon Polak 2023-03-12 12:48:15 +02:00
parent 8a9f2aedbb
commit 1cf0e7be24
2 changed files with 16 additions and 16 deletions

View File

@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, Optional
import pandas as pd
import torch
@ -21,7 +21,7 @@ class PyTorchModelTrainer:
device: str,
batch_size: int,
max_iters: int,
eval_iters: int,
max_n_eval_batches: int,
init_model: Dict,
model_meta_data: Dict[str, Any] = {},
):
@ -34,7 +34,7 @@ class PyTorchModelTrainer:
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs.
:param eval_iters: The number of iterations used to estimate the loss.
:param max_n_eval_batches: The maximum number batches to use for evaluation.
:param init_model: A dictionary containing the initial model/optimizer
state_dict and model_meta_data saved by self.save() method.
:param model_meta_data: Additional metadata about the model (optional).
@ -46,7 +46,7 @@ class PyTorchModelTrainer:
self.device = device
self.max_iters = max_iters
self.batch_size = batch_size
self.eval_iters = eval_iters
self.max_n_eval_batches = max_n_eval_batches
if init_model:
self.load_from_checkpoint(init_model)
@ -67,7 +67,7 @@ class PyTorchModelTrainer:
)
for epoch in range(epochs):
# evaluation
losses = self.estimate_loss(data_loaders_dictionary, data_dictionary)
losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches)
logger.info(
f"epoch ({epoch}/{epochs}):"
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
@ -88,27 +88,27 @@ class PyTorchModelTrainer:
def estimate_loss(
self,
data_loader_dictionary: Dict[str, DataLoader],
data_dictionary: Dict[str, pd.DataFrame]
max_n_eval_batches: Optional[int]
) -> Dict[str, float]:
self.model.eval()
epochs = self.calc_n_epochs(
n_obs=len(data_dictionary["test_features"]),
batch_size=self.batch_size,
n_iters=self.eval_iters
)
loss_dictionary = {}
n_batches = 0
for split in ["train", "test"]:
losses = torch.zeros(epochs)
losses = []
for i, batch in enumerate(data_loader_dictionary[split]):
if max_n_eval_batches and i > max_n_eval_batches:
n_batches += 1
break
xb, yb = batch
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
losses[i] = loss.item()
losses.append(loss.item())
loss_dictionary[split] = losses.mean().item()
loss_dictionary[split] = sum(losses) / len(losses)
self.model.train()
return loss_dictionary

View File

@ -39,7 +39,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
self.max_iters = model_training_parameters.get("max_iters", 100)
self.batch_size = model_training_parameters.get("batch_size", 64)
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
self.eval_iters = model_training_parameters.get("eval_iters", 10)
self.max_n_eval_batches = model_training_parameters.get("max_n_eval_batches", None)
self.class_name_to_index = None
self.index_to_class_name = None
@ -79,7 +79,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
device=self.device,
batch_size=self.batch_size,
max_iters=self.max_iters,
eval_iters=self.eval_iters,
max_n_eval_batches=self.max_n_eval_batches,
init_model=init_model
)
trainer.fit(data_dictionary)