From d9d99931792d0cf21c1e83f2b75da330f3ea8bdf Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Mon, 3 Apr 2023 17:06:39 +0300 Subject: [PATCH] add documentation --- freqtrade/freqai/base_models/BasePyTorchModel.py | 4 ++++ freqtrade/freqai/torch/PyTorchDataConvertor.py | 7 +++++++ freqtrade/freqai/torch/PyTorchModelTrainer.py | 1 + 3 files changed, 12 insertions(+) diff --git a/freqtrade/freqai/base_models/BasePyTorchModel.py b/freqtrade/freqai/base_models/BasePyTorchModel.py index d017f1fec..8177b8eb8 100644 --- a/freqtrade/freqai/base_models/BasePyTorchModel.py +++ b/freqtrade/freqai/base_models/BasePyTorchModel.py @@ -76,4 +76,8 @@ class BasePyTorchModel(IFreqaiModel, ABC): @property @abstractmethod def data_convertor(self) -> PyTorchDataConvertor: + """ + a class responsible for converting `*_features` & `*_labels` pandas dataframes + to pytorch tensors. + """ raise NotImplementedError("Abstract property") diff --git a/freqtrade/freqai/torch/PyTorchDataConvertor.py b/freqtrade/freqai/torch/PyTorchDataConvertor.py index 5982a1b48..e7d5c3ffe 100644 --- a/freqtrade/freqai/torch/PyTorchDataConvertor.py +++ b/freqtrade/freqai/torch/PyTorchDataConvertor.py @@ -6,6 +6,10 @@ import torch class PyTorchDataConvertor(ABC): + """ + This class is responsible for converting `*_features` & `*_labels` pandas dataframes + to pytorch tensors. + """ @abstractmethod def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: @@ -25,6 +29,9 @@ class PyTorchDataConvertor(ABC): class DefaultPyTorchDataConvertor(PyTorchDataConvertor): + """ + A default conversion that keeps features dataframe shapes. + """ def __init__( self, diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index 09de6f940..6449d98b5 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -36,6 +36,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): :param init_model: A dictionary containing the initial model/optimizer state_dict and model_meta_data saved by self.save() method. :param model_meta_data: Additional metadata about the model (optional). + :param data_convertor: convertor from pd.DataFrame to torch.tensor. :param max_iters: The number of training iterations to run. iteration here refers to the number of times we call self.optimizer.step(). used to calculate n_epochs.