set class names in IStrategy.set_freqai_targets method, also save class name with model meta data
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -22,11 +22,13 @@ class PyTorchModelTrainer:
|
||||
batch_size: int,
|
||||
max_iters: int,
|
||||
eval_iters: int,
|
||||
init_model: Dict
|
||||
init_model: Dict,
|
||||
model_meta_data: Dict[str, Any] = {},
|
||||
):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.model_meta_data = model_meta_data
|
||||
self.device = device
|
||||
self.max_iters = max_iters
|
||||
self.batch_size = batch_size
|
||||
@@ -126,6 +128,7 @@ class PyTorchModelTrainer:
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'model_meta_data': self.model_meta_data,
|
||||
}, path)
|
||||
|
||||
def load_from_file(self, path: Path):
|
||||
@@ -135,4 +138,5 @@ class PyTorchModelTrainer:
|
||||
def load_from_checkpoint(self, checkpoint: Dict):
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.model_meta_data = checkpoint["model_meta_data"]
|
||||
return self
|
||||
|
Reference in New Issue
Block a user