remove unuse code and fix coding conventions
This commit is contained in:
parent
d60a166fbf
commit
0475b7cb18
@ -71,10 +71,6 @@ class Base3ActionRLEnv(gym.Env):
|
||||
self.history = None
|
||||
self.trade_history = []
|
||||
|
||||
self.r_t_change = 0.
|
||||
|
||||
self.returns_report = []
|
||||
|
||||
def seed(self, seed: int = 1):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
@ -101,9 +97,6 @@ class Base3ActionRLEnv(gym.Env):
|
||||
|
||||
self._profits = [(self._start_tick, 1)]
|
||||
self.close_trade_profit = []
|
||||
self.r_t_change = 0.
|
||||
|
||||
self.returns_report = []
|
||||
|
||||
return self._get_observation()
|
||||
|
||||
|
@ -73,11 +73,6 @@ class Base5ActionRLEnv(gym.Env):
|
||||
self.history = None
|
||||
self.trade_history = []
|
||||
|
||||
# self.A_t, self.B_t = 0.000639, 0.00001954
|
||||
self.r_t_change = 0.
|
||||
|
||||
self.returns_report = []
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
@ -104,9 +99,6 @@ class Base5ActionRLEnv(gym.Env):
|
||||
|
||||
self._profits = [(self._start_tick, 1)]
|
||||
self.close_trade_profit = []
|
||||
self.r_t_change = 0.
|
||||
|
||||
self.returns_report = []
|
||||
|
||||
return self._get_observation()
|
||||
|
||||
@ -178,12 +170,6 @@ class Base5ActionRLEnv(gym.Env):
|
||||
|
||||
return observation, step_reward, self._done, info
|
||||
|
||||
# def processState(self, state):
|
||||
# return state.to_numpy()
|
||||
|
||||
# def convert_mlp_Policy(self, obs_):
|
||||
# pass
|
||||
|
||||
def _get_observation(self):
|
||||
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
|
||||
|
||||
|
@ -62,7 +62,6 @@ class ReinforcementLearningExample3ac(IStrategy):
|
||||
|
||||
coin = pair.split('/')[0]
|
||||
|
||||
|
||||
if informative is None:
|
||||
informative = self.dp.get_pair_dataframe(pair, tf)
|
||||
|
||||
|
@ -62,7 +62,6 @@ class ReinforcementLearningExample5ac(IStrategy):
|
||||
|
||||
coin = pair.split('/')[0]
|
||||
|
||||
|
||||
if informative is None:
|
||||
informative = self.dp.get_pair_dataframe(pair, tf)
|
||||
|
||||
|
@ -81,7 +81,10 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
||||
net_arch=[512, 512, 512])
|
||||
|
||||
model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=learning_rate, gamma=0.9, verbose=1
|
||||
tensorboard_log=f"{path}/ppo/tensorboard/",
|
||||
learning_rate=learning_rate,
|
||||
gamma=0.9,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
model.learn(
|
||||
|
@ -4,7 +4,8 @@ import torch as th
|
||||
import numpy as np
|
||||
import gym
|
||||
from typing import Callable
|
||||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold
|
||||
from stable_baselines3.common.callbacks import (
|
||||
EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold)
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
@ -18,6 +19,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
reward_params, window_size, monitor=False) -> Callable:
|
||||
"""
|
||||
@ -39,6 +41,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
set_random_seed(seed)
|
||||
return _init
|
||||
|
||||
|
||||
class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
|
||||
"""
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
@ -69,11 +72,22 @@ class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
|
||||
self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
|
||||
|
||||
path = dk.data_path
|
||||
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=5, min_evals=10, verbose=2)
|
||||
stop_train_callback = StopTrainingOnNoModelImprovement(
|
||||
max_no_improvement_evals=5,
|
||||
min_evals=10,
|
||||
verbose=2
|
||||
)
|
||||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=2)
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
|
||||
log_path=f"{path}/tdqn/logs/", eval_freq=int(eval_freq),
|
||||
deterministic=True, render=True, callback_after_eval=stop_train_callback, callback_on_new_best=callback_on_best, verbose=2)
|
||||
eval_callback = EvalCallback(
|
||||
eval_env, best_model_save_path=f"{path}/",
|
||||
log_path=f"{path}/tdqn/logs/",
|
||||
eval_freq=int(eval_freq),
|
||||
deterministic=True,
|
||||
render=True,
|
||||
callback_after_eval=stop_train_callback,
|
||||
callback_on_new_best=callback_on_best,
|
||||
verbose=2
|
||||
)
|
||||
# model arch
|
||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||
net_arch=[512, 512, 512])
|
||||
|
Loading…
Reference in New Issue
Block a user