add pytorch regressor example

This commit is contained in:
Yinon Polak
2023-03-20 17:06:33 +02:00
parent 601c37f862
commit 54db239175
5 changed files with 137 additions and 14 deletions

View File

@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Type
import pandas as pd
import torch
@@ -20,6 +20,7 @@ class PyTorchModelTrainer:
criterion: nn.Module,
device: str,
init_model: Dict,
target_tensor_type: torch.dtype,
model_meta_data: Dict[str, Any] = {},
**kwargs
):
@@ -30,6 +31,8 @@ class PyTorchModelTrainer:
:param device: The device to use for training (e.g. 'cpu', 'cuda').
:param init_model: A dictionary containing the initial model/optimizer
state_dict and model_meta_data saved by self.save() method.
:param target_tensor_type: type of target tensor, for classification usually
torch.long, for regressor usually torch.float.
:param model_meta_data: Additional metadata about the model (optional).
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call
@@ -42,6 +45,7 @@ class PyTorchModelTrainer:
self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device
self.target_tensor_type = target_tensor_type
self.max_iters: int = kwargs.get("max_iters", 100)
self.batch_size: int = kwargs.get("batch_size", 64)
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
@@ -123,8 +127,8 @@ class PyTorchModelTrainer:
labels_view = labels_shape[0] if labels_shape[1] == 1 else labels_shape
dataset = TensorDataset(
torch.from_numpy(data_dictionary[f"{split}_features"].values).float(),
torch.from_numpy(data_dictionary[f"{split}_labels"].astype(float).values)
.long()
torch.from_numpy(data_dictionary[f"{split}_labels"].values)
.to(self.target_tensor_type)
.view(labels_view)
)