stable/freqtrade/freqai/RL/TensorboardCallback.py

58 lines
2.0 KiB
Python
Raw Permalink Normal View History

from enum import Enum
from typing import Any, Dict, Type, Union
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
2022-12-04 13:10:33 +00:00
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard and
episodic summary reports.
"""
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
2023-03-19 16:57:56 +00:00
super().__init__(verbose)
self.model: Any = None
self.logger = None # type: Any
2022-12-04 13:10:33 +00:00
self.training_env: BaseEnvironment = None # type: ignore
self.actions: Type[Enum] = actions
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning_rate": self.model.learning_rate,
# "gamma": self.model.gamma,
# "gae_lambda": self.model.gae_lambda,
# "batch_size": self.model.batch_size,
# "n_steps": self.model.n_steps,
}
metric_dict: Dict[str, Union[float, int]] = {
"eval/mean_reward": 0,
"rollout/ep_rew_mean": 0,
"rollout/ep_len_mean": 0,
"train/value_loss": 0,
"train/explained_variance": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
2022-12-07 11:37:55 +00:00
local_info = self.locals["infos"][0]
2022-12-11 12:37:45 +00:00
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
2022-12-07 11:37:55 +00:00
2023-03-11 22:32:55 +00:00
for metric in local_info:
if metric not in ["episode", "terminal_observation"]:
self.logger.record(f"info/{metric}", local_info[metric])
for category in tensorboard_metrics:
for metric in tensorboard_metrics[category]:
self.logger.record(f"{category}/{metric}", tensorboard_metrics[category][metric])
2022-12-07 11:37:55 +00:00
return True