From dd382dd3702cfe7edf2848adc9f7958d08ac62dc Mon Sep 17 00:00:00 2001 From: robcaulk Date: Mon, 15 Aug 2022 18:56:53 +0200 Subject: [PATCH] add monitor to eval env so that multiproc can save best_model --- .../ReinforcementLearningPPO_multiproc.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py index 8370500b9..e8f67cbb8 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py @@ -6,6 +6,7 @@ import numpy as np # import pandas as pd import torch as th # from pandas import DataFrame +from stable_baselines3.common.monitor import Monitor from typing import Callable from stable_baselines3 import PPO from stable_baselines3.common.callbacks import EvalCallback @@ -20,7 +21,7 @@ logger = logging.getLogger(__name__) def make_env(env_id: str, rank: int, seed: int, train_df, price, - reward_params, window_size) -> Callable: + reward_params, window_size, monitor=False) -> Callable: """ Utility function for multiprocessed env. @@ -34,6 +35,8 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price, env = MyRLEnv(df=train_df, prices=price, window_size=window_size, reward_kwargs=reward_params, id=env_id, seed=seed + rank) + if monitor: + env = Monitor(env, ".") return env set_random_seed(seed) return _init @@ -66,7 +69,7 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel): eval_env_id = 'eval_env' eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, price_test, reward_params, - self.CONV_WIDTH) for i in range(num_cpu)]) + self.CONV_WIDTH, monitor=True) for i in range(num_cpu)]) path = dk.data_path eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/", @@ -86,12 +89,11 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel): callback=eval_callback ) - # TODO get callback working so the best model is saved. For now we save last model - # best_model = PPO.load(dk.data_path / "best_model.zip") + best_model = PPO.load(dk.data_path / "best_model.zip") print('Training finished!') eval_env.close() - return model # best_model + return best_model class MyRLEnv(Base3ActionRLEnv):