Merge pull request #7866 from initrv/cleanup-tensorboard-callback

Cleanup tensorboard callback
This commit is contained in:
Robert Caulk 2022-12-13 09:05:46 +01:00 committed by GitHub
commit e6da646e2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 73 additions and 36 deletions

View File

@ -247,6 +247,32 @@ where `unique-id` is the `identifier` set in the `freqai` configuration file. Th
![tensorboard](assets/tensorboard.jpg) ![tensorboard](assets/tensorboard.jpg)
### Custom logging
FreqAI also provides a built in episodic summary logger called `self.tensorboard_log` for adding custom information to the Tensorboard log. By default, this function is already called once per step inside the environment to record the agent actions. All values accumulated for all steps in a single episode are reported at the conclusion of each episode, followed by a full reset of all metrics to 0 in preparation for the subsequent episode.
`self.tensorboard_log` can also be used anywhere inside the environment, for example, it can be added to the `calculate_reward` function to collect more detailed information about how often various parts of the reward were called:
```py
class MyRLEnv(Base5ActionRLEnv):
"""
User made custom environment. This class inherits from BaseEnvironment and gym.env.
Users can override any functions from those parent classes. Here is an example
of a user customized `calculate_reward()` function.
"""
def calculate_reward(self, action: int) -> float:
if not self._is_valid(action):
self.tensorboard_log("is_valid")
return -2
```
!!! Note
The `self.tensorboard_log()` function is designed for tracking incremented objects only i.e. events, actions inside the training environment. If the event of interest is a float, the float can be passed as the second argument e.g. `self.tensorboard_log("float_metric1", 0.23)` would add 0.23 to `float_metric`. In this case you can also disable incrementing using `inc=False` parameter.
### Choosing a base environment ### Choosing a base environment
FreqAI provides two base environments, `Base4ActionEnvironment` and `Base5ActionEnvironment`. As the names imply, the environments are customized for agents that can select from 4 or 5 actions. In the `Base4ActionEnvironment`, the agent can enter long, enter short, hold neutral, or exit position. Meanwhile, in the `Base5ActionEnvironment`, the agent has the same actions as Base4, but instead of a single exit action, it separates exit long and exit short. The main changes stemming from the environment selection include: FreqAI provides two base environments, `Base4ActionEnvironment` and `Base5ActionEnvironment`. As the names imply, the environments are customized for agents that can select from 4 or 5 actions. In the `Base4ActionEnvironment`, the agent can enter long, enter short, hold neutral, or exit position. Meanwhile, in the `Base5ActionEnvironment`, the agent has the same actions as Base4, but instead of a single exit action, it separates exit long and exit short. The main changes stemming from the environment selection include:

View File

@ -46,9 +46,9 @@ class Base4ActionRLEnv(BaseEnvironment):
self._done = True self._done = True
self._update_unrealized_total_profit() self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action) step_reward = self.calculate_reward(action)
self.total_reward += step_reward self.total_reward += step_reward
self.tensorboard_log(self.actions._member_names_[action])
trade_type = None trade_type = None
if self.is_tradesignal(action): if self.is_tradesignal(action):

View File

@ -49,6 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment):
self._update_unrealized_total_profit() self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action) step_reward = self.calculate_reward(action)
self.total_reward += step_reward self.total_reward += step_reward
self.tensorboard_log(self.actions._member_names_[action])
trade_type = None trade_type = None
if self.is_tradesignal(action): if self.is_tradesignal(action):

View File

@ -2,7 +2,7 @@ import logging
import random import random
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Optional, Type from typing import Optional, Type, Union
import gym import gym
import numpy as np import numpy as np
@ -78,7 +78,7 @@ class BaseEnvironment(gym.Env):
# set here to default 5Ac, but all children envs can override this # set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions self.actions: Type[Enum] = BaseActions
self.custom_info: dict = {} self.tensorboard_metrics: dict = {}
self.live: bool = False self.live: bool = False
if dp: if dp:
self.live = dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE) self.live = dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
@ -139,20 +139,38 @@ class BaseEnvironment(gym.Env):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
return [seed] return [seed]
def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True):
"""
Function builds the tensorboard_metrics dictionary
to be parsed by the TensorboardCallback. This
function is designed for tracking incremented objects,
events, actions inside the training environment.
For example, a user can call this to track the
frequency of occurence of an `is_valid` call in
their `calculate_reward()`:
def calculate_reward(self, action: int) -> float:
if not self._is_valid(action):
self.tensorboard_log("is_valid")
return -2
:param metric: metric to be tracked and incremented
:param value: value to increment `metric` by
:param inc: sets whether the `value` is incremented or not
"""
if not inc or metric not in self.tensorboard_metrics:
self.tensorboard_metrics[metric] = value
else:
self.tensorboard_metrics[metric] += value
def reset_tensorboard_log(self):
self.tensorboard_metrics = {}
def reset(self): def reset(self):
""" """
Reset is called at the beginning of every episode Reset is called at the beginning of every episode
""" """
# custom_info is used for episodic reports and tensorboard logging self.reset_tensorboard_log()
self.custom_info["Invalid"] = 0
self.custom_info["Hold"] = 0
self.custom_info["Unknown"] = 0
self.custom_info["pnl_factor"] = 0
self.custom_info["duration_factor"] = 0
self.custom_info["reward_exit"] = 0
self.custom_info["reward_hold"] = 0
for action in self.actions:
self.custom_info[f"{action.name}"] = 0
self._done = False self._done = False

View File

@ -42,19 +42,18 @@ class TensorboardCallback(BaseCallback):
) )
def _on_step(self) -> bool: def _on_step(self) -> bool:
custom_info = self.training_env.get_attr("custom_info")[0]
self.logger.record("_state/position", self.locals["infos"][0]["position"]) local_info = self.locals["infos"][0]
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"]) tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
self.logger.record("_state/current_profit_pct", self.locals["infos"]
[0]["current_profit_pct"]) for info in local_info:
self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"]) if info not in ["episode", "terminal_observation"]:
self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"]) self.logger.record(f"_info/{info}", local_info[info])
self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"]
[0]["trade_duration"]) for info in tensorboard_metrics:
self.logger.record("_actions/action", self.locals["infos"][0]["action"]) if info in [action.name for action in self.actions]:
self.logger.record("_actions/_Invalid", custom_info["Invalid"]) self.logger.record(f"_actions/{info}", tensorboard_metrics[info])
self.logger.record("_actions/_Unknown", custom_info["Unknown"]) else:
self.logger.record("_actions/Hold", custom_info["Hold"]) self.logger.record(f"_custom/{info}", tensorboard_metrics[info])
for action in self.actions:
self.logger.record(f"_actions/{action.name}", custom_info[action.name])
return True return True

View File

@ -100,7 +100,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
""" """
# first, penalize if the action is not valid # first, penalize if the action is not valid
if not self._is_valid(action): if not self._is_valid(action):
self.custom_info["Invalid"] += 1 self.tensorboard_log("is_valid")
return -2 return -2
pnl = self.get_unrealized_profit() pnl = self.get_unrealized_profit()
@ -109,15 +109,12 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
# reward agent for entering trades # reward agent for entering trades
if (action == Actions.Long_enter.value if (action == Actions.Long_enter.value
and self._position == Positions.Neutral): and self._position == Positions.Neutral):
self.custom_info[f"{Actions.Long_enter.name}"] += 1
return 25 return 25
if (action == Actions.Short_enter.value if (action == Actions.Short_enter.value
and self._position == Positions.Neutral): and self._position == Positions.Neutral):
self.custom_info[f"{Actions.Short_enter.name}"] += 1
return 25 return 25
# discourage agent from not entering trades # discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral: if action == Actions.Neutral.value and self._position == Positions.Neutral:
self.custom_info[f"{Actions.Neutral.name}"] += 1
return -1 return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
@ -131,22 +128,18 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
# discourage sitting in position # discourage sitting in position
if (self._position in (Positions.Short, Positions.Long) and if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value): action == Actions.Neutral.value):
self.custom_info["Hold"] += 1
return -1 * trade_duration / max_trade_duration return -1 * trade_duration / max_trade_duration
# close long # close long
if action == Actions.Long_exit.value and self._position == Positions.Long: if action == Actions.Long_exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
self.custom_info[f"{Actions.Long_exit.name}"] += 1
return float(pnl * factor) return float(pnl * factor)
# close short # close short
if action == Actions.Short_exit.value and self._position == Positions.Short: if action == Actions.Short_exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr: if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
self.custom_info[f"{Actions.Short_exit.name}"] += 1
return float(pnl * factor) return float(pnl * factor)
self.custom_info["Unknown"] += 1
return 0. return 0.