fix model loading from disk bug, improve doc, clarify installation/docker instructions, add a torch tag to the freqairl docker image. Fix seriously outdated prediction_model docstrings

This commit is contained in:
robcaulk
2023-04-08 12:09:53 +02:00
parent a655524221
commit 48d3c8e62e
21 changed files with 195 additions and 83 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional
import pandas as pd
import torch
import torch.nn as nn
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset
@@ -169,6 +169,12 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
n_batches = math.ceil(n_obs // batch_size)
epochs = math.ceil(n_iters // n_batches)
if epochs <= 10:
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
f" {epochs} epochs. Please consider increasing this value accordingly")
if epochs <= 1:
logger.warning("Epochs set to 1. Please review your `max_iters` value")
epochs = 1
return epochs
def save(self, path: Path):
@@ -182,6 +188,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"model_meta_data": self.model_meta_data,
"pytrainer": self
}, path)
def load(self, path: Path):
@@ -195,7 +202,6 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
you can access this dict from any class that inherits IFreqaiModel by calling
get_init_model method.
"""
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.model_meta_data = checkpoint["model_meta_data"]

View File

@@ -4,7 +4,7 @@ from typing import Dict, List
import pandas as pd
import torch
import torch.nn as nn
from torch import nn
class PyTorchTrainerInterface(ABC):