remove unuse code and fix coding conventions

This commit is contained in:
sonnhfit
2022-08-16 09:30:35 +07:00
committed by robcaulk
parent d60a166fbf
commit 0475b7cb18
6 changed files with 23 additions and 29 deletions

View File

@@ -81,7 +81,10 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
net_arch=[512, 512, 512])
model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs,
tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=learning_rate, gamma=0.9, verbose=1
tensorboard_log=f"{path}/ppo/tensorboard/",
learning_rate=learning_rate,
gamma=0.9,
verbose=1
)
model.learn(

View File

@@ -4,7 +4,8 @@ import torch as th
import numpy as np
import gym
from typing import Callable
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold
from stable_baselines3.common.callbacks import (
EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold)
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
@@ -18,6 +19,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
logger = logging.getLogger(__name__)
def make_env(env_id: str, rank: int, seed: int, train_df, price,
reward_params, window_size, monitor=False) -> Callable:
"""
@@ -39,6 +41,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
set_random_seed(seed)
return _init
class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
"""
User created Reinforcement Learning Model prediction model.
@@ -69,11 +72,22 @@ class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
path = dk.data_path
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=5, min_evals=10, verbose=2)
stop_train_callback = StopTrainingOnNoModelImprovement(
max_no_improvement_evals=5,
min_evals=10,
verbose=2
)
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=2)
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
log_path=f"{path}/tdqn/logs/", eval_freq=int(eval_freq),
deterministic=True, render=True, callback_after_eval=stop_train_callback, callback_on_new_best=callback_on_best, verbose=2)
eval_callback = EvalCallback(
eval_env, best_model_save_path=f"{path}/",
log_path=f"{path}/tdqn/logs/",
eval_freq=int(eval_freq),
deterministic=True,
render=True,
callback_after_eval=stop_train_callback,
callback_on_new_best=callback_on_best,
verbose=2
)
# model arch
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[512, 512, 512])