add tensorboard category
This commit is contained in:
parent
b23841fbfe
commit
82cb107520
@ -248,13 +248,13 @@ FreqAI also provides a built in episodic summary logger called `self.tensorboard
|
|||||||
"""
|
"""
|
||||||
def calculate_reward(self, action: int) -> float:
|
def calculate_reward(self, action: int) -> float:
|
||||||
if not self._is_valid(action):
|
if not self._is_valid(action):
|
||||||
self.tensorboard_log("is_valid")
|
self.tensorboard_log("invalid")
|
||||||
return -2
|
return -2
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! Note
|
!!! Note
|
||||||
The `self.tensorboard_log()` function is designed for tracking incremented objects only i.e. events, actions inside the training environment. If the event of interest is a float, the float can be passed as the second argument e.g. `self.tensorboard_log("float_metric1", 0.23)` would add 0.23 to `float_metric`. In this case you can also disable incrementing using `inc=False` parameter.
|
The `self.tensorboard_log()` function is designed for tracking incremented objects only i.e. events, actions inside the training environment. If the event of interest is a float, the float can be passed as the second argument e.g. `self.tensorboard_log("float_metric1", 0.23)`. In this case the metric values are not incremented.
|
||||||
|
|
||||||
### Choosing a base environment
|
### Choosing a base environment
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class Base3ActionRLEnv(BaseEnvironment):
|
|||||||
self._update_unrealized_total_profit()
|
self._update_unrealized_total_profit()
|
||||||
step_reward = self.calculate_reward(action)
|
step_reward = self.calculate_reward(action)
|
||||||
self.total_reward += step_reward
|
self.total_reward += step_reward
|
||||||
self.tensorboard_log(self.actions._member_names_[action])
|
self.tensorboard_log(self.actions._member_names_[action], category="actions")
|
||||||
|
|
||||||
trade_type = None
|
trade_type = None
|
||||||
if self.is_tradesignal(action):
|
if self.is_tradesignal(action):
|
||||||
|
@ -48,7 +48,7 @@ class Base4ActionRLEnv(BaseEnvironment):
|
|||||||
self._update_unrealized_total_profit()
|
self._update_unrealized_total_profit()
|
||||||
step_reward = self.calculate_reward(action)
|
step_reward = self.calculate_reward(action)
|
||||||
self.total_reward += step_reward
|
self.total_reward += step_reward
|
||||||
self.tensorboard_log(self.actions._member_names_[action])
|
self.tensorboard_log(self.actions._member_names_[action], category="actions")
|
||||||
|
|
||||||
trade_type = None
|
trade_type = None
|
||||||
if self.is_tradesignal(action):
|
if self.is_tradesignal(action):
|
||||||
|
@ -49,7 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment):
|
|||||||
self._update_unrealized_total_profit()
|
self._update_unrealized_total_profit()
|
||||||
step_reward = self.calculate_reward(action)
|
step_reward = self.calculate_reward(action)
|
||||||
self.total_reward += step_reward
|
self.total_reward += step_reward
|
||||||
self.tensorboard_log(self.actions._member_names_[action])
|
self.tensorboard_log(self.actions._member_names_[action], category="actions")
|
||||||
|
|
||||||
trade_type = None
|
trade_type = None
|
||||||
if self.is_tradesignal(action):
|
if self.is_tradesignal(action):
|
||||||
|
@ -137,7 +137,8 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True):
|
def tensorboard_log(self, metric: str, value: Optional[Union[int, float]] = None,
|
||||||
|
category: str = "custom"):
|
||||||
"""
|
"""
|
||||||
Function builds the tensorboard_metrics dictionary
|
Function builds the tensorboard_metrics dictionary
|
||||||
to be parsed by the TensorboardCallback. This
|
to be parsed by the TensorboardCallback. This
|
||||||
@ -149,17 +150,23 @@ class BaseEnvironment(gym.Env):
|
|||||||
|
|
||||||
def calculate_reward(self, action: int) -> float:
|
def calculate_reward(self, action: int) -> float:
|
||||||
if not self._is_valid(action):
|
if not self._is_valid(action):
|
||||||
self.tensorboard_log("is_valid")
|
self.tensorboard_log("invalid")
|
||||||
return -2
|
return -2
|
||||||
|
|
||||||
:param metric: metric to be tracked and incremented
|
:param metric: metric to be tracked and incremented
|
||||||
:param value: value to increment `metric` by
|
:param value: `metric` value
|
||||||
:param inc: sets whether the `value` is incremented or not
|
:param category: `metric` category
|
||||||
"""
|
"""
|
||||||
if not inc or metric not in self.tensorboard_metrics:
|
increment = True if not value else False
|
||||||
self.tensorboard_metrics[metric] = value
|
value = 1 if increment else value
|
||||||
|
|
||||||
|
if category not in self.tensorboard_metrics:
|
||||||
|
self.tensorboard_metrics[category] = {}
|
||||||
|
|
||||||
|
if not increment or metric not in self.tensorboard_metrics[category]:
|
||||||
|
self.tensorboard_metrics[category][metric] = value
|
||||||
else:
|
else:
|
||||||
self.tensorboard_metrics[metric] += value
|
self.tensorboard_metrics[category][metric] += value
|
||||||
|
|
||||||
def reset_tensorboard_log(self):
|
def reset_tensorboard_log(self):
|
||||||
self.tensorboard_metrics = {}
|
self.tensorboard_metrics = {}
|
||||||
|
@ -46,14 +46,12 @@ class TensorboardCallback(BaseCallback):
|
|||||||
local_info = self.locals["infos"][0]
|
local_info = self.locals["infos"][0]
|
||||||
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||||
|
|
||||||
for info in local_info:
|
for metric in local_info:
|
||||||
if info not in ["episode", "terminal_observation"]:
|
if metric not in ["episode", "terminal_observation"]:
|
||||||
self.logger.record(f"_info/{info}", local_info[info])
|
self.logger.record(f"info/{metric}", local_info[metric])
|
||||||
|
|
||||||
for info in tensorboard_metrics:
|
for category in tensorboard_metrics:
|
||||||
if info in [action.name for action in self.actions]:
|
for metric in tensorboard_metrics[category]:
|
||||||
self.logger.record(f"_actions/{info}", tensorboard_metrics[info])
|
self.logger.record(f"{category}/{metric}", tensorboard_metrics[category][metric])
|
||||||
else:
|
|
||||||
self.logger.record(f"_custom/{info}", tensorboard_metrics[info])
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -100,7 +100,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
"""
|
"""
|
||||||
# first, penalize if the action is not valid
|
# first, penalize if the action is not valid
|
||||||
if not self._is_valid(action):
|
if not self._is_valid(action):
|
||||||
self.tensorboard_log("is_valid")
|
self.tensorboard_log("invalid", category="actions")
|
||||||
return -2
|
return -2
|
||||||
|
|
||||||
pnl = self.get_unrealized_profit()
|
pnl = self.get_unrealized_profit()
|
||||||
|
Loading…
Reference in New Issue
Block a user