add monitor to eval env so that multiproc can save best_model

This commit is contained in:
robcaulk 2022-08-15 18:56:53 +02:00
parent 69d542d3e2
commit dd382dd370

View File

@ -6,6 +6,7 @@ import numpy as np
# import pandas as pd # import pandas as pd
import torch as th import torch as th
# from pandas import DataFrame # from pandas import DataFrame
from stable_baselines3.common.monitor import Monitor
from typing import Callable from typing import Callable
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
def make_env(env_id: str, rank: int, seed: int, train_df, price, def make_env(env_id: str, rank: int, seed: int, train_df, price,
reward_params, window_size) -> Callable: reward_params, window_size, monitor=False) -> Callable:
""" """
Utility function for multiprocessed env. Utility function for multiprocessed env.
@ -34,6 +35,8 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
env = MyRLEnv(df=train_df, prices=price, window_size=window_size, env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
reward_kwargs=reward_params, id=env_id, seed=seed + rank) reward_kwargs=reward_params, id=env_id, seed=seed + rank)
if monitor:
env = Monitor(env, ".")
return env return env
set_random_seed(seed) set_random_seed(seed)
return _init return _init
@ -66,7 +69,7 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
eval_env_id = 'eval_env' eval_env_id = 'eval_env'
eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, price_test, reward_params, eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, price_test, reward_params,
self.CONV_WIDTH) for i in range(num_cpu)]) self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
path = dk.data_path path = dk.data_path
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/", eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
@ -86,12 +89,11 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
callback=eval_callback callback=eval_callback
) )
# TODO get callback working so the best model is saved. For now we save last model best_model = PPO.load(dk.data_path / "best_model.zip")
# best_model = PPO.load(dk.data_path / "best_model.zip")
print('Training finished!') print('Training finished!')
eval_env.close() eval_env.close()
return model # best_model return best_model
class MyRLEnv(Base3ActionRLEnv): class MyRLEnv(Base3ActionRLEnv):