add optional target tensor squeezing to pytorch trainer

This commit is contained in:
Yinon Polak 2023-03-21 13:20:54 +02:00
parent 97339e14cf
commit a80afc8f1b
2 changed files with 12 additions and 7 deletions

View File

@ -22,6 +22,7 @@ class PyTorchModelTrainer:
device: str, device: str,
init_model: Dict, init_model: Dict,
target_tensor_type: torch.dtype, target_tensor_type: torch.dtype,
squeeze_target_tensor: bool = False,
model_meta_data: Dict[str, Any] = {}, model_meta_data: Dict[str, Any] = {},
**kwargs **kwargs
): ):
@ -35,11 +36,14 @@ class PyTorchModelTrainer:
:param target_tensor_type: type of target tensor, for classification usually :param target_tensor_type: type of target tensor, for classification usually
torch.long, for regressor usually torch.float. torch.long, for regressor usually torch.float.
:param model_meta_data: Additional metadata about the model (optional). :param model_meta_data: Additional metadata about the model (optional).
:param squeeze_target_tensor: controls the target shape, used for loss functions
that requires 0D or 1D.
:param max_iters: The number of training iterations to run. :param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs. self.optimizer.step(). used to calculate n_epochs.
:param batch_size: The size of the batches to use during training. :param batch_size: The size of the batches to use during training.
:param max_n_eval_batches: The maximum number batches to use for evaluation. :param max_n_eval_batches: The maximum number batches to use for evaluation.
""" """
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
@ -50,6 +54,7 @@ class PyTorchModelTrainer:
self.max_iters: int = kwargs.get("max_iters", 100) self.max_iters: int = kwargs.get("max_iters", 100)
self.batch_size: int = kwargs.get("batch_size", 64) self.batch_size: int = kwargs.get("batch_size", 64)
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
self.squeeze_target_tensor = squeeze_target_tensor
if init_model: if init_model:
self.load_from_checkpoint(init_model) self.load_from_checkpoint(init_model)
@ -124,15 +129,14 @@ class PyTorchModelTrainer:
""" """
data_loader_dictionary = {} data_loader_dictionary = {}
for split in ["train", "test"]: for split in ["train", "test"]:
labels_shape = data_dictionary[f"{split}_labels"].shape x = torch.from_numpy(data_dictionary[f"{split}_features"].values).float()
labels_view = (labels_shape[0], 1) if labels_shape[1] == 1 else labels_shape y = torch.from_numpy(data_dictionary[f"{split}_labels"].values)\
dataset = TensorDataset(
torch.from_numpy(data_dictionary[f"{split}_features"].values).float(),
torch.from_numpy(data_dictionary[f"{split}_labels"].values)
.to(self.target_tensor_type) .to(self.target_tensor_type)
.view(labels_view)
)
if self.squeeze_target_tensor:
y = y.squeeze()
dataset = TensorDataset(x, y)
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,
batch_size=self.batch_size, batch_size=self.batch_size,

View File

@ -73,6 +73,7 @@ class PyTorchMLPClassifier(PyTorchClassifier):
device=self.device, device=self.device,
init_model=init_model, init_model=init_model,
target_tensor_type=torch.long, target_tensor_type=torch.long,
squeeze_target_tensor=True,
**self.trainer_kwargs, **self.trainer_kwargs,
) )
trainer.fit(data_dictionary) trainer.fit(data_dictionary)