diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 5a90d381e..5a5a950e7 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -139,7 +139,7 @@ class BaseEnvironment(gym.Env): self.np_random, seed = seeding.np_random(seed) return [seed] - def tensorboard_log(self, metric: str, inc: Union[int, float] = 1): + def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True): """ Function builds the tensorboard_metrics dictionary to be parsed by the TensorboardCallback. This @@ -155,12 +155,13 @@ class BaseEnvironment(gym.Env): return -2 :param metric: metric to be tracked and incremented - :param inc: value to increment `metric` by + :param value: value to increment `metric` by + :param inc: sets whether the `value` is incremented or not """ - if metric not in self.tensorboard_metrics: - self.tensorboard_metrics[metric] = inc + if not inc or metric not in self.tensorboard_metrics: + self.tensorboard_metrics[metric] = value else: - self.tensorboard_metrics[metric] += inc + self.tensorboard_metrics[metric] += value def reset_tensorboard_log(self): self.tensorboard_metrics = {}