reorganize/generalize tensorboard callback

This commit is contained in:
robcaulk
2022-12-04 13:54:30 +01:00
parent b2edc58089
commit 24766928ba
7 changed files with 125 additions and 88 deletions

View File

@@ -12,8 +12,7 @@ import pandas as pd
import torch as th
import torch.multiprocessing
from pandas import DataFrame
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.logger import HParam
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv
@@ -22,7 +21,8 @@ from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import Positions
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.persistence import Trade
@@ -45,8 +45,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads)
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, gym.Env] = None
self.eval_env: Union[SubprocVecEnv, gym.Env] = None
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
self.eval_callback: Optional[EvalCallback] = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
@@ -66,6 +66,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
self.unset_outlier_removal()
self.net_arch = self.rl_config.get('net_arch', [128, 128])
self.dd.model_type = import_str
self.tensorboard_callback: TensorboardCallback = \
TensorboardCallback(verbose=1, actions=BaseActions)
def unset_outlier_removal(self):
"""
@@ -157,7 +159,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
self.tensorboard_callback = TensorboardCallback()
actions = self.train_env.get_actions()
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
@abstractmethod
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
@@ -401,51 +404,3 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
return env
set_random_seed(seed)
return _init
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=1):
super(TensorboardCallback, self).__init__(verbose)
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning_rate": self.model.learning_rate,
"gamma": self.model.gamma,
"gae_lambda": self.model.gae_lambda,
"batch_size": self.model.batch_size,
"n_steps": self.model.n_steps,
}
metric_dict = {
"eval/mean_reward": 0,
"rollout/ep_rew_mean": 0,
"rollout/ep_len_mean": 0,
"train/value_loss": 0,
"train/explained_variance": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
custom_info = self.training_env.get_attr("custom_info")[0]
self.logger.record("_state/position", self.locals["infos"][0]["position"])
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"])
self.logger.record("_state/current_profit_pct", self.locals["infos"]
[0]["current_profit_pct"])
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"])
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"])
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"]
[0]["trade_duration"])
self.logger.record("_actions/action", self.locals["infos"][0]["action"])
self.logger.record("_actions/_Invalid", custom_info["Invalid"])
self.logger.record("_actions/_Unknown", custom_info["Unknown"])
self.logger.record("_actions/Hold", custom_info["Hold"])
for action in Actions:
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
return True