use one iteration on all test and train data for evaluation
This commit is contained in:
parent
8a9f2aedbb
commit
1cf0e7be24
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -21,7 +21,7 @@ class PyTorchModelTrainer:
|
|||||||
device: str,
|
device: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
max_iters: int,
|
max_iters: int,
|
||||||
eval_iters: int,
|
max_n_eval_batches: int,
|
||||||
init_model: Dict,
|
init_model: Dict,
|
||||||
model_meta_data: Dict[str, Any] = {},
|
model_meta_data: Dict[str, Any] = {},
|
||||||
):
|
):
|
||||||
@ -34,7 +34,7 @@ class PyTorchModelTrainer:
|
|||||||
:param max_iters: The number of training iterations to run.
|
:param max_iters: The number of training iterations to run.
|
||||||
iteration here refers to the number of times we call
|
iteration here refers to the number of times we call
|
||||||
self.optimizer.step(). used to calculate n_epochs.
|
self.optimizer.step(). used to calculate n_epochs.
|
||||||
:param eval_iters: The number of iterations used to estimate the loss.
|
:param max_n_eval_batches: The maximum number batches to use for evaluation.
|
||||||
:param init_model: A dictionary containing the initial model/optimizer
|
:param init_model: A dictionary containing the initial model/optimizer
|
||||||
state_dict and model_meta_data saved by self.save() method.
|
state_dict and model_meta_data saved by self.save() method.
|
||||||
:param model_meta_data: Additional metadata about the model (optional).
|
:param model_meta_data: Additional metadata about the model (optional).
|
||||||
@ -46,7 +46,7 @@ class PyTorchModelTrainer:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.max_iters = max_iters
|
self.max_iters = max_iters
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.eval_iters = eval_iters
|
self.max_n_eval_batches = max_n_eval_batches
|
||||||
|
|
||||||
if init_model:
|
if init_model:
|
||||||
self.load_from_checkpoint(init_model)
|
self.load_from_checkpoint(init_model)
|
||||||
@ -67,7 +67,7 @@ class PyTorchModelTrainer:
|
|||||||
)
|
)
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
# evaluation
|
# evaluation
|
||||||
losses = self.estimate_loss(data_loaders_dictionary, data_dictionary)
|
losses = self.estimate_loss(data_loaders_dictionary, self.max_n_eval_batches)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"epoch ({epoch}/{epochs}):"
|
f"epoch ({epoch}/{epochs}):"
|
||||||
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
|
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
|
||||||
@ -88,27 +88,27 @@ class PyTorchModelTrainer:
|
|||||||
def estimate_loss(
|
def estimate_loss(
|
||||||
self,
|
self,
|
||||||
data_loader_dictionary: Dict[str, DataLoader],
|
data_loader_dictionary: Dict[str, DataLoader],
|
||||||
data_dictionary: Dict[str, pd.DataFrame]
|
max_n_eval_batches: Optional[int]
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
epochs = self.calc_n_epochs(
|
|
||||||
n_obs=len(data_dictionary["test_features"]),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
n_iters=self.eval_iters
|
|
||||||
)
|
|
||||||
loss_dictionary = {}
|
loss_dictionary = {}
|
||||||
|
n_batches = 0
|
||||||
for split in ["train", "test"]:
|
for split in ["train", "test"]:
|
||||||
losses = torch.zeros(epochs)
|
losses = []
|
||||||
for i, batch in enumerate(data_loader_dictionary[split]):
|
for i, batch in enumerate(data_loader_dictionary[split]):
|
||||||
|
if max_n_eval_batches and i > max_n_eval_batches:
|
||||||
|
n_batches += 1
|
||||||
|
break
|
||||||
|
|
||||||
xb, yb = batch
|
xb, yb = batch
|
||||||
xb = xb.to(self.device)
|
xb = xb.to(self.device)
|
||||||
yb = yb.to(self.device)
|
yb = yb.to(self.device)
|
||||||
yb_pred = self.model(xb)
|
yb_pred = self.model(xb)
|
||||||
loss = self.criterion(yb_pred, yb)
|
loss = self.criterion(yb_pred, yb)
|
||||||
losses[i] = loss.item()
|
losses.append(loss.item())
|
||||||
|
|
||||||
loss_dictionary[split] = losses.mean().item()
|
loss_dictionary[split] = sum(losses) / len(losses)
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
return loss_dictionary
|
return loss_dictionary
|
||||||
|
@ -39,7 +39,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
self.max_iters = model_training_parameters.get("max_iters", 100)
|
self.max_iters = model_training_parameters.get("max_iters", 100)
|
||||||
self.batch_size = model_training_parameters.get("batch_size", 64)
|
self.batch_size = model_training_parameters.get("batch_size", 64)
|
||||||
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
|
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
|
||||||
self.eval_iters = model_training_parameters.get("eval_iters", 10)
|
self.max_n_eval_batches = model_training_parameters.get("max_n_eval_batches", None)
|
||||||
self.class_name_to_index = None
|
self.class_name_to_index = None
|
||||||
self.index_to_class_name = None
|
self.index_to_class_name = None
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
max_iters=self.max_iters,
|
max_iters=self.max_iters,
|
||||||
eval_iters=self.eval_iters,
|
max_n_eval_batches=self.max_n_eval_batches,
|
||||||
init_model=init_model
|
init_model=init_model
|
||||||
)
|
)
|
||||||
trainer.fit(data_dictionary)
|
trainer.fit(data_dictionary)
|
||||||
|
Loading…
Reference in New Issue
Block a user