add documentation
This commit is contained in:
parent
3081b9402b
commit
ba5de0cd00
@ -54,10 +54,7 @@ class PyTorchModelTrainer:
|
|||||||
|
|
||||||
def fit(self, data_dictionary: Dict[str, pd.DataFrame]):
|
def fit(self, data_dictionary: Dict[str, pd.DataFrame]):
|
||||||
"""
|
"""
|
||||||
general training loop:
|
General training loop.
|
||||||
- converting data to tensors
|
|
||||||
- calculating n_epochs
|
|
||||||
-
|
|
||||||
"""
|
"""
|
||||||
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary)
|
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary)
|
||||||
epochs = self.calc_n_epochs(
|
epochs = self.calc_n_epochs(
|
||||||
@ -117,6 +114,9 @@ class PyTorchModelTrainer:
|
|||||||
self,
|
self,
|
||||||
data_dictionary: Dict[str, pd.DataFrame]
|
data_dictionary: Dict[str, pd.DataFrame]
|
||||||
) -> Dict[str, DataLoader]:
|
) -> Dict[str, DataLoader]:
|
||||||
|
"""
|
||||||
|
Converts the input data to PyTorch tensors using a data loader.
|
||||||
|
"""
|
||||||
data_loader_dictionary = {}
|
data_loader_dictionary = {}
|
||||||
for split in ['train', 'test']:
|
for split in ['train', 'test']:
|
||||||
labels_shape = data_dictionary[f'{split}_labels'].shape
|
labels_shape = data_dictionary[f'{split}_labels'].shape
|
||||||
@ -141,6 +141,10 @@ class PyTorchModelTrainer:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
|
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the number of epochs required to reach the maximum number
|
||||||
|
of iterations specified in the model training parameters.
|
||||||
|
"""
|
||||||
n_batches = n_obs // batch_size
|
n_batches = n_obs // batch_size
|
||||||
epochs = n_iters // n_batches
|
epochs = n_iters // n_batches
|
||||||
return epochs
|
return epochs
|
||||||
|
Loading…
Reference in New Issue
Block a user