add optional target tensor squeezing to pytorch trainer
This commit is contained in:
@@ -73,6 +73,7 @@ class PyTorchMLPClassifier(PyTorchClassifier):
|
||||
device=self.device,
|
||||
init_model=init_model,
|
||||
target_tensor_type=torch.long,
|
||||
squeeze_target_tensor=True,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
trainer.fit(data_dictionary)
|
||||
|
Reference in New Issue
Block a user