add optional target tensor squeezing to pytorch trainer
This commit is contained in:
parent
97339e14cf
commit
a80afc8f1b
@ -22,6 +22,7 @@ class PyTorchModelTrainer:
|
|||||||
device: str,
|
device: str,
|
||||||
init_model: Dict,
|
init_model: Dict,
|
||||||
target_tensor_type: torch.dtype,
|
target_tensor_type: torch.dtype,
|
||||||
|
squeeze_target_tensor: bool = False,
|
||||||
model_meta_data: Dict[str, Any] = {},
|
model_meta_data: Dict[str, Any] = {},
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@ -35,11 +36,14 @@ class PyTorchModelTrainer:
|
|||||||
:param target_tensor_type: type of target tensor, for classification usually
|
:param target_tensor_type: type of target tensor, for classification usually
|
||||||
torch.long, for regressor usually torch.float.
|
torch.long, for regressor usually torch.float.
|
||||||
:param model_meta_data: Additional metadata about the model (optional).
|
:param model_meta_data: Additional metadata about the model (optional).
|
||||||
|
:param squeeze_target_tensor: controls the target shape, used for loss functions
|
||||||
|
that requires 0D or 1D.
|
||||||
:param max_iters: The number of training iterations to run.
|
:param max_iters: The number of training iterations to run.
|
||||||
iteration here refers to the number of times we call
|
iteration here refers to the number of times we call
|
||||||
self.optimizer.step(). used to calculate n_epochs.
|
self.optimizer.step(). used to calculate n_epochs.
|
||||||
:param batch_size: The size of the batches to use during training.
|
:param batch_size: The size of the batches to use during training.
|
||||||
:param max_n_eval_batches: The maximum number batches to use for evaluation.
|
:param max_n_eval_batches: The maximum number batches to use for evaluation.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -50,6 +54,7 @@ class PyTorchModelTrainer:
|
|||||||
self.max_iters: int = kwargs.get("max_iters", 100)
|
self.max_iters: int = kwargs.get("max_iters", 100)
|
||||||
self.batch_size: int = kwargs.get("batch_size", 64)
|
self.batch_size: int = kwargs.get("batch_size", 64)
|
||||||
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
|
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
|
||||||
|
self.squeeze_target_tensor = squeeze_target_tensor
|
||||||
if init_model:
|
if init_model:
|
||||||
self.load_from_checkpoint(init_model)
|
self.load_from_checkpoint(init_model)
|
||||||
|
|
||||||
@ -124,15 +129,14 @@ class PyTorchModelTrainer:
|
|||||||
"""
|
"""
|
||||||
data_loader_dictionary = {}
|
data_loader_dictionary = {}
|
||||||
for split in ["train", "test"]:
|
for split in ["train", "test"]:
|
||||||
labels_shape = data_dictionary[f"{split}_labels"].shape
|
x = torch.from_numpy(data_dictionary[f"{split}_features"].values).float()
|
||||||
labels_view = (labels_shape[0], 1) if labels_shape[1] == 1 else labels_shape
|
y = torch.from_numpy(data_dictionary[f"{split}_labels"].values)\
|
||||||
dataset = TensorDataset(
|
|
||||||
torch.from_numpy(data_dictionary[f"{split}_features"].values).float(),
|
|
||||||
torch.from_numpy(data_dictionary[f"{split}_labels"].values)
|
|
||||||
.to(self.target_tensor_type)
|
.to(self.target_tensor_type)
|
||||||
.view(labels_view)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if self.squeeze_target_tensor:
|
||||||
|
y = y.squeeze()
|
||||||
|
|
||||||
|
dataset = TensorDataset(x, y)
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
|
@ -73,6 +73,7 @@ class PyTorchMLPClassifier(PyTorchClassifier):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
init_model=init_model,
|
init_model=init_model,
|
||||||
target_tensor_type=torch.long,
|
target_tensor_type=torch.long,
|
||||||
|
squeeze_target_tensor=True,
|
||||||
**self.trainer_kwargs,
|
**self.trainer_kwargs,
|
||||||
)
|
)
|
||||||
trainer.fit(data_dictionary)
|
trainer.fit(data_dictionary)
|
||||||
|
Loading…
Reference in New Issue
Block a user