From ba5de0cd00423570cc484a44ffa281718bda47a9 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Thu, 9 Mar 2023 11:21:10 +0200 Subject: [PATCH] add documentation --- freqtrade/freqai/base_models/PyTorchModelTrainer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 26149e2fa..41d26e31a 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -54,10 +54,7 @@ class PyTorchModelTrainer: def fit(self, data_dictionary: Dict[str, pd.DataFrame]): """ - general training loop: - - converting data to tensors - - calculating n_epochs - - + General training loop. """ data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary) epochs = self.calc_n_epochs( @@ -117,6 +114,9 @@ class PyTorchModelTrainer: self, data_dictionary: Dict[str, pd.DataFrame] ) -> Dict[str, DataLoader]: + """ + Converts the input data to PyTorch tensors using a data loader. + """ data_loader_dictionary = {} for split in ['train', 'test']: labels_shape = data_dictionary[f'{split}_labels'].shape @@ -141,6 +141,10 @@ class PyTorchModelTrainer: @staticmethod def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int: + """ + Calculates the number of epochs required to reach the maximum number + of iterations specified in the model training parameters. + """ n_batches = n_obs // batch_size epochs = n_iters // n_batches return epochs