custom info to tensorboard_metrics
This commit is contained in:
@@ -46,9 +46,9 @@ class Base4ActionRLEnv(BaseEnvironment):
|
||||
self._done = True
|
||||
|
||||
self._update_unrealized_total_profit()
|
||||
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
self.tensorboard_metrics[self.actions._member_names_[action]] += 1
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
|
@@ -49,6 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment):
|
||||
self._update_unrealized_total_profit()
|
||||
step_reward = self.calculate_reward(action)
|
||||
self.total_reward += step_reward
|
||||
self.tensorboard_metrics[self.actions._member_names_[action]] += 1
|
||||
|
||||
trade_type = None
|
||||
if self.is_tradesignal(action):
|
||||
|
@@ -77,7 +77,7 @@ class BaseEnvironment(gym.Env):
|
||||
|
||||
# set here to default 5Ac, but all children envs can override this
|
||||
self.actions: Type[Enum] = BaseActions
|
||||
self.custom_info: dict = {}
|
||||
self.tensorboard_metrics: dict = {}
|
||||
|
||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||
reward_kwargs: dict, starting_point=True):
|
||||
@@ -136,10 +136,10 @@ class BaseEnvironment(gym.Env):
|
||||
"""
|
||||
Reset is called at the beginning of every episode
|
||||
"""
|
||||
# custom_info is used for episodic reports and tensorboard logging
|
||||
self.custom_info: dict = {}
|
||||
# tensorboard_metrics is used for episodic reports and tensorboard logging
|
||||
self.tensorboard_metrics: dict = {}
|
||||
for action in self.actions:
|
||||
self.custom_info[action.name] = 0
|
||||
self.tensorboard_metrics[action.name] = 0
|
||||
|
||||
self._done = False
|
||||
|
||||
|
@@ -44,16 +44,16 @@ class TensorboardCallback(BaseCallback):
|
||||
def _on_step(self) -> bool:
|
||||
|
||||
local_info = self.locals["infos"][0]
|
||||
custom_info = self.training_env.get_attr("custom_info")[0]
|
||||
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||
|
||||
for info in local_info:
|
||||
if info not in ["episode", "terminal_observation"]:
|
||||
self.logger.record(f"_info/{info}", local_info[info])
|
||||
|
||||
for info in custom_info:
|
||||
for info in tensorboard_metrics:
|
||||
if info in [action.name for action in self.actions]:
|
||||
self.logger.record(f"_actions/{info}", custom_info[info])
|
||||
self.logger.record(f"_actions/{info}", tensorboard_metrics[info])
|
||||
else:
|
||||
self.logger.record(f"_custom/{info}", custom_info[info])
|
||||
self.logger.record(f"_custom/{info}", tensorboard_metrics[info])
|
||||
|
||||
return True
|
||||
|
Reference in New Issue
Block a user