add monitor to eval env so that multiproc can save best_model
This commit is contained in:
parent
69d542d3e2
commit
dd382dd370
@ -6,6 +6,7 @@ import numpy as np
|
|||||||
# import pandas as pd
|
# import pandas as pd
|
||||||
import torch as th
|
import torch as th
|
||||||
# from pandas import DataFrame
|
# from pandas import DataFrame
|
||||||
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||||
reward_params, window_size) -> Callable:
|
reward_params, window_size, monitor=False) -> Callable:
|
||||||
"""
|
"""
|
||||||
Utility function for multiprocessed env.
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
@ -34,6 +35,8 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
|||||||
|
|
||||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank)
|
reward_kwargs=reward_params, id=env_id, seed=seed + rank)
|
||||||
|
if monitor:
|
||||||
|
env = Monitor(env, ".")
|
||||||
return env
|
return env
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
return _init
|
return _init
|
||||||
@ -66,7 +69,7 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
|||||||
|
|
||||||
eval_env_id = 'eval_env'
|
eval_env_id = 'eval_env'
|
||||||
eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, price_test, reward_params,
|
eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, price_test, reward_params,
|
||||||
self.CONV_WIDTH) for i in range(num_cpu)])
|
self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
|
||||||
|
|
||||||
path = dk.data_path
|
path = dk.data_path
|
||||||
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
|
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
|
||||||
@ -86,12 +89,11 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
|||||||
callback=eval_callback
|
callback=eval_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO get callback working so the best model is saved. For now we save last model
|
best_model = PPO.load(dk.data_path / "best_model.zip")
|
||||||
# best_model = PPO.load(dk.data_path / "best_model.zip")
|
|
||||||
print('Training finished!')
|
print('Training finished!')
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
|
|
||||||
return model # best_model
|
return best_model
|
||||||
|
|
||||||
|
|
||||||
class MyRLEnv(Base3ActionRLEnv):
|
class MyRLEnv(Base3ActionRLEnv):
|
||||||
|
Loading…
Reference in New Issue
Block a user