learning_rate + multicpu changes
This commit is contained in:
parent
48bb51b458
commit
57c488a6f1
@ -50,19 +50,22 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
||||
test_df = data_dictionary["test_features"]
|
||||
eval_freq = agent_params.get("eval_cycles", 4) * len(test_df)
|
||||
total_timesteps = agent_params["train_cycles"] * len(train_df)
|
||||
learning_rate = agent_params["learning_rate"]
|
||||
|
||||
# price data for model training and evaluation
|
||||
price = self.dd.historic_data[pair][f"{self.config['timeframe']}"].tail(len(train_df.index))
|
||||
price_test = self.dd.historic_data[pair][f"{self.config['timeframe']}"].tail(
|
||||
len(test_df.index))
|
||||
|
||||
env_id = "CartPole-v1"
|
||||
num_cpu = 4
|
||||
env_id = "train_env"
|
||||
train_num_cpu = 6
|
||||
train_env = SubprocVecEnv([make_env(env_id, i, 1, train_df, price, reward_params,
|
||||
self.CONV_WIDTH) for i in range(num_cpu)])
|
||||
self.CONV_WIDTH) for i in range(train_num_cpu)])
|
||||
|
||||
eval_env = SubprocVecEnv([make_env(env_id, i, 1, test_df, price_test, reward_params,
|
||||
self.CONV_WIDTH) for i in range(num_cpu)])
|
||||
eval_num_cpu = 6
|
||||
eval_env_id = 'eval_env'
|
||||
eval_env = SubprocVecEnv([make_env(eval_env_id, i, 1, test_df, price_test, reward_params,
|
||||
self.CONV_WIDTH) for i in range(eval_num_cpu)])
|
||||
|
||||
path = self.dk.data_path
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
|
||||
@ -71,10 +74,10 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
||||
|
||||
# model arch
|
||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=[256, 256, 128])
|
||||
net_arch=[512, 512, 512])
|
||||
|
||||
model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=0.00025, gamma=0.9, verbose=1
|
||||
tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=learning_rate, gamma=0.9, verbose=1
|
||||
)
|
||||
|
||||
model.learn(
|
||||
@ -83,6 +86,7 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
||||
)
|
||||
|
||||
print('Training finished!')
|
||||
eval_env.close()
|
||||
|
||||
return model
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user