add documentation

This commit is contained in:
Yinon Polak 2023-04-03 17:06:39 +03:00
parent 7b494c8333
commit d9d9993179
3 changed files with 12 additions and 0 deletions

View File

@ -76,4 +76,8 @@ class BasePyTorchModel(IFreqaiModel, ABC):
@property @property
@abstractmethod @abstractmethod
def data_convertor(self) -> PyTorchDataConvertor: def data_convertor(self) -> PyTorchDataConvertor:
"""
a class responsible for converting `*_features` & `*_labels` pandas dataframes
to pytorch tensors.
"""
raise NotImplementedError("Abstract property") raise NotImplementedError("Abstract property")

View File

@ -6,6 +6,10 @@ import torch
class PyTorchDataConvertor(ABC): class PyTorchDataConvertor(ABC):
"""
This class is responsible for converting `*_features` & `*_labels` pandas dataframes
to pytorch tensors.
"""
@abstractmethod @abstractmethod
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
@ -25,6 +29,9 @@ class PyTorchDataConvertor(ABC):
class DefaultPyTorchDataConvertor(PyTorchDataConvertor): class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
"""
A default conversion that keeps features dataframe shapes.
"""
def __init__( def __init__(
self, self,

View File

@ -36,6 +36,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
:param init_model: A dictionary containing the initial model/optimizer :param init_model: A dictionary containing the initial model/optimizer
state_dict and model_meta_data saved by self.save() method. state_dict and model_meta_data saved by self.save() method.
:param model_meta_data: Additional metadata about the model (optional). :param model_meta_data: Additional metadata about the model (optional).
:param data_convertor: convertor from pd.DataFrame to torch.tensor.
:param max_iters: The number of training iterations to run. :param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs. self.optimizer.step(). used to calculate n_epochs.