add increment param for tensorboard_log
This commit is contained in:
parent
0f6b98b69a
commit
f9b7d35900
@ -139,7 +139,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
def tensorboard_log(self, metric: str, inc: Union[int, float] = 1):
|
def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True):
|
||||||
"""
|
"""
|
||||||
Function builds the tensorboard_metrics dictionary
|
Function builds the tensorboard_metrics dictionary
|
||||||
to be parsed by the TensorboardCallback. This
|
to be parsed by the TensorboardCallback. This
|
||||||
@ -155,12 +155,13 @@ class BaseEnvironment(gym.Env):
|
|||||||
return -2
|
return -2
|
||||||
|
|
||||||
:param metric: metric to be tracked and incremented
|
:param metric: metric to be tracked and incremented
|
||||||
:param inc: value to increment `metric` by
|
:param value: value to increment `metric` by
|
||||||
|
:param inc: sets whether the `value` is incremented or not
|
||||||
"""
|
"""
|
||||||
if metric not in self.tensorboard_metrics:
|
if not inc or metric not in self.tensorboard_metrics:
|
||||||
self.tensorboard_metrics[metric] = inc
|
self.tensorboard_metrics[metric] = value
|
||||||
else:
|
else:
|
||||||
self.tensorboard_metrics[metric] += inc
|
self.tensorboard_metrics[metric] += value
|
||||||
|
|
||||||
def reset_tensorboard_log(self):
|
def reset_tensorboard_log(self):
|
||||||
self.tensorboard_metrics = {}
|
self.tensorboard_metrics = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user