add monitor to eval env so that multiproc can save best_model
This commit is contained in:
parent
69d542d3e2
commit
dd382dd370
@ -6,6 +6,7 @@ import numpy as np
|
||||
# import pandas as pd
|
||||
import torch as th
|
||||
# from pandas import DataFrame
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from typing import Callable
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
reward_params, window_size) -> Callable:
|
||||
reward_params, window_size, monitor=False) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
@ -34,6 +35,8 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank)
|
||||
if monitor:
|
||||
env = Monitor(env, ".")
|
||||
return env
|
||||
set_random_seed(seed)
|
||||
return _init
|
||||
@ -66,7 +69,7 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
||||
|
||||
eval_env_id = 'eval_env'
|
||||
eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, price_test, reward_params,
|
||||
self.CONV_WIDTH) for i in range(num_cpu)])
|
||||
self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
|
||||
|
||||
path = dk.data_path
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
|
||||
@ -86,12 +89,11 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
||||
callback=eval_callback
|
||||
)
|
||||
|
||||
# TODO get callback working so the best model is saved. For now we save last model
|
||||
# best_model = PPO.load(dk.data_path / "best_model.zip")
|
||||
best_model = PPO.load(dk.data_path / "best_model.zip")
|
||||
print('Training finished!')
|
||||
eval_env.close()
|
||||
|
||||
return model # best_model
|
||||
return best_model
|
||||
|
||||
|
||||
class MyRLEnv(Base3ActionRLEnv):
|
||||
|
Loading…
Reference in New Issue
Block a user