58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
from enum import Enum
|
|
from typing import Any, Dict, Type, Union
|
|
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
from stable_baselines3.common.logger import HParam
|
|
|
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
|
|
|
|
|
|
class TensorboardCallback(BaseCallback):
|
|
"""
|
|
Custom callback for plotting additional values in tensorboard and
|
|
episodic summary reports.
|
|
"""
|
|
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
|
|
super(TensorboardCallback, self).__init__(verbose)
|
|
self.model: Any = None
|
|
self.logger = None # type: Any
|
|
self.training_env: BaseEnvironment = None # type: ignore
|
|
self.actions: Type[Enum] = actions
|
|
|
|
def _on_training_start(self) -> None:
|
|
hparam_dict = {
|
|
"algorithm": self.model.__class__.__name__,
|
|
"learning_rate": self.model.learning_rate,
|
|
# "gamma": self.model.gamma,
|
|
# "gae_lambda": self.model.gae_lambda,
|
|
# "batch_size": self.model.batch_size,
|
|
# "n_steps": self.model.n_steps,
|
|
}
|
|
metric_dict: Dict[str, Union[float, int]] = {
|
|
"eval/mean_reward": 0,
|
|
"rollout/ep_rew_mean": 0,
|
|
"rollout/ep_len_mean": 0,
|
|
"train/value_loss": 0,
|
|
"train/explained_variance": 0,
|
|
}
|
|
self.logger.record(
|
|
"hparams",
|
|
HParam(hparam_dict, metric_dict),
|
|
exclude=("stdout", "log", "json", "csv"),
|
|
)
|
|
|
|
def _on_step(self) -> bool:
|
|
|
|
local_info = self.locals["infos"][0]
|
|
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
|
|
|
for metric in local_info:
|
|
if metric not in ["episode", "terminal_observation"]:
|
|
self.logger.record(f"info/{metric}", local_info[metric])
|
|
|
|
for category in tensorboard_metrics:
|
|
for metric in tensorboard_metrics[category]:
|
|
self.logger.record(f"{category}/{metric}", tensorboard_metrics[category][metric])
|
|
|
|
return True
|