add pytorch regressor example
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -20,6 +20,7 @@ class PyTorchModelTrainer:
|
||||
criterion: nn.Module,
|
||||
device: str,
|
||||
init_model: Dict,
|
||||
target_tensor_type: torch.dtype,
|
||||
model_meta_data: Dict[str, Any] = {},
|
||||
**kwargs
|
||||
):
|
||||
@@ -30,6 +31,8 @@ class PyTorchModelTrainer:
|
||||
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
||||
:param init_model: A dictionary containing the initial model/optimizer
|
||||
state_dict and model_meta_data saved by self.save() method.
|
||||
:param target_tensor_type: type of target tensor, for classification usually
|
||||
torch.long, for regressor usually torch.float.
|
||||
:param model_meta_data: Additional metadata about the model (optional).
|
||||
:param max_iters: The number of training iterations to run.
|
||||
iteration here refers to the number of times we call
|
||||
@@ -42,6 +45,7 @@ class PyTorchModelTrainer:
|
||||
self.criterion = criterion
|
||||
self.model_meta_data = model_meta_data
|
||||
self.device = device
|
||||
self.target_tensor_type = target_tensor_type
|
||||
self.max_iters: int = kwargs.get("max_iters", 100)
|
||||
self.batch_size: int = kwargs.get("batch_size", 64)
|
||||
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
|
||||
@@ -123,8 +127,8 @@ class PyTorchModelTrainer:
|
||||
labels_view = labels_shape[0] if labels_shape[1] == 1 else labels_shape
|
||||
dataset = TensorDataset(
|
||||
torch.from_numpy(data_dictionary[f"{split}_features"].values).float(),
|
||||
torch.from_numpy(data_dictionary[f"{split}_labels"].astype(float).values)
|
||||
.long()
|
||||
torch.from_numpy(data_dictionary[f"{split}_labels"].values)
|
||||
.to(self.target_tensor_type)
|
||||
.view(labels_view)
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user