2023-03-06 14:16:45 +00:00
|
|
|
import logging
|
2023-03-21 10:29:05 +00:00
|
|
|
import math
|
2023-03-06 14:16:45 +00:00
|
|
|
from pathlib import Path
|
2023-03-28 11:40:23 +00:00
|
|
|
from typing import Any, Dict, List, Optional
|
2023-03-06 14:16:45 +00:00
|
|
|
|
2023-03-08 14:03:36 +00:00
|
|
|
import pandas as pd
|
2023-03-06 14:16:45 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2023-03-08 14:03:36 +00:00
|
|
|
from torch.optim import Optimizer
|
|
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
|
2023-03-06 14:16:45 +00:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class PyTorchModelTrainer:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model: nn.Module,
|
2023-03-06 18:15:36 +00:00
|
|
|
optimizer: Optimizer,
|
2023-03-06 14:16:45 +00:00
|
|
|
criterion: nn.Module,
|
|
|
|
device: str,
|
2023-03-08 16:36:44 +00:00
|
|
|
init_model: Dict,
|
2023-03-20 15:06:33 +00:00
|
|
|
target_tensor_type: torch.dtype,
|
2023-03-21 11:20:54 +00:00
|
|
|
squeeze_target_tensor: bool = False,
|
2023-03-08 16:36:44 +00:00
|
|
|
model_meta_data: Dict[str, Any] = {},
|
2023-03-19 12:38:49 +00:00
|
|
|
**kwargs
|
2023-03-06 14:16:45 +00:00
|
|
|
):
|
2023-03-09 09:14:54 +00:00
|
|
|
"""
|
|
|
|
:param model: The PyTorch model to be trained.
|
|
|
|
:param optimizer: The optimizer to use for training.
|
|
|
|
:param criterion: The loss function to use for training.
|
|
|
|
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
2023-03-19 12:38:49 +00:00
|
|
|
:param init_model: A dictionary containing the initial model/optimizer
|
|
|
|
state_dict and model_meta_data saved by self.save() method.
|
2023-03-20 15:06:33 +00:00
|
|
|
:param target_tensor_type: type of target tensor, for classification usually
|
|
|
|
torch.long, for regressor usually torch.float.
|
2023-03-19 12:38:49 +00:00
|
|
|
:param model_meta_data: Additional metadata about the model (optional).
|
2023-03-21 11:20:54 +00:00
|
|
|
:param squeeze_target_tensor: controls the target shape, used for loss functions
|
|
|
|
that requires 0D or 1D.
|
2023-03-09 09:14:54 +00:00
|
|
|
:param max_iters: The number of training iterations to run.
|
2023-03-09 11:01:04 +00:00
|
|
|
iteration here refers to the number of times we call
|
|
|
|
self.optimizer.step(). used to calculate n_epochs.
|
2023-03-19 12:38:49 +00:00
|
|
|
:param batch_size: The size of the batches to use during training.
|
2023-03-12 10:48:15 +00:00
|
|
|
:param max_n_eval_batches: The maximum number batches to use for evaluation.
|
2023-03-09 09:14:54 +00:00
|
|
|
"""
|
2023-03-06 14:16:45 +00:00
|
|
|
self.model = model
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.criterion = criterion
|
2023-03-08 16:36:44 +00:00
|
|
|
self.model_meta_data = model_meta_data
|
2023-03-06 14:16:45 +00:00
|
|
|
self.device = device
|
2023-03-20 15:06:33 +00:00
|
|
|
self.target_tensor_type = target_tensor_type
|
2023-03-19 12:38:49 +00:00
|
|
|
self.max_iters: int = kwargs.get("max_iters", 100)
|
|
|
|
self.batch_size: int = kwargs.get("batch_size", 64)
|
|
|
|
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
|
2023-03-21 11:20:54 +00:00
|
|
|
self.squeeze_target_tensor = squeeze_target_tensor
|
2023-03-06 14:16:45 +00:00
|
|
|
if init_model:
|
|
|
|
self.load_from_checkpoint(init_model)
|
|
|
|
|
2023-03-28 11:40:23 +00:00
|
|
|
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
|
2023-03-09 09:14:54 +00:00
|
|
|
"""
|
2023-03-28 11:40:23 +00:00
|
|
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
|
|
|
all the training and test data/labels.
|
|
|
|
:param splits: splits to use in training, splits must contain "train",
|
|
|
|
optional "test" could be added by setting freqai.data_split_parameters.test_size > 0
|
|
|
|
in the config file.
|
|
|
|
|
2023-03-09 11:25:20 +00:00
|
|
|
- Calculates the predicted output for the batch using the PyTorch model.
|
|
|
|
- Calculates the loss between the predicted and actual output using a loss function.
|
|
|
|
- Computes the gradients of the loss with respect to the model's parameters using
|
|
|
|
backpropagation.
|
|
|
|
- Updates the model's parameters using an optimizer.
|
2023-03-09 09:14:54 +00:00
|
|
|
"""
|
2023-03-28 11:40:23 +00:00
|
|
|
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
|
2023-03-06 14:16:45 +00:00
|
|
|
epochs = self.calc_n_epochs(
|
2023-03-09 11:29:11 +00:00
|
|
|
n_obs=len(data_dictionary["train_features"]),
|
2023-03-06 14:16:45 +00:00
|
|
|
batch_size=self.batch_size,
|
|
|
|
n_iters=self.max_iters
|
|
|
|
)
|
2023-03-28 12:18:10 +00:00
|
|
|
for epoch in range(1, epochs + 1):
|
2023-03-06 14:16:45 +00:00
|
|
|
# training
|
2023-03-12 22:17:34 +00:00
|
|
|
losses = []
|
|
|
|
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
|
2023-03-06 14:16:45 +00:00
|
|
|
xb, yb = batch_data
|
2023-03-06 17:14:54 +00:00
|
|
|
xb = xb.to(self.device)
|
2023-03-06 14:16:45 +00:00
|
|
|
yb = yb.to(self.device)
|
|
|
|
yb_pred = self.model(xb)
|
|
|
|
loss = self.criterion(yb_pred, yb)
|
|
|
|
|
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
loss.backward()
|
|
|
|
self.optimizer.step()
|
2023-03-12 22:17:34 +00:00
|
|
|
losses.append(loss.item())
|
|
|
|
train_loss = sum(losses) / len(losses)
|
2023-03-28 11:40:23 +00:00
|
|
|
log_message = f"epoch {epoch}/{epochs}: train loss {train_loss:.4f}"
|
2023-03-12 22:17:34 +00:00
|
|
|
|
|
|
|
# evaluation
|
2023-03-28 11:40:23 +00:00
|
|
|
if "test" in splits:
|
|
|
|
test_loss = self.estimate_loss(
|
|
|
|
data_loaders_dictionary,
|
|
|
|
self.max_n_eval_batches,
|
|
|
|
"test"
|
|
|
|
)
|
|
|
|
log_message += f" ; test loss {test_loss:.4f}"
|
|
|
|
|
|
|
|
logger.info(log_message)
|
2023-03-06 14:16:45 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def estimate_loss(
|
|
|
|
self,
|
|
|
|
data_loader_dictionary: Dict[str, DataLoader],
|
2023-03-12 22:17:34 +00:00
|
|
|
max_n_eval_batches: Optional[int],
|
|
|
|
split: str,
|
|
|
|
) -> float:
|
2023-03-06 14:16:45 +00:00
|
|
|
self.model.eval()
|
2023-03-12 10:48:15 +00:00
|
|
|
n_batches = 0
|
2023-03-12 22:17:34 +00:00
|
|
|
losses = []
|
|
|
|
for i, batch in enumerate(data_loader_dictionary[split]):
|
|
|
|
if max_n_eval_batches and i > max_n_eval_batches:
|
|
|
|
n_batches += 1
|
|
|
|
break
|
2023-03-06 14:16:45 +00:00
|
|
|
|
2023-03-12 22:17:34 +00:00
|
|
|
xb, yb = batch
|
|
|
|
xb = xb.to(self.device)
|
|
|
|
yb = yb.to(self.device)
|
|
|
|
yb_pred = self.model(xb)
|
|
|
|
loss = self.criterion(yb_pred, yb)
|
|
|
|
losses.append(loss.item())
|
2023-03-06 14:16:45 +00:00
|
|
|
|
|
|
|
self.model.train()
|
2023-03-12 22:17:34 +00:00
|
|
|
return sum(losses) / len(losses)
|
2023-03-06 14:16:45 +00:00
|
|
|
|
|
|
|
def create_data_loaders_dictionary(
|
|
|
|
self,
|
2023-03-28 11:40:23 +00:00
|
|
|
data_dictionary: Dict[str, pd.DataFrame],
|
|
|
|
splits: List[str]
|
2023-03-06 14:16:45 +00:00
|
|
|
) -> Dict[str, DataLoader]:
|
2023-03-09 09:21:10 +00:00
|
|
|
"""
|
|
|
|
Converts the input data to PyTorch tensors using a data loader.
|
|
|
|
"""
|
2023-03-06 14:16:45 +00:00
|
|
|
data_loader_dictionary = {}
|
2023-03-28 11:40:23 +00:00
|
|
|
for split in splits:
|
2023-03-21 11:20:54 +00:00
|
|
|
x = torch.from_numpy(data_dictionary[f"{split}_features"].values).float()
|
|
|
|
y = torch.from_numpy(data_dictionary[f"{split}_labels"].values)\
|
2023-03-20 15:06:33 +00:00
|
|
|
.to(self.target_tensor_type)
|
2023-03-06 15:50:02 +00:00
|
|
|
|
2023-03-21 11:20:54 +00:00
|
|
|
if self.squeeze_target_tensor:
|
|
|
|
y = y.squeeze()
|
|
|
|
|
|
|
|
dataset = TensorDataset(x, y)
|
2023-03-06 14:16:45 +00:00
|
|
|
data_loader = DataLoader(
|
|
|
|
dataset,
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
shuffle=True,
|
|
|
|
drop_last=True,
|
|
|
|
num_workers=0,
|
|
|
|
)
|
|
|
|
data_loader_dictionary[split] = data_loader
|
|
|
|
|
|
|
|
return data_loader_dictionary
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
|
2023-03-09 09:21:10 +00:00
|
|
|
"""
|
|
|
|
Calculates the number of epochs required to reach the maximum number
|
|
|
|
of iterations specified in the model training parameters.
|
2023-03-21 10:29:05 +00:00
|
|
|
|
|
|
|
the motivation here is that `max_iters` is easier to optimize and keep stable,
|
|
|
|
across different n_obs - the number of data points.
|
2023-03-09 09:21:10 +00:00
|
|
|
"""
|
2023-03-09 11:29:11 +00:00
|
|
|
|
2023-03-21 10:29:05 +00:00
|
|
|
n_batches = math.ceil(n_obs // batch_size)
|
|
|
|
epochs = math.ceil(n_iters // n_batches)
|
2023-03-06 14:16:45 +00:00
|
|
|
return epochs
|
|
|
|
|
|
|
|
def save(self, path: Path):
|
2023-03-09 11:25:20 +00:00
|
|
|
"""
|
|
|
|
- Saving any nn.Module state_dict
|
|
|
|
- Saving model_meta_data, this dict should contain any additional data that the
|
|
|
|
user needs to store. e.g class_names for classification models.
|
|
|
|
"""
|
|
|
|
|
2023-03-06 14:16:45 +00:00
|
|
|
torch.save({
|
2023-03-09 11:29:11 +00:00
|
|
|
"model_state_dict": self.model.state_dict(),
|
|
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
|
|
"model_meta_data": self.model_meta_data,
|
2023-03-06 14:16:45 +00:00
|
|
|
}, path)
|
|
|
|
|
|
|
|
def load_from_file(self, path: Path):
|
|
|
|
checkpoint = torch.load(path)
|
|
|
|
return self.load_from_checkpoint(checkpoint)
|
|
|
|
|
|
|
|
def load_from_checkpoint(self, checkpoint: Dict):
|
2023-03-09 11:25:20 +00:00
|
|
|
"""
|
|
|
|
when using continual_learning, DataDrawer will load the dictionary
|
|
|
|
(containing state dicts and model_meta_data) by calling torch.load(path).
|
|
|
|
you can access this dict from any class that inherits IFreqaiModel by calling
|
|
|
|
get_init_model method.
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
2023-03-08 16:36:44 +00:00
|
|
|
self.model_meta_data = checkpoint["model_meta_data"]
|
2023-03-06 14:16:45 +00:00
|
|
|
return self
|