164 lines
5.6 KiB
Python
164 lines
5.6 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from typing import Any, Dict
|
|
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PyTorchModelTrainer:
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
optimizer: Optimizer,
|
|
criterion: nn.Module,
|
|
device: str,
|
|
batch_size: int,
|
|
max_iters: int,
|
|
eval_iters: int,
|
|
init_model: Dict,
|
|
model_meta_data: Dict[str, Any] = {},
|
|
):
|
|
"""
|
|
A class for training PyTorch models.
|
|
|
|
:param model: The PyTorch model to be trained.
|
|
:param optimizer: The optimizer to use for training.
|
|
:param criterion: The loss function to use for training.
|
|
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
|
:param batch_size: The size of the batches to use during training.
|
|
:param max_iters: The number of training iterations to run.
|
|
iteration here refers to the number of times we call
|
|
self.optimizer.step() . used to calculate n_epochs.
|
|
:param eval_iters: The number of iterations used to estimate the loss.
|
|
:param init_model: A dictionary containing the initial model parameters.
|
|
:param model_meta_data: Additional metadata about the model (optional).
|
|
"""
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
self.criterion = criterion
|
|
self.model_meta_data = model_meta_data
|
|
self.device = device
|
|
self.max_iters = max_iters
|
|
self.batch_size = batch_size
|
|
self.eval_iters = eval_iters
|
|
|
|
if init_model:
|
|
self.load_from_checkpoint(init_model)
|
|
|
|
def fit(self, data_dictionary: Dict[str, pd.DataFrame]):
|
|
"""
|
|
general training loop:
|
|
- converting data to tensors
|
|
- calculating n_epochs
|
|
-
|
|
"""
|
|
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary)
|
|
epochs = self.calc_n_epochs(
|
|
n_obs=len(data_dictionary['train_features']),
|
|
batch_size=self.batch_size,
|
|
n_iters=self.max_iters
|
|
)
|
|
for epoch in range(epochs):
|
|
# evaluation
|
|
losses = self.estimate_loss(data_loaders_dictionary, data_dictionary)
|
|
logger.info(
|
|
f"epoch ({epoch}/{epochs}):"
|
|
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
|
|
)
|
|
# training
|
|
for batch_data in data_loaders_dictionary['train']:
|
|
xb, yb = batch_data
|
|
xb = xb.to(self.device)
|
|
yb = yb.to(self.device)
|
|
yb_pred = self.model(xb)
|
|
loss = self.criterion(yb_pred, yb)
|
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
@torch.no_grad()
|
|
def estimate_loss(
|
|
self,
|
|
data_loader_dictionary: Dict[str, DataLoader],
|
|
data_dictionary: Dict[str, pd.DataFrame]
|
|
) -> Dict[str, float]:
|
|
|
|
self.model.eval()
|
|
epochs = self.calc_n_epochs(
|
|
n_obs=len(data_dictionary['test_features']),
|
|
batch_size=self.batch_size,
|
|
n_iters=self.eval_iters
|
|
)
|
|
loss_dictionary = {}
|
|
for split in ['train', 'test']:
|
|
losses = torch.zeros(epochs)
|
|
for i, batch in enumerate(data_loader_dictionary[split]):
|
|
xb, yb = batch
|
|
xb = xb.to(self.device)
|
|
yb = yb.to(self.device)
|
|
yb_pred = self.model(xb)
|
|
loss = self.criterion(yb_pred, yb)
|
|
losses[i] = loss.item()
|
|
|
|
loss_dictionary[split] = losses.mean().item()
|
|
|
|
self.model.train()
|
|
return loss_dictionary
|
|
|
|
def create_data_loaders_dictionary(
|
|
self,
|
|
data_dictionary: Dict[str, pd.DataFrame]
|
|
) -> Dict[str, DataLoader]:
|
|
data_loader_dictionary = {}
|
|
for split in ['train', 'test']:
|
|
labels_shape = data_dictionary[f'{split}_labels'].shape
|
|
labels_view = labels_shape[0] if labels_shape[1] == 1 else labels_shape
|
|
dataset = TensorDataset(
|
|
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
|
|
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
|
|
.long()
|
|
.view(labels_view)
|
|
)
|
|
|
|
data_loader = DataLoader(
|
|
dataset,
|
|
batch_size=self.batch_size,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
num_workers=0,
|
|
)
|
|
data_loader_dictionary[split] = data_loader
|
|
|
|
return data_loader_dictionary
|
|
|
|
@staticmethod
|
|
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
|
|
n_batches = n_obs // batch_size
|
|
epochs = n_iters // n_batches
|
|
return epochs
|
|
|
|
def save(self, path: Path):
|
|
torch.save({
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'model_meta_data': self.model_meta_data,
|
|
}, path)
|
|
|
|
def load_from_file(self, path: Path):
|
|
checkpoint = torch.load(path)
|
|
return self.load_from_checkpoint(checkpoint)
|
|
|
|
def load_from_checkpoint(self, checkpoint: Dict):
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
self.model_meta_data = checkpoint["model_meta_data"]
|
|
return self
|