remove unuse code and fix coding conventions
This commit is contained in:
parent
d60a166fbf
commit
0475b7cb18
@ -71,10 +71,6 @@ class Base3ActionRLEnv(gym.Env):
|
|||||||
self.history = None
|
self.history = None
|
||||||
self.trade_history = []
|
self.trade_history = []
|
||||||
|
|
||||||
self.r_t_change = 0.
|
|
||||||
|
|
||||||
self.returns_report = []
|
|
||||||
|
|
||||||
def seed(self, seed: int = 1):
|
def seed(self, seed: int = 1):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
@ -101,9 +97,6 @@ class Base3ActionRLEnv(gym.Env):
|
|||||||
|
|
||||||
self._profits = [(self._start_tick, 1)]
|
self._profits = [(self._start_tick, 1)]
|
||||||
self.close_trade_profit = []
|
self.close_trade_profit = []
|
||||||
self.r_t_change = 0.
|
|
||||||
|
|
||||||
self.returns_report = []
|
|
||||||
|
|
||||||
return self._get_observation()
|
return self._get_observation()
|
||||||
|
|
||||||
|
@ -73,11 +73,6 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
self.history = None
|
self.history = None
|
||||||
self.trade_history = []
|
self.trade_history = []
|
||||||
|
|
||||||
# self.A_t, self.B_t = 0.000639, 0.00001954
|
|
||||||
self.r_t_change = 0.
|
|
||||||
|
|
||||||
self.returns_report = []
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
@ -104,9 +99,6 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
|
|
||||||
self._profits = [(self._start_tick, 1)]
|
self._profits = [(self._start_tick, 1)]
|
||||||
self.close_trade_profit = []
|
self.close_trade_profit = []
|
||||||
self.r_t_change = 0.
|
|
||||||
|
|
||||||
self.returns_report = []
|
|
||||||
|
|
||||||
return self._get_observation()
|
return self._get_observation()
|
||||||
|
|
||||||
@ -178,12 +170,6 @@ class Base5ActionRLEnv(gym.Env):
|
|||||||
|
|
||||||
return observation, step_reward, self._done, info
|
return observation, step_reward, self._done, info
|
||||||
|
|
||||||
# def processState(self, state):
|
|
||||||
# return state.to_numpy()
|
|
||||||
|
|
||||||
# def convert_mlp_Policy(self, obs_):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
def _get_observation(self):
|
def _get_observation(self):
|
||||||
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
|
return self.signal_features[(self._current_tick - self.window_size):self._current_tick]
|
||||||
|
|
||||||
|
@ -62,7 +62,6 @@ class ReinforcementLearningExample3ac(IStrategy):
|
|||||||
|
|
||||||
coin = pair.split('/')[0]
|
coin = pair.split('/')[0]
|
||||||
|
|
||||||
|
|
||||||
if informative is None:
|
if informative is None:
|
||||||
informative = self.dp.get_pair_dataframe(pair, tf)
|
informative = self.dp.get_pair_dataframe(pair, tf)
|
||||||
|
|
||||||
|
@ -62,7 +62,6 @@ class ReinforcementLearningExample5ac(IStrategy):
|
|||||||
|
|
||||||
coin = pair.split('/')[0]
|
coin = pair.split('/')[0]
|
||||||
|
|
||||||
|
|
||||||
if informative is None:
|
if informative is None:
|
||||||
informative = self.dp.get_pair_dataframe(pair, tf)
|
informative = self.dp.get_pair_dataframe(pair, tf)
|
||||||
|
|
||||||
|
@ -81,7 +81,10 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel):
|
|||||||
net_arch=[512, 512, 512])
|
net_arch=[512, 512, 512])
|
||||||
|
|
||||||
model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs,
|
model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs,
|
||||||
tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=learning_rate, gamma=0.9, verbose=1
|
tensorboard_log=f"{path}/ppo/tensorboard/",
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
gamma=0.9,
|
||||||
|
verbose=1
|
||||||
)
|
)
|
||||||
|
|
||||||
model.learn(
|
model.learn(
|
||||||
|
@ -4,7 +4,8 @@ import torch as th
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold
|
from stable_baselines3.common.callbacks import (
|
||||||
|
EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold)
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
@ -18,6 +19,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||||
reward_params, window_size, monitor=False) -> Callable:
|
reward_params, window_size, monitor=False) -> Callable:
|
||||||
"""
|
"""
|
||||||
@ -39,6 +41,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
|||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
return _init
|
return _init
|
||||||
|
|
||||||
|
|
||||||
class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
|
class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
|
||||||
"""
|
"""
|
||||||
User created Reinforcement Learning Model prediction model.
|
User created Reinforcement Learning Model prediction model.
|
||||||
@ -69,11 +72,22 @@ class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel):
|
|||||||
self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
|
self.CONV_WIDTH, monitor=True) for i in range(num_cpu)])
|
||||||
|
|
||||||
path = dk.data_path
|
path = dk.data_path
|
||||||
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=5, min_evals=10, verbose=2)
|
stop_train_callback = StopTrainingOnNoModelImprovement(
|
||||||
|
max_no_improvement_evals=5,
|
||||||
|
min_evals=10,
|
||||||
|
verbose=2
|
||||||
|
)
|
||||||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=2)
|
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=2)
|
||||||
eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/",
|
eval_callback = EvalCallback(
|
||||||
log_path=f"{path}/tdqn/logs/", eval_freq=int(eval_freq),
|
eval_env, best_model_save_path=f"{path}/",
|
||||||
deterministic=True, render=True, callback_after_eval=stop_train_callback, callback_on_new_best=callback_on_best, verbose=2)
|
log_path=f"{path}/tdqn/logs/",
|
||||||
|
eval_freq=int(eval_freq),
|
||||||
|
deterministic=True,
|
||||||
|
render=True,
|
||||||
|
callback_after_eval=stop_train_callback,
|
||||||
|
callback_on_new_best=callback_on_best,
|
||||||
|
verbose=2
|
||||||
|
)
|
||||||
# model arch
|
# model arch
|
||||||
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
policy_kwargs = dict(activation_fn=th.nn.ReLU,
|
||||||
net_arch=[512, 512, 512])
|
net_arch=[512, 512, 512])
|
||||||
|
Loading…
Reference in New Issue
Block a user