diff --git a/freqtrade/freqai/RL/Base4ActionRLEnv.py b/freqtrade/freqai/RL/Base4ActionRLEnv.py index 79616d778..02e182bbd 100644 --- a/freqtrade/freqai/RL/Base4ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base4ActionRLEnv.py @@ -46,9 +46,9 @@ class Base4ActionRLEnv(BaseEnvironment): self._done = True self._update_unrealized_total_profit() - step_reward = self.calculate_reward(action) self.total_reward += step_reward + self.tensorboard_metrics[self.actions._member_names_[action]] += 1 trade_type = None if self.is_tradesignal(action): diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index 1c09f9386..baf7dde9f 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -49,6 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment): self._update_unrealized_total_profit() step_reward = self.calculate_reward(action) self.total_reward += step_reward + self.tensorboard_metrics[self.actions._member_names_[action]] += 1 trade_type = None if self.is_tradesignal(action): diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 71b423844..0da13db7c 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -77,7 +77,7 @@ class BaseEnvironment(gym.Env): # set here to default 5Ac, but all children envs can override this self.actions: Type[Enum] = BaseActions - self.custom_info: dict = {} + self.tensorboard_metrics: dict = {} def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int, reward_kwargs: dict, starting_point=True): @@ -136,10 +136,10 @@ class BaseEnvironment(gym.Env): """ Reset is called at the beginning of every episode """ - # custom_info is used for episodic reports and tensorboard logging - self.custom_info: dict = {} + # tensorboard_metrics is used for episodic reports and tensorboard logging + self.tensorboard_metrics: dict = {} for action in self.actions: - self.custom_info[action.name] = 0 + self.tensorboard_metrics[action.name] = 0 self._done = False diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index d03c040d4..b596742e9 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -44,16 +44,16 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] - custom_info = self.training_env.get_attr("custom_info")[0] + tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] for info in local_info: if info not in ["episode", "terminal_observation"]: self.logger.record(f"_info/{info}", local_info[info]) - for info in custom_info: + for info in tensorboard_metrics: if info in [action.name for action in self.actions]: - self.logger.record(f"_actions/{info}", custom_info[info]) + self.logger.record(f"_actions/{info}", tensorboard_metrics[info]) else: - self.logger.record(f"_custom/{info}", custom_info[info]) + self.logger.record(f"_custom/{info}", tensorboard_metrics[info]) return True diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 1383ad15e..e015b138a 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -108,15 +108,12 @@ class ReinforcementLearner(BaseReinforcementLearningModel): # reward agent for entering trades if (action == Actions.Long_enter.value and self._position == Positions.Neutral): - self.custom_info[Actions.Long_enter.name] += 1 return 25 if (action == Actions.Short_enter.value and self._position == Positions.Neutral): - self.custom_info[Actions.Short_enter.name] += 1 return 25 # discourage agent from not entering trades if action == Actions.Neutral.value and self._position == Positions.Neutral: - self.custom_info[Actions.Neutral.name] += 1 return -1 max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) @@ -130,21 +127,18 @@ class ReinforcementLearner(BaseReinforcementLearningModel): # discourage sitting in position if (self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value): - self.custom_info[Actions.Neutral.name] += 1 return -1 * trade_duration / max_trade_duration # close long if action == Actions.Long_exit.value and self._position == Positions.Long: if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) - self.custom_info[Actions.Long_exit.name] += 1 return float(pnl * factor) # close short if action == Actions.Short_exit.value and self._position == Positions.Short: if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) - self.custom_info[Actions.Short_exit.name] += 1 return float(pnl * factor) return 0.