fix model loading from disk bug, improve doc, clarify installation/docker instructions, add a torch tag to the freqairl docker image. Fix seriously outdated prediction_model docstrings
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
@@ -169,6 +169,12 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
|
||||
n_batches = math.ceil(n_obs // batch_size)
|
||||
epochs = math.ceil(n_iters // n_batches)
|
||||
if epochs <= 10:
|
||||
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
|
||||
f" {epochs} epochs. Please consider increasing this value accordingly")
|
||||
if epochs <= 1:
|
||||
logger.warning("Epochs set to 1. Please review your `max_iters` value")
|
||||
epochs = 1
|
||||
return epochs
|
||||
|
||||
def save(self, path: Path):
|
||||
@@ -182,6 +188,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"model_meta_data": self.model_meta_data,
|
||||
"pytrainer": self
|
||||
}, path)
|
||||
|
||||
def load(self, path: Path):
|
||||
@@ -195,7 +202,6 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
you can access this dict from any class that inherits IFreqaiModel by calling
|
||||
get_init_model method.
|
||||
"""
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
self.model_meta_data = checkpoint["model_meta_data"]
|
||||
|
@@ -4,7 +4,7 @@ from typing import Dict, List
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PyTorchTrainerInterface(ABC):
|
||||
|
Reference in New Issue
Block a user