fix flake8

This commit is contained in:
smarmau 2022-12-03 22:31:02 +11:00 committed by GitHub
parent d6f45a12ae
commit b2edc58089
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,12 +12,11 @@ import pandas as pd
import torch as th
import torch.multiprocessing
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.logger import HParam
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.logger import HParam
from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
@ -403,6 +402,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
set_random_seed(seed)
return _init
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
@ -434,16 +434,18 @@ class TensorboardCallback(BaseCallback):
def _on_step(self) -> bool:
custom_info = self.training_env.get_attr("custom_info")[0]
self.logger.record(f"_state/position", self.locals["infos"][0]["position"])
self.logger.record(f"_state/trade_duration", self.locals["infos"][0]["trade_duration"])
self.logger.record(f"_state/current_profit_pct", self.locals["infos"][0]["current_profit_pct"])
self.logger.record(f"_reward/total_profit", self.locals["infos"][0]["total_profit"])
self.logger.record(f"_reward/total_reward", self.locals["infos"][0]["total_reward"])
self.logger.record_mean(f"_reward/mean_trade_duration", self.locals["infos"][0]["trade_duration"])
self.logger.record(f"_actions/action", self.locals["infos"][0]["action"])
self.logger.record(f"_actions/_Invalid", custom_info["Invalid"])
self.logger.record(f"_actions/_Unknown", custom_info["Unknown"])
self.logger.record(f"_actions/Hold", custom_info["Hold"])
self.logger.record("_state/position", self.locals["infos"][0]["position"])
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"])
self.logger.record("_state/current_profit_pct", self.locals["infos"]
[0]["current_profit_pct"])
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"])
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"])
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"]
[0]["trade_duration"])
self.logger.record("_actions/action", self.locals["infos"][0]["action"])
self.logger.record("_actions/_Invalid", custom_info["Invalid"])
self.logger.record("_actions/_Unknown", custom_info["Unknown"])
self.logger.record("_actions/Hold", custom_info["Hold"])
for action in Actions:
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
return True