remove unuse code and fix coding conventions

This commit is contained in:
sonnhfit 2022-08-16 09:30:35 +07:00 committed by robcaulk
parent d60a166fbf
commit 0475b7cb18
6 changed files with 23 additions and 29 deletions

View File

@ -71,10 +71,6 @@ class Base3ActionRLEnv(gym.Env):
self.history = None
self.trade_history = []
self.r_t_change = 0.
self.returns_report = []
def seed(self, seed: int = 1):
self.np_random, seed = seeding.np_random(seed)
return [seed]
@ -101,9 +97,6 @@ class Base3ActionRLEnv(gym.Env):
self._profits = [(self._start_tick, 1)]
self.close_trade_profit = []
self.r_t_change = 0.
self.returns_report = []
return self._get_observation()

View File

@ -73,11 +73,6 @@ class Base5ActionRLEnv(gym.Env):
self.history = None
self.trade_history = []
# self.A_t, self.B_t = 0.000639, 0.00001954
self.r_t_change = 0.
self.returns_report = []
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
@ -104,9 +99,6 @@ class Base5ActionRLEnv(gym.Env):
self._profits = [(self._start_tick, 1)]
self.close_trade_profit = []
self.r_t_change = 0.
self.returns_report = []
return self._get_observation()
@ -178,12 +170,6 @@ class Base5ActionRLEnv(gym.Env):
return observation, step_reward, self._done, info
# def processState(self, state):
# return state.to_numpy()
# def convert_mlp_Policy(self, obs_):
# pass
def _get_observation(self):
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]

View File

@ -62,7 +62,6 @@ class ReinforcementLearningExample3ac(IStrategy):
coin = pair.split('/')[0]
if informative is None:
informative = self.dp.get_pair_dataframe(pair, tf)

View File

@ -62,7 +62,6 @@ class ReinforcementLearningExample5ac(IStrategy):
coin = pair.split('/')[0]
if informative is None:
informative = self.dp.get_pair_dataframe(pair, tf)

View File

@ -81,7 +81,10 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
net_arch=[512, 512, 512])
model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs,
tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=learning_rate, gamma=0.9, verbose=1
tensorboard_log=f"{path}/ppo/tensorboard/",
learning_rate=learning_rate,
gamma=0.9,
verbose=1
)
model.learn(

View File

@ -4,7 +4,8 @@ import torch as th
import numpy as np
import gym
from typing import Callable
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold
from stable_baselines3.common.callbacks import (
EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold)
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
@ -18,6 +19,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
logger = logging.getLogger(__name__)
def make_env(env_id: str, rank: int, seed: int, train_df, price,
reward_params, window_size, monitor=False) -> Callable:
"""
@ -39,6 +41,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
set_random_seed(seed)
return _init
class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
"""
User created Reinforcement Learning Model prediction model.
@ -69,11 +72,22 @@ class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
path = dk.data_path
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=5, min_evals=10, verbose=2)
stop_train_callback = StopTrainingOnNoModelImprovement(
max_no_improvement_evals=5,
min_evals=10,
verbose=2
)
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=2)
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
log_path=f"{path}/tdqn/logs/", eval_freq=int(eval_freq),
deterministic=True, render=True, callback_after_eval=stop_train_callback, callback_on_new_best=callback_on_best, verbose=2)
eval_callback = EvalCallback(
eval_env, best_model_save_path=f"{path}/",
log_path=f"{path}/tdqn/logs/",
eval_freq=int(eval_freq),
deterministic=True,
render=True,
callback_after_eval=stop_train_callback,
callback_on_new_best=callback_on_best,
verbose=2
)
# model arch
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[512, 512, 512])