add documentation

This commit is contained in:
Yinon Polak 2023-03-09 11:21:10 +02:00
parent 3081b9402b
commit ba5de0cd00

View File

@ -54,10 +54,7 @@ class PyTorchModelTrainer:
def fit(self, data_dictionary: Dict[str, pd.DataFrame]):
"""
general training loop:
- converting data to tensors
- calculating n_epochs
-
General training loop.
"""
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary)
epochs = self.calc_n_epochs(
@ -117,6 +114,9 @@ class PyTorchModelTrainer:
self,
data_dictionary: Dict[str, pd.DataFrame]
) -> Dict[str, DataLoader]:
"""
Converts the input data to PyTorch tensors using a data loader.
"""
data_loader_dictionary = {}
for split in ['train', 'test']:
labels_shape = data_dictionary[f'{split}_labels'].shape
@ -141,6 +141,10 @@ class PyTorchModelTrainer:
@staticmethod
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
"""
Calculates the number of epochs required to reach the maximum number
of iterations specified in the model training parameters.
"""
n_batches = n_obs // batch_size
epochs = n_iters // n_batches
return epochs