add documentation to pytorch data convertor
This commit is contained in:
parent
bc9454e0f9
commit
7b494c8333
@ -31,6 +31,12 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
|||||||
target_tensor_type: Optional[torch.dtype] = None,
|
target_tensor_type: Optional[torch.dtype] = None,
|
||||||
squeeze_target_tensor: bool = False
|
squeeze_target_tensor: bool = False
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
:param target_tensor_type: type of target tensor, for classification use
|
||||||
|
torch.long, for regressor use torch.float or torch.double.
|
||||||
|
:param squeeze_target_tensor: controls the target shape, used for loss functions
|
||||||
|
that requires 0D or 1D.
|
||||||
|
"""
|
||||||
self._target_tensor_type = target_tensor_type
|
self._target_tensor_type = target_tensor_type
|
||||||
self._squeeze_target_tensor = squeeze_target_tensor
|
self._squeeze_target_tensor = squeeze_target_tensor
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user