fix flake8
This commit is contained in:
parent
d6f45a12ae
commit
b2edc58089
@ -12,12 +12,11 @@ import pandas as pd
|
|||||||
import torch as th
|
import torch as th
|
||||||
import torch.multiprocessing
|
import torch.multiprocessing
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.logger import HParam
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
from stable_baselines3.common.logger import HParam
|
|
||||||
|
|
||||||
from freqtrade.exceptions import OperationalException
|
from freqtrade.exceptions import OperationalException
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
@ -157,7 +156,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
|
|
||||||
self.tensorboard_callback = TensorboardCallback()
|
self.tensorboard_callback = TensorboardCallback()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -403,6 +402,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
|||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
return _init
|
return _init
|
||||||
|
|
||||||
|
|
||||||
class TensorboardCallback(BaseCallback):
|
class TensorboardCallback(BaseCallback):
|
||||||
"""
|
"""
|
||||||
Custom callback for plotting additional values in tensorboard.
|
Custom callback for plotting additional values in tensorboard.
|
||||||
@ -422,7 +422,7 @@ class TensorboardCallback(BaseCallback):
|
|||||||
metric_dict = {
|
metric_dict = {
|
||||||
"eval/mean_reward": 0,
|
"eval/mean_reward": 0,
|
||||||
"rollout/ep_rew_mean": 0,
|
"rollout/ep_rew_mean": 0,
|
||||||
"rollout/ep_len_mean":0 ,
|
"rollout/ep_len_mean": 0,
|
||||||
"train/value_loss": 0,
|
"train/value_loss": 0,
|
||||||
"train/explained_variance": 0,
|
"train/explained_variance": 0,
|
||||||
}
|
}
|
||||||
@ -431,19 +431,21 @@ class TensorboardCallback(BaseCallback):
|
|||||||
HParam(hparam_dict, metric_dict),
|
HParam(hparam_dict, metric_dict),
|
||||||
exclude=("stdout", "log", "json", "csv"),
|
exclude=("stdout", "log", "json", "csv"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
custom_info = self.training_env.get_attr("custom_info")[0]
|
custom_info = self.training_env.get_attr("custom_info")[0]
|
||||||
self.logger.record(f"_state/position", self.locals["infos"][0]["position"])
|
self.logger.record("_state/position", self.locals["infos"][0]["position"])
|
||||||
self.logger.record(f"_state/trade_duration", self.locals["infos"][0]["trade_duration"])
|
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"])
|
||||||
self.logger.record(f"_state/current_profit_pct", self.locals["infos"][0]["current_profit_pct"])
|
self.logger.record("_state/current_profit_pct", self.locals["infos"]
|
||||||
self.logger.record(f"_reward/total_profit", self.locals["infos"][0]["total_profit"])
|
[0]["current_profit_pct"])
|
||||||
self.logger.record(f"_reward/total_reward", self.locals["infos"][0]["total_reward"])
|
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"])
|
||||||
self.logger.record_mean(f"_reward/mean_trade_duration", self.locals["infos"][0]["trade_duration"])
|
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"])
|
||||||
self.logger.record(f"_actions/action", self.locals["infos"][0]["action"])
|
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"]
|
||||||
self.logger.record(f"_actions/_Invalid", custom_info["Invalid"])
|
[0]["trade_duration"])
|
||||||
self.logger.record(f"_actions/_Unknown", custom_info["Unknown"])
|
self.logger.record("_actions/action", self.locals["infos"][0]["action"])
|
||||||
self.logger.record(f"_actions/Hold", custom_info["Hold"])
|
self.logger.record("_actions/_Invalid", custom_info["Invalid"])
|
||||||
|
self.logger.record("_actions/_Unknown", custom_info["Unknown"])
|
||||||
|
self.logger.record("_actions/Hold", custom_info["Hold"])
|
||||||
for action in Actions:
|
for action in Actions:
|
||||||
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
|
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
|
||||||
return True
|
return True
|
||||||
|
Loading…
Reference in New Issue
Block a user