remove unuse code and fix coding conventions
This commit is contained in:
@@ -81,7 +81,10 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
||||
net_arch=[512, 512, 512])
|
||||
|
||||
model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=learning_rate, gamma=0.9, verbose=1
|
||||
tensorboard_log=f"{path}/ppo/tensorboard/",
|
||||
learning_rate=learning_rate,
|
||||
gamma=0.9,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
model.learn(
|
||||
|
@@ -4,7 +4,8 @@ import torch as th
|
||||
import numpy as np
|
||||
import gym
|
||||
from typing import Callable
|
||||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold
|
||||
from stable_baselines3.common.callbacks import (
|
||||
EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold)
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
@@ -18,6 +19,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
reward_params, window_size, monitor=False) -> Callable:
|
||||
"""
|
||||
@@ -39,6 +41,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
set_random_seed(seed)
|
||||
return _init
|
||||
|
||||
|
||||
class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
|
||||
"""
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
@@ -69,11 +72,22 @@ class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
|
||||
self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
|
||||
|
||||
path = dk.data_path
|
||||
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=5, min_evals=10, verbose=2)
|
||||
stop_train_callback = StopTrainingOnNoModelImprovement(
|
||||
max_no_improvement_evals=5,
|
||||
min_evals=10,
|
||||
verbose=2
|
||||
)
|
||||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=2)
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
|
||||
log_path=f"{path}/tdqn/logs/", eval_freq=int(eval_freq),
|
||||
deterministic=True, render=True, callback_after_eval=stop_train_callback, callback_on_new_best=callback_on_best, verbose=2)
|
||||
eval_callback = EvalCallback(
|
||||
eval_env, best_model_save_path=f"{path}/",
|
||||
log_path=f"{path}/tdqn/logs/",
|
||||
eval_freq=int(eval_freq),
|
||||
deterministic=True,
|
||||
render=True,
|
||||
callback_after_eval=stop_train_callback,
|
||||
callback_on_new_best=callback_on_best,
|
||||
verbose=2
|
||||
)
|
||||
# model arch
|
||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=[512, 512, 512])
|
||||
|
Reference in New Issue
Block a user