cleanup tensorboard callback

This commit is contained in:
initrv 2022-12-07 14:37:55 +03:00
parent b9f6911a6a
commit 58604c747e
3 changed files with 21 additions and 30 deletions

View File

@ -137,15 +137,9 @@ class BaseEnvironment(gym.Env):
Reset is called at the beginning of every episode Reset is called at the beginning of every episode
""" """
# custom_info is used for episodic reports and tensorboard logging # custom_info is used for episodic reports and tensorboard logging
self.custom_info["Invalid"] = 0 self.custom_info: dict = {}
self.custom_info["Hold"] = 0
self.custom_info["Unknown"] = 0
self.custom_info["pnl_factor"] = 0
self.custom_info["duration_factor"] = 0
self.custom_info["reward_exit"] = 0
self.custom_info["reward_hold"] = 0
for action in self.actions: for action in self.actions:
self.custom_info[f"{action.name}"] = 0 self.custom_info[action.name] = 0
self._done = False self._done = False

View File

@ -42,19 +42,18 @@ class TensorboardCallback(BaseCallback):
) )
def _on_step(self) -> bool: def _on_step(self) -> bool:
local_info = self.locals["infos"][0]
custom_info = self.training_env.get_attr("custom_info")[0] custom_info = self.training_env.get_attr("custom_info")[0]
self.logger.record("_state/position", self.locals["infos"][0]["position"])
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"]) for info in local_info:
self.logger.record("_state/current_profit_pct", self.locals["infos"] if info not in ["episode", "terminal_observation"]:
[0]["current_profit_pct"]) self.logger.record(f"_info/{info}", local_info[info])
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"])
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"]) for info in custom_info:
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"] if info in [action.name for action in self.actions]:
[0]["trade_duration"]) self.logger.record(f"_actions/{info}", custom_info[info])
self.logger.record("_actions/action", self.locals["infos"][0]["action"]) else:
self.logger.record("_actions/_Invalid", custom_info["Invalid"]) self.logger.record(f"_custom/{info}", custom_info[info])
self.logger.record("_actions/_Unknown", custom_info["Unknown"])
self.logger.record("_actions/Hold", custom_info["Hold"])
for action in self.actions:
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
return True return True

View File

@ -100,7 +100,6 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
""" """
# first, penalize if the action is not valid # first, penalize if the action is not valid
if not self._is_valid(action): if not self._is_valid(action):
self.custom_info["Invalid"] += 1
return -2 return -2
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
@ -109,15 +108,15 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
# reward agent for entering trades # reward agent for entering trades
if (action == Actions.Long_enter.value if (action == Actions.Long_enter.value
and self._position == Positions.Neutral): and self._position == Positions.Neutral):
self.custom_info[f"{Actions.Long_enter.name}"] += 1 self.custom_info[Actions.Long_enter.name] += 1
return 25 return 25
if (action == Actions.Short_enter.value if (action == Actions.Short_enter.value
and self._position == Positions.Neutral): and self._position == Positions.Neutral):
self.custom_info[f"{Actions.Short_enter.name}"] += 1 self.custom_info[Actions.Short_enter.name] += 1
return 25 return 25
# discourage agent from not entering trades # discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral: if action == Actions.Neutral.value and self._position == Positions.Neutral:
self.custom_info[f"{Actions.Neutral.name}"] += 1 self.custom_info[Actions.Neutral.name] += 1
return -1 return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
@ -131,22 +130,21 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
# discourage sitting in position # discourage sitting in position
if (self._position in (Positions.Short, Positions.Long) and if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value): action == Actions.Neutral.value):
self.custom_info["Hold"] += 1 self.custom_info[Actions.Neutral.name] += 1
return -1 * trade_duration / max_trade_duration return -1 * trade_duration / max_trade_duration
# close long # close long
if action == Actions.Long_exit.value and self._position == Positions.Long: if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
self.custom_info[f"{Actions.Long_exit.name}"] += 1 self.custom_info[Actions.Long_exit.name] += 1
return float(pnl * factor) return float(pnl * factor)
# close short # close short
if action == Actions.Short_exit.value and self._position == Positions.Short: if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
self.custom_info[f"{Actions.Short_exit.name}"] += 1 self.custom_info[Actions.Short_exit.name] += 1
return float(pnl * factor) return float(pnl * factor)
self.custom_info["Unknown"] += 1
return 0. return 0.