add documentation

This commit is contained in:
Yinon Polak
2023-04-03 17:06:39 +03:00
parent 7b494c8333
commit d9d9993179
3 changed files with 12 additions and 0 deletions

View File

@@ -6,6 +6,10 @@ import torch
class PyTorchDataConvertor(ABC):
"""
This class is responsible for converting `*_features` & `*_labels` pandas dataframes
to pytorch tensors.
"""
@abstractmethod
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
@@ -25,6 +29,9 @@ class PyTorchDataConvertor(ABC):
class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
"""
A default conversion that keeps features dataframe shapes.
"""
def __init__(
self,