From 3081b9402b78406c4edee6e6ef6cdc4937b64e50 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Thu, 9 Mar 2023 11:14:54 +0200 Subject: [PATCH] add documentation --- .../freqai/base_models/BasePyTorchModel.py | 2 +- .../freqai/base_models/PyTorchModelTrainer.py | 21 +++++++++++++++++++ .../PyTorchClassifierMultiTarget.py | 17 +++++++++++++-- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/freqtrade/freqai/base_models/BasePyTorchModel.py b/freqtrade/freqai/base_models/BasePyTorchModel.py index efc36fdec..8e608ee1a 100644 --- a/freqtrade/freqai/base_models/BasePyTorchModel.py +++ b/freqtrade/freqai/base_models/BasePyTorchModel.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) class BasePyTorchModel(IFreqaiModel): """ - Base class for TensorFlow type models. + Base class for PyTorch type models. User *must* inherit from this class and set fit() and predict(). """ diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 5ebecef34..26149e2fa 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -25,6 +25,21 @@ class PyTorchModelTrainer: init_model: Dict, model_meta_data: Dict[str, Any] = {}, ): + """ + A class for training PyTorch models. + + :param model: The PyTorch model to be trained. + :param optimizer: The optimizer to use for training. + :param criterion: The loss function to use for training. + :param device: The device to use for training (e.g. 'cpu', 'cuda'). + :param batch_size: The size of the batches to use during training. + :param max_iters: The number of training iterations to run. + iteration here refers to the number of times we call + self.optimizer.step() . used to calculate n_epochs. + :param eval_iters: The number of iterations used to estimate the loss. + :param init_model: A dictionary containing the initial model parameters. + :param model_meta_data: Additional metadata about the model (optional). + """ self.model = model self.optimizer = optimizer self.criterion = criterion @@ -38,6 +53,12 @@ class PyTorchModelTrainer: self.load_from_checkpoint(init_model) def fit(self, data_dictionary: Dict[str, pd.DataFrame]): + """ + general training loop: + - converting data to tensors + - calculating n_epochs + - + """ data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary) epochs = self.calc_n_epochs( n_obs=len(data_dictionary['train_features']), diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index 13ec2d0bb..e8326ffe9 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -19,8 +19,19 @@ logger = logging.getLogger(__name__) class PyTorchClassifierMultiTarget(BasePyTorchModel): - + """ + A PyTorch implementation of a multi-target classifier. + """ def __init__(self, **kwargs): + """ + int: The number of nodes in the hidden layer of the neural network. + int: The maximum number of iterations to run during training. + int: The batch size to use during training. + float: The learning rate to use during training. + int: The number of training iterations between each evaluation. + dict: A dictionary mapping class names to their corresponding indices. + dict: A dictionary mapping indices to their corresponding class names. + """ super().__init__(**kwargs) model_training_parameters = self.freqai_info["model_training_parameters"] self.n_hidden = model_training_parameters.get("n_hidden", 1024) @@ -34,8 +45,10 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here - :param tensor_dictionary: the dictionary constructed by DataHandler to hold + :param data_dictionary: the dictionary constructed by DataHandler to hold all the training and test data/labels. + :raises ValueError: If self.class_names is not defined in the parent class. + """ if not hasattr(self, "class_names"): raise ValueError(