cleanup tensorboard callback

This commit is contained in:
initrv
2022-12-07 14:37:55 +03:00
parent b9f6911a6a
commit 58604c747e
3 changed files with 21 additions and 30 deletions

View File

@@ -137,15 +137,9 @@ class BaseEnvironment(gym.Env):
Reset is called at the beginning of every episode
"""
# custom_info is used for episodic reports and tensorboard logging
self.custom_info["Invalid"] = 0
self.custom_info["Hold"] = 0
self.custom_info["Unknown"] = 0
self.custom_info["pnl_factor"] = 0
self.custom_info["duration_factor"] = 0
self.custom_info["reward_exit"] = 0
self.custom_info["reward_hold"] = 0
self.custom_info: dict = {}
for action in self.actions:
self.custom_info[f"{action.name}"] = 0
self.custom_info[action.name] = 0
self._done = False

View File

@@ -42,19 +42,18 @@ class TensorboardCallback(BaseCallback):
)
def _on_step(self) -> bool:
local_info = self.locals["infos"][0]
custom_info = self.training_env.get_attr("custom_info")[0]
self.logger.record("_state/position", self.locals["infos"][0]["position"])
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"])
self.logger.record("_state/current_profit_pct", self.locals["infos"]
[0]["current_profit_pct"])
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"])
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"])
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"]
[0]["trade_duration"])
self.logger.record("_actions/action", self.locals["infos"][0]["action"])
self.logger.record("_actions/_Invalid", custom_info["Invalid"])
self.logger.record("_actions/_Unknown", custom_info["Unknown"])
self.logger.record("_actions/Hold", custom_info["Hold"])
for action in self.actions:
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
for info in local_info:
if info not in ["episode", "terminal_observation"]:
self.logger.record(f"_info/{info}", local_info[info])
for info in custom_info:
if info in [action.name for action in self.actions]:
self.logger.record(f"_actions/{info}", custom_info[info])
else:
self.logger.record(f"_custom/{info}", custom_info[info])
return True