round up divisions in calc_n_epochs

This commit is contained in:
Yinon Polak 2023-03-21 12:29:05 +02:00
parent 443263803c
commit 97339e14cf
1 changed files with 6 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import logging
import math
from pathlib import Path
from typing import Any, Dict, Optional
@ -148,10 +149,13 @@ class PyTorchModelTrainer:
"""
Calculates the number of epochs required to reach the maximum number
of iterations specified in the model training parameters.
the motivation here is that `max_iters` is easier to optimize and keep stable,
across different n_obs - the number of data points.
"""
n_batches = n_obs // batch_size
epochs = n_iters // n_batches
n_batches = math.ceil(n_obs // batch_size)
epochs = math.ceil(n_iters // n_batches)
return epochs
def save(self, path: Path):