diff --git a/docs/freqai-reinforcement-learning.md b/docs/freqai-reinforcement-learning.md index 04ca42a5d..ed6a41825 100644 --- a/docs/freqai-reinforcement-learning.md +++ b/docs/freqai-reinforcement-learning.md @@ -248,13 +248,13 @@ FreqAI also provides a built in episodic summary logger called `self.tensorboard """ def calculate_reward(self, action: int) -> float: if not self._is_valid(action): - self.tensorboard_log("is_valid") + self.tensorboard_log("invalid") return -2 ``` !!! Note - The `self.tensorboard_log()` function is designed for tracking incremented objects only i.e. events, actions inside the training environment. If the event of interest is a float, the float can be passed as the second argument e.g. `self.tensorboard_log("float_metric1", 0.23)` would add 0.23 to `float_metric`. In this case you can also disable incrementing using `inc=False` parameter. + The `self.tensorboard_log()` function is designed for tracking incremented objects only i.e. events, actions inside the training environment. If the event of interest is a float, the float can be passed as the second argument e.g. `self.tensorboard_log("float_metric1", 0.23)`. In this case the metric values are not incremented. ### Choosing a base environment diff --git a/freqtrade/freqai/RL/Base3ActionRLEnv.py b/freqtrade/freqai/RL/Base3ActionRLEnv.py index 3b5fffc58..a108d776e 100644 --- a/freqtrade/freqai/RL/Base3ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base3ActionRLEnv.py @@ -47,7 +47,7 @@ class Base3ActionRLEnv(BaseEnvironment): self._update_unrealized_total_profit() step_reward = self.calculate_reward(action) self.total_reward += step_reward - self.tensorboard_log(self.actions._member_names_[action]) + self.tensorboard_log(self.actions._member_names_[action], category="actions") trade_type = None if self.is_tradesignal(action): diff --git a/freqtrade/freqai/RL/Base4ActionRLEnv.py b/freqtrade/freqai/RL/Base4ActionRLEnv.py index 8f45028b1..4f093f06c 100644 --- a/freqtrade/freqai/RL/Base4ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base4ActionRLEnv.py @@ -48,7 +48,7 @@ class Base4ActionRLEnv(BaseEnvironment): self._update_unrealized_total_profit() step_reward = self.calculate_reward(action) self.total_reward += step_reward - self.tensorboard_log(self.actions._member_names_[action]) + self.tensorboard_log(self.actions._member_names_[action], category="actions") trade_type = None if self.is_tradesignal(action): diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index 22d3cae30..490ef3601 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -49,7 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment): self._update_unrealized_total_profit() step_reward = self.calculate_reward(action) self.total_reward += step_reward - self.tensorboard_log(self.actions._member_names_[action]) + self.tensorboard_log(self.actions._member_names_[action], category="actions") trade_type = None if self.is_tradesignal(action): diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 7a4467bf7..df2a89d81 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -137,7 +137,8 @@ class BaseEnvironment(gym.Env): self.np_random, seed = seeding.np_random(seed) return [seed] - def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True): + def tensorboard_log(self, metric: str, value: Optional[Union[int, float]] = None, + category: str = "custom"): """ Function builds the tensorboard_metrics dictionary to be parsed by the TensorboardCallback. This @@ -149,17 +150,23 @@ class BaseEnvironment(gym.Env): def calculate_reward(self, action: int) -> float: if not self._is_valid(action): - self.tensorboard_log("is_valid") + self.tensorboard_log("invalid") return -2 :param metric: metric to be tracked and incremented - :param value: value to increment `metric` by - :param inc: sets whether the `value` is incremented or not + :param value: `metric` value + :param category: `metric` category """ - if not inc or metric not in self.tensorboard_metrics: - self.tensorboard_metrics[metric] = value + increment = True if not value else False + value = 1 if increment else value + + if category not in self.tensorboard_metrics: + self.tensorboard_metrics[category] = {} + + if not increment or metric not in self.tensorboard_metrics[category]: + self.tensorboard_metrics[category][metric] = value else: - self.tensorboard_metrics[metric] += value + self.tensorboard_metrics[category][metric] += value def reset_tensorboard_log(self): self.tensorboard_metrics = {} diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index b596742e9..1828319cd 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -46,14 +46,12 @@ class TensorboardCallback(BaseCallback): local_info = self.locals["infos"][0] tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] - for info in local_info: - if info not in ["episode", "terminal_observation"]: - self.logger.record(f"_info/{info}", local_info[info]) + for metric in local_info: + if metric not in ["episode", "terminal_observation"]: + self.logger.record(f"info/{metric}", local_info[metric]) - for info in tensorboard_metrics: - if info in [action.name for action in self.actions]: - self.logger.record(f"_actions/{info}", tensorboard_metrics[info]) - else: - self.logger.record(f"_custom/{info}", tensorboard_metrics[info]) + for category in tensorboard_metrics: + for metric in tensorboard_metrics[category]: + self.logger.record(f"{category}/{metric}", tensorboard_metrics[category][metric]) return True diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 2a87151f9..e795703d4 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -100,7 +100,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): """ # first, penalize if the action is not valid if not self._is_valid(action): - self.tensorboard_log("is_valid") + self.tensorboard_log("invalid", category="actions") return -2 pnl = self.get_unrealized_profit()