add optional target tensor squeezing to pytorch trainer
This commit is contained in:
		| @@ -22,6 +22,7 @@ class PyTorchModelTrainer: | ||||
|             device: str, | ||||
|             init_model: Dict, | ||||
|             target_tensor_type: torch.dtype, | ||||
|             squeeze_target_tensor: bool = False, | ||||
|             model_meta_data: Dict[str, Any] = {}, | ||||
|             **kwargs | ||||
|     ): | ||||
| @@ -35,11 +36,14 @@ class PyTorchModelTrainer: | ||||
|         :param target_tensor_type: type of target tensor, for classification usually | ||||
|             torch.long, for regressor usually torch.float. | ||||
|         :param model_meta_data: Additional metadata about the model (optional). | ||||
|         :param squeeze_target_tensor: controls the target shape, used for loss functions | ||||
|             that requires 0D or 1D. | ||||
|         :param max_iters: The number of training iterations to run. | ||||
|             iteration here refers to the number of times we call | ||||
|             self.optimizer.step(). used to calculate n_epochs. | ||||
|         :param batch_size: The size of the batches to use during training. | ||||
|         :param max_n_eval_batches: The maximum number batches to use for evaluation. | ||||
|  | ||||
|         """ | ||||
|         self.model = model | ||||
|         self.optimizer = optimizer | ||||
| @@ -50,6 +54,7 @@ class PyTorchModelTrainer: | ||||
|         self.max_iters: int = kwargs.get("max_iters", 100) | ||||
|         self.batch_size: int = kwargs.get("batch_size", 64) | ||||
|         self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) | ||||
|         self.squeeze_target_tensor = squeeze_target_tensor | ||||
|         if init_model: | ||||
|             self.load_from_checkpoint(init_model) | ||||
|  | ||||
| @@ -124,15 +129,14 @@ class PyTorchModelTrainer: | ||||
|         """ | ||||
|         data_loader_dictionary = {} | ||||
|         for split in ["train", "test"]: | ||||
|             labels_shape = data_dictionary[f"{split}_labels"].shape | ||||
|             labels_view = (labels_shape[0], 1) if labels_shape[1] == 1 else labels_shape | ||||
|             dataset = TensorDataset( | ||||
|                 torch.from_numpy(data_dictionary[f"{split}_features"].values).float(), | ||||
|                 torch.from_numpy(data_dictionary[f"{split}_labels"].values) | ||||
|             x = torch.from_numpy(data_dictionary[f"{split}_features"].values).float() | ||||
|             y = torch.from_numpy(data_dictionary[f"{split}_labels"].values)\ | ||||
|                 .to(self.target_tensor_type) | ||||
|                 .view(labels_view) | ||||
|             ) | ||||
|  | ||||
|             if self.squeeze_target_tensor: | ||||
|                 y = y.squeeze() | ||||
|  | ||||
|             dataset = TensorDataset(x, y) | ||||
|             data_loader = DataLoader( | ||||
|                 dataset, | ||||
|                 batch_size=self.batch_size, | ||||
|   | ||||
| @@ -73,6 +73,7 @@ class PyTorchMLPClassifier(PyTorchClassifier): | ||||
|             device=self.device, | ||||
|             init_model=init_model, | ||||
|             target_tensor_type=torch.long, | ||||
|             squeeze_target_tensor=True, | ||||
|             **self.trainer_kwargs, | ||||
|         ) | ||||
|         trainer.fit(data_dictionary) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user