From 0475b7cb1838f46bcbb31771eca0b3c3cb6ed940 Mon Sep 17 00:00:00 2001 From: sonnhfit Date: Tue, 16 Aug 2022 09:30:35 +0700 Subject: [PATCH] remove unuse code and fix coding conventions --- freqtrade/freqai/RL/Base3ActionRLEnv.py | 7 ------ freqtrade/freqai/RL/Base5ActionRLEnv.py | 14 ----------- .../ReinforcementLearningExample3ac.py | 1 - .../ReinforcementLearningExample5ac.py | 1 - .../ReinforcementLearningPPO_multiproc.py | 5 +++- .../ReinforcementLearningTDQN_multiproc.py | 24 +++++++++++++++---- 6 files changed, 23 insertions(+), 29 deletions(-) diff --git a/freqtrade/freqai/RL/Base3ActionRLEnv.py b/freqtrade/freqai/RL/Base3ActionRLEnv.py index 5e8bff024..bf7b2fc7b 100644 --- a/freqtrade/freqai/RL/Base3ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base3ActionRLEnv.py @@ -71,10 +71,6 @@ class Base3ActionRLEnv(gym.Env): self.history = None self.trade_history = [] - self.r_t_change = 0. - - self.returns_report = [] - def seed(self, seed: int = 1): self.np_random, seed = seeding.np_random(seed) return [seed] @@ -101,9 +97,6 @@ class Base3ActionRLEnv(gym.Env): self._profits = [(self._start_tick, 1)] self.close_trade_profit = [] - self.r_t_change = 0. - - self.returns_report = [] return self._get_observation() diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index 01fb77481..00b031e54 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -73,11 +73,6 @@ class Base5ActionRLEnv(gym.Env): self.history = None self.trade_history = [] - # self.A_t, self.B_t = 0.000639, 0.00001954 - self.r_t_change = 0. - - self.returns_report = [] - def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] @@ -104,9 +99,6 @@ class Base5ActionRLEnv(gym.Env): self._profits = [(self._start_tick, 1)] self.close_trade_profit = [] - self.r_t_change = 0. - - self.returns_report = [] return self._get_observation() @@ -178,12 +170,6 @@ class Base5ActionRLEnv(gym.Env): return observation, step_reward, self._done, info - # def processState(self, state): - # return state.to_numpy() - - # def convert_mlp_Policy(self, obs_): - # pass - def _get_observation(self): return self.signal_features[(self._current_tick - self.window_size):self._current_tick] diff --git a/freqtrade/freqai/example_strats/ReinforcementLearningExample3ac.py b/freqtrade/freqai/example_strats/ReinforcementLearningExample3ac.py index 1976620fb..be7a8973b 100644 --- a/freqtrade/freqai/example_strats/ReinforcementLearningExample3ac.py +++ b/freqtrade/freqai/example_strats/ReinforcementLearningExample3ac.py @@ -62,7 +62,6 @@ class ReinforcementLearningExample3ac(IStrategy): coin = pair.split('/')[0] - if informative is None: informative = self.dp.get_pair_dataframe(pair, tf) diff --git a/freqtrade/freqai/example_strats/ReinforcementLearningExample5ac.py b/freqtrade/freqai/example_strats/ReinforcementLearningExample5ac.py index 8c19cc0fa..0ecea92a9 100644 --- a/freqtrade/freqai/example_strats/ReinforcementLearningExample5ac.py +++ b/freqtrade/freqai/example_strats/ReinforcementLearningExample5ac.py @@ -62,7 +62,6 @@ class ReinforcementLearningExample5ac(IStrategy): coin = pair.split('/')[0] - if informative is None: informative = self.dp.get_pair_dataframe(pair, tf) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py index e8f67cbb8..26099a9e3 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearningPPO_multiproc.py @@ -81,7 +81,10 @@ class ReinforcementLearningPPO_multiproc(BaseReinforcementLearningModel): net_arch=[512, 512, 512]) model = PPO('MlpPolicy', train_env, policy_kwargs=policy_kwargs, - tensorboard_log=f"{path}/ppo/tensorboard/", learning_rate=learning_rate, gamma=0.9, verbose=1 + tensorboard_log=f"{path}/ppo/tensorboard/", + learning_rate=learning_rate, + gamma=0.9, + verbose=1 ) model.learn( diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN_multiproc.py index d05184d87..dd34c96c1 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN_multiproc.py @@ -4,7 +4,8 @@ import torch as th import numpy as np import gym from typing import Callable -from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold +from stable_baselines3.common.callbacks import ( + EvalCallback, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold) from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.utils import set_random_seed @@ -18,6 +19,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen logger = logging.getLogger(__name__) + def make_env(env_id: str, rank: int, seed: int, train_df, price, reward_params, window_size, monitor=False) -> Callable: """ @@ -39,6 +41,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price, set_random_seed(seed) return _init + class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel): """ User created Reinforcement Learning Model prediction model. @@ -69,11 +72,22 @@ class ReinforcementLearningTDQN_multiproc(BaseReinforcementLearningModel): self.CONV_WIDTH, monitor=True) for i in range(num_cpu)]) path = dk.data_path - stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=5, min_evals=10, verbose=2) + stop_train_callback = StopTrainingOnNoModelImprovement( + max_no_improvement_evals=5, + min_evals=10, + verbose=2 + ) callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=2) - eval_callback = EvalCallback(eval_env, best_model_save_path=f"{path}/", - log_path=f"{path}/tdqn/logs/", eval_freq=int(eval_freq), - deterministic=True, render=True, callback_after_eval=stop_train_callback, callback_on_new_best=callback_on_best, verbose=2) + eval_callback = EvalCallback( + eval_env, best_model_save_path=f"{path}/", + log_path=f"{path}/tdqn/logs/", + eval_freq=int(eval_freq), + deterministic=True, + render=True, + callback_after_eval=stop_train_callback, + callback_on_new_best=callback_on_best, + verbose=2 + ) # model arch policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[512, 512, 512])