custom info to tensorboard_metrics

This commit is contained in:
initrv 2022-12-11 15:37:45 +03:00
parent 58604c747e
commit cb8fc3c8c7
5 changed files with 10 additions and 15 deletions

View File

@ -46,9 +46,9 @@ class Base4ActionRLEnv(BaseEnvironment):
self._done = True self._done = True
self._update_unrealized_total_profit() self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action) step_reward = self.calculate_reward(action)
self.total_reward += step_reward self.total_reward += step_reward
self.tensorboard_metrics[self.actions._member_names_[action]] += 1
trade_type = None trade_type = None
if self.is_tradesignal(action): if self.is_tradesignal(action):

View File

@ -49,6 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment):
self._update_unrealized_total_profit() self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action) step_reward = self.calculate_reward(action)
self.total_reward += step_reward self.total_reward += step_reward
self.tensorboard_metrics[self.actions._member_names_[action]] += 1
trade_type = None trade_type = None
if self.is_tradesignal(action): if self.is_tradesignal(action):

View File

@ -77,7 +77,7 @@ class BaseEnvironment(gym.Env):
# set here to default 5Ac, but all children envs can override this # set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions self.actions: Type[Enum] = BaseActions
self.custom_info: dict = {} self.tensorboard_metrics: dict = {}
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int, def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
reward_kwargs: dict, starting_point=True): reward_kwargs: dict, starting_point=True):
@ -136,10 +136,10 @@ class BaseEnvironment(gym.Env):
""" """
Reset is called at the beginning of every episode Reset is called at the beginning of every episode
""" """
# custom_info is used for episodic reports and tensorboard logging # tensorboard_metrics is used for episodic reports and tensorboard logging
self.custom_info: dict = {} self.tensorboard_metrics: dict = {}
for action in self.actions: for action in self.actions:
self.custom_info[action.name] = 0 self.tensorboard_metrics[action.name] = 0
self._done = False self._done = False

View File

@ -44,16 +44,16 @@ class TensorboardCallback(BaseCallback):
def _on_step(self) -> bool: def _on_step(self) -> bool:
local_info = self.locals["infos"][0] local_info = self.locals["infos"][0]
custom_info = self.training_env.get_attr("custom_info")[0] tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
for info in local_info: for info in local_info:
if info not in ["episode", "terminal_observation"]: if info not in ["episode", "terminal_observation"]:
self.logger.record(f"_info/{info}", local_info[info]) self.logger.record(f"_info/{info}", local_info[info])
for info in custom_info: for info in tensorboard_metrics:
if info in [action.name for action in self.actions]: if info in [action.name for action in self.actions]:
self.logger.record(f"_actions/{info}", custom_info[info]) self.logger.record(f"_actions/{info}", tensorboard_metrics[info])
else: else:
self.logger.record(f"_custom/{info}", custom_info[info]) self.logger.record(f"_custom/{info}", tensorboard_metrics[info])
return True return True

View File

@ -108,15 +108,12 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
# reward agent for entering trades # reward agent for entering trades
if (action == Actions.Long_enter.value if (action == Actions.Long_enter.value
and self._position == Positions.Neutral): and self._position == Positions.Neutral):
self.custom_info[Actions.Long_enter.name] += 1
return 25 return 25
if (action == Actions.Short_enter.value if (action == Actions.Short_enter.value
and self._position == Positions.Neutral): and self._position == Positions.Neutral):
self.custom_info[Actions.Short_enter.name] += 1
return 25 return 25
# discourage agent from not entering trades # discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral: if action == Actions.Neutral.value and self._position == Positions.Neutral:
self.custom_info[Actions.Neutral.name] += 1
return -1 return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
@ -130,21 +127,18 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
# discourage sitting in position # discourage sitting in position
if (self._position in (Positions.Short, Positions.Long) and if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value): action == Actions.Neutral.value):
self.custom_info[Actions.Neutral.name] += 1
return -1 * trade_duration / max_trade_duration return -1 * trade_duration / max_trade_duration
# close long # close long
if action == Actions.Long_exit.value and self._position == Positions.Long: if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
self.custom_info[Actions.Long_exit.name] += 1
return float(pnl * factor) return float(pnl * factor)
# close short # close short
if action == Actions.Short_exit.value and self._position == Positions.Short: if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
self.custom_info[Actions.Short_exit.name] += 1
return float(pnl * factor) return float(pnl * factor)
return 0. return 0.