add documentation to pytorch data convertor

This commit is contained in:
Yinon Polak 2023-04-03 16:39:49 +03:00
parent bc9454e0f9
commit 7b494c8333

View File

@ -31,6 +31,12 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
target_tensor_type: Optional[torch.dtype] = None,
squeeze_target_tensor: bool = False
):
"""
:param target_tensor_type: type of target tensor, for classification use
torch.long, for regressor use torch.float or torch.double.
:param squeeze_target_tensor: controls the target shape, used for loss functions
that requires 0D or 1D.
"""
self._target_tensor_type = target_tensor_type
self._squeeze_target_tensor = squeeze_target_tensor