stable/freqtrade/freqai/base_models/PyTorchModelTrainer.py

164 lines
5.6 KiB
Python

import logging
from pathlib import Path
from typing import Any, Dict
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset
logger = logging.getLogger(__name__)
class PyTorchModelTrainer:
def __init__(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: nn.Module,
device: str,
batch_size: int,
max_iters: int,
eval_iters: int,
init_model: Dict,
model_meta_data: Dict[str, Any] = {},
):
"""
A class for training PyTorch models.
:param model: The PyTorch model to be trained.
:param optimizer: The optimizer to use for training.
:param criterion: The loss function to use for training.
:param device: The device to use for training (e.g. 'cpu', 'cuda').
:param batch_size: The size of the batches to use during training.
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call
self.optimizer.step() . used to calculate n_epochs.
:param eval_iters: The number of iterations used to estimate the loss.
:param init_model: A dictionary containing the initial model parameters.
:param model_meta_data: Additional metadata about the model (optional).
"""
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device
self.max_iters = max_iters
self.batch_size = batch_size
self.eval_iters = eval_iters
if init_model:
self.load_from_checkpoint(init_model)
def fit(self, data_dictionary: Dict[str, pd.DataFrame]):
"""
general training loop:
- converting data to tensors
- calculating n_epochs
-
"""
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary)
epochs = self.calc_n_epochs(
n_obs=len(data_dictionary['train_features']),
batch_size=self.batch_size,
n_iters=self.max_iters
)
for epoch in range(epochs):
# evaluation
losses = self.estimate_loss(data_loaders_dictionary, data_dictionary)
logger.info(
f"epoch ({epoch}/{epochs}):"
f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}"
)
# training
for batch_data in data_loaders_dictionary['train']:
xb, yb = batch_data
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
@torch.no_grad()
def estimate_loss(
self,
data_loader_dictionary: Dict[str, DataLoader],
data_dictionary: Dict[str, pd.DataFrame]
) -> Dict[str, float]:
self.model.eval()
epochs = self.calc_n_epochs(
n_obs=len(data_dictionary['test_features']),
batch_size=self.batch_size,
n_iters=self.eval_iters
)
loss_dictionary = {}
for split in ['train', 'test']:
losses = torch.zeros(epochs)
for i, batch in enumerate(data_loader_dictionary[split]):
xb, yb = batch
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
losses[i] = loss.item()
loss_dictionary[split] = losses.mean().item()
self.model.train()
return loss_dictionary
def create_data_loaders_dictionary(
self,
data_dictionary: Dict[str, pd.DataFrame]
) -> Dict[str, DataLoader]:
data_loader_dictionary = {}
for split in ['train', 'test']:
labels_shape = data_dictionary[f'{split}_labels'].shape
labels_view = labels_shape[0] if labels_shape[1] == 1 else labels_shape
dataset = TensorDataset(
torch.from_numpy(data_dictionary[f'{split}_features'].values).float(),
torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values)
.long()
.view(labels_view)
)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
num_workers=0,
)
data_loader_dictionary[split] = data_loader
return data_loader_dictionary
@staticmethod
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
n_batches = n_obs // batch_size
epochs = n_iters // n_batches
return epochs
def save(self, path: Path):
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'model_meta_data': self.model_meta_data,
}, path)
def load_from_file(self, path: Path):
checkpoint = torch.load(path)
return self.load_from_checkpoint(checkpoint)
def load_from_checkpoint(self, checkpoint: Dict):
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.model_meta_data = checkpoint["model_meta_data"]
return self