set class names in IStrategy.set_freqai_targets method, also save class name with model meta data

This commit is contained in:
Yinon Polak
2023-03-08 18:36:44 +02:00
parent 7d26df01b8
commit 1597c3aa89
2 changed files with 33 additions and 20 deletions

View File

@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Dict
from typing import Any, Dict
import pandas as pd
import torch
@@ -22,11 +22,13 @@ class PyTorchModelTrainer:
batch_size: int,
max_iters: int,
eval_iters: int,
init_model: Dict
init_model: Dict,
model_meta_data: Dict[str, Any] = {},
):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device
self.max_iters = max_iters
self.batch_size = batch_size
@@ -126,6 +128,7 @@ class PyTorchModelTrainer:
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'model_meta_data': self.model_meta_data,
}, path)
def load_from_file(self, path: Path):
@@ -135,4 +138,5 @@ class PyTorchModelTrainer:
def load_from_checkpoint(self, checkpoint: Dict):
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.model_meta_data = checkpoint["model_meta_data"]
return self