use one iteration on all test and train data for evaluation

This commit is contained in:
Yinon Polak 2023-03-12 12:48:15 +02:00
parent 8a9f2aedbb
commit 1cf0e7be24
2 changed files with 16 additions and 16 deletions

View File

@ -1,6 +1,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict, Optional
import pandas as pd import pandas as pd
import torch import torch
@ -21,7 +21,7 @@ class PyTorchModelTrainer:
device: str, device: str,
batch_size: int, batch_size: int,
max_iters: int, max_iters: int,
eval_iters: int, max_n_eval_batches: int,
init_model: Dict, init_model: Dict,
model_meta_data: Dict[str, Any] = {}, model_meta_data: Dict[str, Any] = {},
): ):
@ -34,7 +34,7 @@ class PyTorchModelTrainer:
:param max_iters: The number of training iterations to run. :param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs. self.optimizer.step(). used to calculate n_epochs.
:param eval_iters: The number of iterations used to estimate the loss. :param max_n_eval_batches: The maximum number batches to use for evaluation.
:param init_model: A dictionary containing the initial model/optimizer :param init_model: A dictionary containing the initial model/optimizer
state_dict and model_meta_data saved by self.save() method. state_dict and model_meta_data saved by self.save() method.
:param model_meta_data: Additional metadata about the model (optional). :param model_meta_data: Additional metadata about the model (optional).
@ -46,7 +46,7 @@ class PyTorchModelTrainer:
self.device = device self.device = device
self.max_iters = max_iters self.max_iters = max_iters
self.batch_size = batch_size self.batch_size = batch_size
self.eval_iters = eval_iters self.max_n_eval_batches = max_n_eval_batches
if init_model: if init_model:
self.load_from_checkpoint(init_model) self.load_from_checkpoint(init_model)
@ -67,7 +67,7 @@ class PyTorchModelTrainer:
) )
for epoch in range(epochs): for epoch in range(epochs):
# evaluation # evaluation
losses = self.estimate_loss(data_loaders_dictionary, data_dictionary) losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches)
logger.info( logger.info(
f"epoch ({epoch}/{epochs}):" f"epoch ({epoch}/{epochs}):"
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}" f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
@ -88,27 +88,27 @@ class PyTorchModelTrainer:
def estimate_loss( def estimate_loss(
self, self,
data_loader_dictionary: Dict[str, DataLoader], data_loader_dictionary: Dict[str, DataLoader],
data_dictionary: Dict[str, pd.DataFrame] max_n_eval_batches: Optional[int]
) -> Dict[str, float]: ) -> Dict[str, float]:
self.model.eval() self.model.eval()
epochs = self.calc_n_epochs(
n_obs=len(data_dictionary["test_features"]),
batch_size=self.batch_size,
n_iters=self.eval_iters
)
loss_dictionary = {} loss_dictionary = {}
n_batches = 0
for split in ["train", "test"]: for split in ["train", "test"]:
losses = torch.zeros(epochs) losses = []
for i, batch in enumerate(data_loader_dictionary[split]): for i, batch in enumerate(data_loader_dictionary[split]):
if max_n_eval_batches and i > max_n_eval_batches:
n_batches += 1
break
xb, yb = batch xb, yb = batch
xb = xb.to(self.device) xb = xb.to(self.device)
yb = yb.to(self.device) yb = yb.to(self.device)
yb_pred = self.model(xb) yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb) loss = self.criterion(yb_pred, yb)
losses[i] = loss.item() losses.append(loss.item())
loss_dictionary[split] = losses.mean().item() loss_dictionary[split] = sum(losses) / len(losses)
self.model.train() self.model.train()
return loss_dictionary return loss_dictionary

View File

@ -39,7 +39,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
self.max_iters = model_training_parameters.get("max_iters", 100) self.max_iters = model_training_parameters.get("max_iters", 100)
self.batch_size = model_training_parameters.get("batch_size", 64) self.batch_size = model_training_parameters.get("batch_size", 64)
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4) self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
self.eval_iters = model_training_parameters.get("eval_iters", 10) self.max_n_eval_batches = model_training_parameters.get("max_n_eval_batches", None)
self.class_name_to_index = None self.class_name_to_index = None
self.index_to_class_name = None self.index_to_class_name = None
@ -79,7 +79,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
device=self.device, device=self.device,
batch_size=self.batch_size, batch_size=self.batch_size,
max_iters=self.max_iters, max_iters=self.max_iters,
eval_iters=self.eval_iters, max_n_eval_batches=self.max_n_eval_batches,
init_model=init_model init_model=init_model
) )
trainer.fit(data_dictionary) trainer.fit(data_dictionary)