custom info to tensorboard_metrics

This commit is contained in:
initrv
2022-12-11 15:37:45 +03:00
parent 58604c747e
commit cb8fc3c8c7
5 changed files with 10 additions and 15 deletions

View File

@@ -44,16 +44,16 @@ class TensorboardCallback(BaseCallback):
def _on_step(self) -> bool:
local_info = self.locals["infos"][0]
custom_info = self.training_env.get_attr("custom_info")[0]
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
for info in local_info:
if info not in ["episode", "terminal_observation"]:
self.logger.record(f"_info/{info}", local_info[info])
for info in custom_info:
for info in tensorboard_metrics:
if info in [action.name for action in self.actions]:
self.logger.record(f"_actions/{info}", custom_info[info])
self.logger.record(f"_actions/{info}", tensorboard_metrics[info])
else:
self.logger.record(f"_custom/{info}", custom_info[info])
self.logger.record(f"_custom/{info}", tensorboard_metrics[info])
return True