custom info to tensorboard_metrics
This commit is contained in:
parent
58604c747e
commit
cb8fc3c8c7
@ -46,9 +46,9 @@ class Base4ActionRLEnv(BaseEnvironment):
|
|||||||
self._done = True
|
self._done = True
|
||||||
|
|
||||||
self._update_unrealized_total_profit()
|
self._update_unrealized_total_profit()
|
||||||
|
|
||||||
step_reward = self.calculate_reward(action)
|
step_reward = self.calculate_reward(action)
|
||||||
self.total_reward += step_reward
|
self.total_reward += step_reward
|
||||||
|
self.tensorboard_metrics[self.actions._member_names_[action]] += 1
|
||||||
|
|
||||||
trade_type = None
|
trade_type = None
|
||||||
if self.is_tradesignal(action):
|
if self.is_tradesignal(action):
|
||||||
|
@ -49,6 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment):
|
|||||||
self._update_unrealized_total_profit()
|
self._update_unrealized_total_profit()
|
||||||
step_reward = self.calculate_reward(action)
|
step_reward = self.calculate_reward(action)
|
||||||
self.total_reward += step_reward
|
self.total_reward += step_reward
|
||||||
|
self.tensorboard_metrics[self.actions._member_names_[action]] += 1
|
||||||
|
|
||||||
trade_type = None
|
trade_type = None
|
||||||
if self.is_tradesignal(action):
|
if self.is_tradesignal(action):
|
||||||
|
@ -77,7 +77,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
|
|
||||||
# set here to default 5Ac, but all children envs can override this
|
# set here to default 5Ac, but all children envs can override this
|
||||||
self.actions: Type[Enum] = BaseActions
|
self.actions: Type[Enum] = BaseActions
|
||||||
self.custom_info: dict = {}
|
self.tensorboard_metrics: dict = {}
|
||||||
|
|
||||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||||
reward_kwargs: dict, starting_point=True):
|
reward_kwargs: dict, starting_point=True):
|
||||||
@ -136,10 +136,10 @@ class BaseEnvironment(gym.Env):
|
|||||||
"""
|
"""
|
||||||
Reset is called at the beginning of every episode
|
Reset is called at the beginning of every episode
|
||||||
"""
|
"""
|
||||||
# custom_info is used for episodic reports and tensorboard logging
|
# tensorboard_metrics is used for episodic reports and tensorboard logging
|
||||||
self.custom_info: dict = {}
|
self.tensorboard_metrics: dict = {}
|
||||||
for action in self.actions:
|
for action in self.actions:
|
||||||
self.custom_info[action.name] = 0
|
self.tensorboard_metrics[action.name] = 0
|
||||||
|
|
||||||
self._done = False
|
self._done = False
|
||||||
|
|
||||||
|
@ -44,16 +44,16 @@ class TensorboardCallback(BaseCallback):
|
|||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
|
||||||
local_info = self.locals["infos"][0]
|
local_info = self.locals["infos"][0]
|
||||||
custom_info = self.training_env.get_attr("custom_info")[0]
|
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||||
|
|
||||||
for info in local_info:
|
for info in local_info:
|
||||||
if info not in ["episode", "terminal_observation"]:
|
if info not in ["episode", "terminal_observation"]:
|
||||||
self.logger.record(f"_info/{info}", local_info[info])
|
self.logger.record(f"_info/{info}", local_info[info])
|
||||||
|
|
||||||
for info in custom_info:
|
for info in tensorboard_metrics:
|
||||||
if info in [action.name for action in self.actions]:
|
if info in [action.name for action in self.actions]:
|
||||||
self.logger.record(f"_actions/{info}", custom_info[info])
|
self.logger.record(f"_actions/{info}", tensorboard_metrics[info])
|
||||||
else:
|
else:
|
||||||
self.logger.record(f"_custom/{info}", custom_info[info])
|
self.logger.record(f"_custom/{info}", tensorboard_metrics[info])
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -108,15 +108,12 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
# reward agent for entering trades
|
# reward agent for entering trades
|
||||||
if (action == Actions.Long_enter.value
|
if (action == Actions.Long_enter.value
|
||||||
and self._position == Positions.Neutral):
|
and self._position == Positions.Neutral):
|
||||||
self.custom_info[Actions.Long_enter.name] += 1
|
|
||||||
return 25
|
return 25
|
||||||
if (action == Actions.Short_enter.value
|
if (action == Actions.Short_enter.value
|
||||||
and self._position == Positions.Neutral):
|
and self._position == Positions.Neutral):
|
||||||
self.custom_info[Actions.Short_enter.name] += 1
|
|
||||||
return 25
|
return 25
|
||||||
# discourage agent from not entering trades
|
# discourage agent from not entering trades
|
||||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
self.custom_info[Actions.Neutral.name] += 1
|
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||||
@ -130,21 +127,18 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
# discourage sitting in position
|
# discourage sitting in position
|
||||||
if (self._position in (Positions.Short, Positions.Long) and
|
if (self._position in (Positions.Short, Positions.Long) and
|
||||||
action == Actions.Neutral.value):
|
action == Actions.Neutral.value):
|
||||||
self.custom_info[Actions.Neutral.name] += 1
|
|
||||||
return -1 * trade_duration / max_trade_duration
|
return -1 * trade_duration / max_trade_duration
|
||||||
|
|
||||||
# close long
|
# close long
|
||||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||||
if pnl > self.profit_aim * self.rr:
|
if pnl > self.profit_aim * self.rr:
|
||||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||||
self.custom_info[Actions.Long_exit.name] += 1
|
|
||||||
return float(pnl * factor)
|
return float(pnl * factor)
|
||||||
|
|
||||||
# close short
|
# close short
|
||||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||||
if pnl > self.profit_aim * self.rr:
|
if pnl > self.profit_aim * self.rr:
|
||||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||||
self.custom_info[Actions.Short_exit.name] += 1
|
|
||||||
return float(pnl * factor)
|
return float(pnl * factor)
|
||||||
|
|
||||||
return 0.
|
return 0.
|
||||||
|
Loading…
Reference in New Issue
Block a user