add pytorch data convertor

This commit is contained in:
Yinon Polak
2023-04-03 15:19:10 +03:00
parent 5a7ca35c6b
commit bd3b70293f
9 changed files with 168 additions and 40 deletions

View File

@@ -9,11 +9,13 @@ import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface
logger = logging.getLogger(__name__)
class PyTorchModelTrainer:
class PyTorchModelTrainer(PyTorchTrainerInterface):
def __init__(
self,
model: nn.Module,
@@ -21,8 +23,7 @@ class PyTorchModelTrainer:
criterion: nn.Module,
device: str,
init_model: Dict,
target_tensor_type: torch.dtype,
squeeze_target_tensor: bool = False,
data_convertor: PyTorchDataConvertor,
model_meta_data: Dict[str, Any] = {},
**kwargs
):
@@ -33,11 +34,7 @@ class PyTorchModelTrainer:
:param device: The device to use for training (e.g. 'cpu', 'cuda').
:param init_model: A dictionary containing the initial model/optimizer
state_dict and model_meta_data saved by self.save() method.
:param target_tensor_type: type of target tensor, for classification usually
torch.long, for regressor usually torch.float.
:param model_meta_data: Additional metadata about the model (optional).
:param squeeze_target_tensor: controls the target shape, used for loss functions
that requires 0D or 1D.
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs.
@@ -49,11 +46,10 @@ class PyTorchModelTrainer:
self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device
self.target_tensor_type = target_tensor_type
self.max_iters: int = kwargs.get("max_iters", 100)
self.batch_size: int = kwargs.get("batch_size", 64)
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
self.squeeze_target_tensor = squeeze_target_tensor
self.data_convertor = data_convertor
if init_model:
self.load_from_checkpoint(init_model)
@@ -81,9 +77,12 @@ class PyTorchModelTrainer:
# training
losses = []
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
xb = xb.to(self.device)
yb = yb.to(self.device)
for tensor in batch_data:
tensor.to(self.device)
xb = batch_data[:-1]
yb = batch_data[-1]
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
@@ -115,14 +114,16 @@ class PyTorchModelTrainer:
self.model.eval()
n_batches = 0
losses = []
for i, batch in enumerate(data_loader_dictionary[split]):
for i, batch_data in enumerate(data_loader_dictionary[split]):
if max_n_eval_batches and i > max_n_eval_batches:
n_batches += 1
break
xb, yb = batch
xb = xb.to(self.device)
yb = yb.to(self.device)
for tensor in batch_data:
tensor.to(self.device)
xb = batch_data[:-1]
yb = batch_data[-1]
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
losses.append(loss.item())
@@ -140,14 +141,9 @@ class PyTorchModelTrainer:
"""
data_loader_dictionary = {}
for split in splits:
x = torch.from_numpy(data_dictionary[f"{split}_features"].values).float()
y = torch.from_numpy(data_dictionary[f"{split}_labels"].values)\
.to(self.target_tensor_type)
if self.squeeze_target_tensor:
y = y.squeeze()
dataset = TensorDataset(x, y)
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"])
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"])
dataset = TensorDataset(*x, *y)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
@@ -186,7 +182,7 @@ class PyTorchModelTrainer:
"model_meta_data": self.model_meta_data,
}, path)
def load_from_file(self, path: Path):
def load(self, path: Path):
checkpoint = torch.load(path)
return self.load_from_checkpoint(checkpoint)