round up divisions in calc_n_epochs
This commit is contained in:
parent
443263803c
commit
97339e14cf
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@ -148,10 +149,13 @@ class PyTorchModelTrainer:
|
|||||||
"""
|
"""
|
||||||
Calculates the number of epochs required to reach the maximum number
|
Calculates the number of epochs required to reach the maximum number
|
||||||
of iterations specified in the model training parameters.
|
of iterations specified in the model training parameters.
|
||||||
|
|
||||||
|
the motivation here is that `max_iters` is easier to optimize and keep stable,
|
||||||
|
across different n_obs - the number of data points.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n_batches = n_obs // batch_size
|
n_batches = math.ceil(n_obs // batch_size)
|
||||||
epochs = n_iters // n_batches
|
epochs = math.ceil(n_iters // n_batches)
|
||||||
return epochs
|
return epochs
|
||||||
|
|
||||||
def save(self, path: Path):
|
def save(self, path: Path):
|
||||||
|
Loading…
Reference in New Issue
Block a user