diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 81f8edfc4..15acde6fb 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -13,9 +13,11 @@ import torch as th import torch.multiprocessing from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.logger import HParam from freqtrade.exceptions import OperationalException from freqtrade.freqai.data_kitchen import FreqaiDataKitchen @@ -155,6 +157,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), best_model_save_path=str(dk.data_path)) + + self.tensorboard_callback = TensorboardCallback() @abstractmethod def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): @@ -398,3 +402,48 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, return env set_random_seed(seed) return _init + +class TensorboardCallback(BaseCallback): + """ + Custom callback for plotting additional values in tensorboard. + """ + def __init__(self, verbose=1): + super(TensorboardCallback, self).__init__(verbose) + + def _on_training_start(self) -> None: + hparam_dict = { + "algorithm": self.model.__class__.__name__, + "learning_rate": self.model.learning_rate, + "gamma": self.model.gamma, + "gae_lambda": self.model.gae_lambda, + "batch_size": self.model.batch_size, + "n_steps": self.model.n_steps, + } + metric_dict = { + "eval/mean_reward": 0, + "rollout/ep_rew_mean": 0, + "rollout/ep_len_mean":0 , + "train/value_loss": 0, + "train/explained_variance": 0, + } + self.logger.record( + "hparams", + HParam(hparam_dict, metric_dict), + exclude=("stdout", "log", "json", "csv"), + ) + + def _on_step(self) -> bool: + custom_info = self.training_env.get_attr("custom_info")[0] + self.logger.record(f"_state/position", self.locals["infos"][0]["position"]) + self.logger.record(f"_state/trade_duration", self.locals["infos"][0]["trade_duration"]) + self.logger.record(f"_state/current_profit_pct", self.locals["infos"][0]["current_profit_pct"]) + self.logger.record(f"_reward/total_profit", self.locals["infos"][0]["total_profit"]) + self.logger.record(f"_reward/total_reward", self.locals["infos"][0]["total_reward"]) + self.logger.record_mean(f"_reward/mean_trade_duration", self.locals["infos"][0]["trade_duration"]) + self.logger.record(f"_actions/action", self.locals["infos"][0]["action"]) + self.logger.record(f"_actions/_Invalid", custom_info["Invalid"]) + self.logger.record(f"_actions/_Unknown", custom_info["Unknown"]) + self.logger.record(f"_actions/Hold", custom_info["Hold"]) + for action in Actions: + self.logger.record(f"_actions/{action.name}", custom_info[action.name]) + return True \ No newline at end of file