round up divisions in calc_n_epochs

This commit is contained in:
Yinon Polak 2023-03-21 12:29:05 +02:00
parent 443263803c
commit 97339e14cf

View File

@ -1,4 +1,5 @@
import logging import logging
import math
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -148,10 +149,13 @@ class PyTorchModelTrainer:
""" """
Calculates the number of epochs required to reach the maximum number Calculates the number of epochs required to reach the maximum number
of iterations specified in the model training parameters. of iterations specified in the model training parameters.
the motivation here is that `max_iters` is easier to optimize and keep stable,
across different n_obs - the number of data points.
""" """
n_batches = n_obs // batch_size n_batches = math.ceil(n_obs // batch_size)
epochs = n_iters // n_batches epochs = math.ceil(n_iters // n_batches)
return epochs return epochs
def save(self, path: Path): def save(self, path: Path):