add increment param for tensorboard_log

This commit is contained in:
initrv 2022-12-12 14:14:23 +03:00
parent 0f6b98b69a
commit f9b7d35900

View File

@ -139,7 +139,7 @@ class BaseEnvironment(gym.Env):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
return [seed] return [seed]
def tensorboard_log(self, metric: str, inc: Union[int, float] = 1): def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True):
""" """
Function builds the tensorboard_metrics dictionary Function builds the tensorboard_metrics dictionary
to be parsed by the TensorboardCallback. This to be parsed by the TensorboardCallback. This
@ -155,12 +155,13 @@ class BaseEnvironment(gym.Env):
return -2 return -2
:param metric: metric to be tracked and incremented :param metric: metric to be tracked and incremented
:param inc: value to increment `metric` by :param value: value to increment `metric` by
:param inc: sets whether the `value` is incremented or not
""" """
if metric not in self.tensorboard_metrics: if not inc or metric not in self.tensorboard_metrics:
self.tensorboard_metrics[metric] = inc self.tensorboard_metrics[metric] = value
else: else:
self.tensorboard_metrics[metric] += inc self.tensorboard_metrics[metric] += value
def reset_tensorboard_log(self): def reset_tensorboard_log(self):
self.tensorboard_metrics = {} self.tensorboard_metrics = {}