From 58604c747e759161f25ad4c90571fbaf6a1c5233 Mon Sep 17 00:00:00 2001 From: initrv Date: Wed, 7 Dec 2022 14:37:55 +0300 Subject: [PATCH] cleanup tensorboard callback --- freqtrade/freqai/RL/BaseEnvironment.py | 10 ++----- freqtrade/freqai/RL/TensorboardCallback.py | 27 +++++++++---------- .../prediction_models/ReinforcementLearner.py | 14 +++++----- 3 files changed, 21 insertions(+), 30 deletions(-) diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index a31ded0c6..71b423844 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -137,15 +137,9 @@ class BaseEnvironment(gym.Env): Reset is called at the beginning of every episode """ # custom_info is used for episodic reports and tensorboard logging - self.custom_info["Invalid"] = 0 - self.custom_info["Hold"] = 0 - self.custom_info["Unknown"] = 0 - self.custom_info["pnl_factor"] = 0 - self.custom_info["duration_factor"] = 0 - self.custom_info["reward_exit"] = 0 - self.custom_info["reward_hold"] = 0 + self.custom_info: dict = {} for action in self.actions: - self.custom_info[f"{action.name}"] = 0 + self.custom_info[action.name] = 0 self._done = False diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index f590bdf84..d03c040d4 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -42,19 +42,18 @@ class TensorboardCallback(BaseCallback): ) def _on_step(self) -> bool: + + local_info = self.locals["infos"][0] custom_info = self.training_env.get_attr("custom_info")[0] - self.logger.record("_state/position", self.locals["infos"][0]["position"]) - self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"]) - self.logger.record("_state/current_profit_pct", self.locals["infos"] - [0]["current_profit_pct"]) - self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"]) - self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"]) - self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"] - [0]["trade_duration"]) - self.logger.record("_actions/action", self.locals["infos"][0]["action"]) - self.logger.record("_actions/_Invalid", custom_info["Invalid"]) - self.logger.record("_actions/_Unknown", custom_info["Unknown"]) - self.logger.record("_actions/Hold", custom_info["Hold"]) - for action in self.actions: - self.logger.record(f"_actions/{action.name}", custom_info[action.name]) + + for info in local_info: + if info not in ["episode", "terminal_observation"]: + self.logger.record(f"_info/{info}", local_info[info]) + + for info in custom_info: + if info in [action.name for action in self.actions]: + self.logger.record(f"_actions/{info}", custom_info[info]) + else: + self.logger.record(f"_custom/{info}", custom_info[info]) + return True diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 47dbaf99e..1383ad15e 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -100,7 +100,6 @@ class ReinforcementLearner(BaseReinforcementLearningModel): """ # first, penalize if the action is not valid if not self._is_valid(action): - self.custom_info["Invalid"] += 1 return -2 pnl = self.get_unrealized_profit() @@ -109,15 +108,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel): # reward agent for entering trades if (action == Actions.Long_enter.value and self._position == Positions.Neutral): - self.custom_info[f"{Actions.Long_enter.name}"] += 1 + self.custom_info[Actions.Long_enter.name] += 1 return 25 if (action == Actions.Short_enter.value and self._position == Positions.Neutral): - self.custom_info[f"{Actions.Short_enter.name}"] += 1 + self.custom_info[Actions.Short_enter.name] += 1 return 25 # discourage agent from not entering trades if action == Actions.Neutral.value and self._position == Positions.Neutral: - self.custom_info[f"{Actions.Neutral.name}"] += 1 + self.custom_info[Actions.Neutral.name] += 1 return -1 max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) @@ -131,22 +130,21 @@ class ReinforcementLearner(BaseReinforcementLearningModel): # discourage sitting in position if (self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value): - self.custom_info["Hold"] += 1 + self.custom_info[Actions.Neutral.name] += 1 return -1 * trade_duration / max_trade_duration # close long if action == Actions.Long_exit.value and self._position == Positions.Long: if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) - self.custom_info[f"{Actions.Long_exit.name}"] += 1 + self.custom_info[Actions.Long_exit.name] += 1 return float(pnl * factor) # close short if action == Actions.Short_exit.value and self._position == Positions.Short: if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) - self.custom_info[f"{Actions.Short_exit.name}"] += 1 + self.custom_info[Actions.Short_exit.name] += 1 return float(pnl * factor) - self.custom_info["Unknown"] += 1 return 0.