add optional target tensor squeezing to pytorch trainer

This commit is contained in:
Yinon Polak
2023-03-21 13:20:54 +02:00
parent 97339e14cf
commit a80afc8f1b
2 changed files with 12 additions and 7 deletions

View File

@@ -73,6 +73,7 @@ class PyTorchMLPClassifier(PyTorchClassifier):
device=self.device,
init_model=init_model,
target_tensor_type=torch.long,
squeeze_target_tensor=True,
**self.trainer_kwargs,
)
trainer.fit(data_dictionary)