add documentation
This commit is contained in:
parent
3081b9402b
commit
ba5de0cd00
@ -54,10 +54,7 @@ class PyTorchModelTrainer:
|
||||
|
||||
def fit(self, data_dictionary: Dict[str, pd.DataFrame]):
|
||||
"""
|
||||
general training loop:
|
||||
- converting data to tensors
|
||||
- calculating n_epochs
|
||||
-
|
||||
General training loop.
|
||||
"""
|
||||
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary)
|
||||
epochs = self.calc_n_epochs(
|
||||
@ -117,6 +114,9 @@ class PyTorchModelTrainer:
|
||||
self,
|
||||
data_dictionary: Dict[str, pd.DataFrame]
|
||||
) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Converts the input data to PyTorch tensors using a data loader.
|
||||
"""
|
||||
data_loader_dictionary = {}
|
||||
for split in ['train', 'test']:
|
||||
labels_shape = data_dictionary[f'{split}_labels'].shape
|
||||
@ -141,6 +141,10 @@ class PyTorchModelTrainer:
|
||||
|
||||
@staticmethod
|
||||
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
|
||||
"""
|
||||
Calculates the number of epochs required to reach the maximum number
|
||||
of iterations specified in the model training parameters.
|
||||
"""
|
||||
n_batches = n_obs // batch_size
|
||||
epochs = n_iters // n_batches
|
||||
return epochs
|
||||
|
Loading…
Reference in New Issue
Block a user