add documentation

This commit is contained in:
Yinon Polak 2023-04-03 17:06:39 +03:00
parent 7b494c8333
commit d9d9993179
3 changed files with 12 additions and 0 deletions

View File

@ -76,4 +76,8 @@ class BasePyTorchModel(IFreqaiModel, ABC):
@property
@abstractmethod
def data_convertor(self) -> PyTorchDataConvertor:
"""
a class responsible for converting `*_features` & `*_labels` pandas dataframes
to pytorch tensors.
"""
raise NotImplementedError("Abstract property")

View File

@ -6,6 +6,10 @@ import torch
class PyTorchDataConvertor(ABC):
"""
This class is responsible for converting `*_features` & `*_labels` pandas dataframes
to pytorch tensors.
"""
@abstractmethod
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
@ -25,6 +29,9 @@ class PyTorchDataConvertor(ABC):
class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
"""
A default conversion that keeps features dataframe shapes.
"""
def __init__(
self,

View File

@ -36,6 +36,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
:param init_model: A dictionary containing the initial model/optimizer
state_dict and model_meta_data saved by self.save() method.
:param model_meta_data: Additional metadata about the model (optional).
:param data_convertor: convertor from pd.DataFrame to torch.tensor.
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs.