add documentation
This commit is contained in:
@@ -6,6 +6,10 @@ import torch
|
||||
|
||||
|
||||
class PyTorchDataConvertor(ABC):
|
||||
"""
|
||||
This class is responsible for converting `*_features` & `*_labels` pandas dataframes
|
||||
to pytorch tensors.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
|
||||
@@ -25,6 +29,9 @@ class PyTorchDataConvertor(ABC):
|
||||
|
||||
|
||||
class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
||||
"""
|
||||
A default conversion that keeps features dataframe shapes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user