diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py index 1b2873334..c00784d7a 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py @@ -50,19 +50,22 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel): test_df = data_dictionary["test_features"] eval_freq = agent_params.get("eval_cycles", 4) * len(test_df) total_timesteps = agent_params["train_cycles"] * len(train_df) + learning_rate = agent_params["learning_rate"] # price data for model training and evaluation price = self.dd.historic_data[pair][f"{self.config['timeframe']}"].tail(len(train_df.index)) price_test = self.dd.historic_data[pair][f"{self.config['timeframe']}"].tail( len(test_df.index)) - env_id = "CartPole-v1" - num_cpu = 4 + env_id = "train_env" + train_num_cpu = 6 train_env = SubprocVecEnv([make_env(env_id, i, 1, train_df, price, reward_params, - self.CONV_WIDTH) for i in range(num_cpu)]) + self.CONV_WIDTH) for i in range(train_num_cpu)]) - eval_env = SubprocVecEnv([make_env(env_id, i, 1, test_df, price_test, reward_params, - self.CONV_WIDTH) for i in range(num_cpu)]) + eval_num_cpu = 6 + eval_env_id = 'eval_env' + eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, price_test, reward_params, + self.CONV_WIDTH) for i in range(eval_num_cpu)]) path = self.dk.data_path eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/", @@ -71,10 +74,10 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel): # model arch policy_kwargs = dict(activation_fn=th.nn.ReLU, - net_arch=[256, 256, 128]) + net_arch=[512, 512, 512]) model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs, - tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=0.00025, gamma=0.9, verbose=1 + tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=learning_rate, gamma=0.9, verbose=1 ) model.learn( @@ -83,6 +86,7 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel): ) print('Training finished!') + eval_env.close() return model