round up divisions in calc_n_epochs
This commit is contained in:
parent
443263803c
commit
97339e14cf
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@ -148,10 +149,13 @@ class PyTorchModelTrainer:
|
||||
"""
|
||||
Calculates the number of epochs required to reach the maximum number
|
||||
of iterations specified in the model training parameters.
|
||||
|
||||
the motivation here is that `max_iters` is easier to optimize and keep stable,
|
||||
across different n_obs - the number of data points.
|
||||
"""
|
||||
|
||||
n_batches = n_obs // batch_size
|
||||
epochs = n_iters // n_batches
|
||||
n_batches = math.ceil(n_obs // batch_size)
|
||||
epochs = math.ceil(n_iters // n_batches)
|
||||
return epochs
|
||||
|
||||
def save(self, path: Path):
|
||||
|
Loading…
Reference in New Issue
Block a user