add documentation
This commit is contained in:
parent
7b494c8333
commit
d9d9993179
@ -76,4 +76,8 @@ class BasePyTorchModel(IFreqaiModel, ABC):
|
|||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def data_convertor(self) -> PyTorchDataConvertor:
|
def data_convertor(self) -> PyTorchDataConvertor:
|
||||||
|
"""
|
||||||
|
a class responsible for converting `*_features` & `*_labels` pandas dataframes
|
||||||
|
to pytorch tensors.
|
||||||
|
"""
|
||||||
raise NotImplementedError("Abstract property")
|
raise NotImplementedError("Abstract property")
|
||||||
|
@ -6,6 +6,10 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class PyTorchDataConvertor(ABC):
|
class PyTorchDataConvertor(ABC):
|
||||||
|
"""
|
||||||
|
This class is responsible for converting `*_features` & `*_labels` pandas dataframes
|
||||||
|
to pytorch tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
|
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
|
||||||
@ -25,6 +29,9 @@ class PyTorchDataConvertor(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
||||||
|
"""
|
||||||
|
A default conversion that keeps features dataframe shapes.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -36,6 +36,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
:param init_model: A dictionary containing the initial model/optimizer
|
:param init_model: A dictionary containing the initial model/optimizer
|
||||||
state_dict and model_meta_data saved by self.save() method.
|
state_dict and model_meta_data saved by self.save() method.
|
||||||
:param model_meta_data: Additional metadata about the model (optional).
|
:param model_meta_data: Additional metadata about the model (optional).
|
||||||
|
:param data_convertor: convertor from pd.DataFrame to torch.tensor.
|
||||||
:param max_iters: The number of training iterations to run.
|
:param max_iters: The number of training iterations to run.
|
||||||
iteration here refers to the number of times we call
|
iteration here refers to the number of times we call
|
||||||
self.optimizer.step(). used to calculate n_epochs.
|
self.optimizer.step(). used to calculate n_epochs.
|
||||||
|
Loading…
Reference in New Issue
Block a user