add documentation to pytorch data convertor
This commit is contained in:
parent
bc9454e0f9
commit
7b494c8333
@ -31,6 +31,12 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
||||
target_tensor_type: Optional[torch.dtype] = None,
|
||||
squeeze_target_tensor: bool = False
|
||||
):
|
||||
"""
|
||||
:param target_tensor_type: type of target tensor, for classification use
|
||||
torch.long, for regressor use torch.float or torch.double.
|
||||
:param squeeze_target_tensor: controls the target shape, used for loss functions
|
||||
that requires 0D or 1D.
|
||||
"""
|
||||
self._target_tensor_type = target_tensor_type
|
||||
self._squeeze_target_tensor = squeeze_target_tensor
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user