add documentation for tensorboard_log, change how users interact with tensorboard_log

This commit is contained in:
robcaulk
2022-12-11 15:31:29 +01:00
parent cb8fc3c8c7
commit 0fd8e214e4
5 changed files with 57 additions and 7 deletions

View File

@@ -48,7 +48,7 @@ class Base4ActionRLEnv(BaseEnvironment):
self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action)
self.total_reward += step_reward
self.tensorboard_metrics[self.actions._member_names_[action]] += 1
self.tensorboard_log(self.actions._member_names_[action])
trade_type = None
if self.is_tradesignal(action):

View File

@@ -49,7 +49,7 @@ class Base5ActionRLEnv(BaseEnvironment):
self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action)
self.total_reward += step_reward
self.tensorboard_metrics[self.actions._member_names_[action]] += 1
self.tensorboard_log(self.actions._member_names_[action])
trade_type = None
if self.is_tradesignal(action):

View File

@@ -2,7 +2,7 @@ import logging
import random
from abc import abstractmethod
from enum import Enum
from typing import Optional, Type
from typing import Optional, Type, Union
import gym
import numpy as np
@@ -132,14 +132,37 @@ class BaseEnvironment(gym.Env):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def tensorboard_log(self, metric: str, inc: Union[int, float] = 1):
"""
Function builds the tensorboard_metrics dictionary
to be parsed by the TensorboardCallback. This
function is designed for tracking incremented objects,
events, actions inside the training environment.
For example, a user can call this to track the
frequency of occurence of an `is_valid` call in
their `calculate_reward()`:
def calculate_reward(self, action: int) -> float:
if not self._is_valid(action):
self.tensorboard_log("is_valid")
return -2
:param metric: metric to be tracked and incremented
:param inc: value to increment `metric` by
"""
if metric not in self.tensorboard_metrics:
self.tensorboard_metrics[metric] = inc
else:
self.tensorboard_metrics[metric] += inc
def reset_tensorboard_log(self):
self.tensorboard_metrics = {}
def reset(self):
"""
Reset is called at the beginning of every episode
"""
# tensorboard_metrics is used for episodic reports and tensorboard logging
self.tensorboard_metrics: dict = {}
for action in self.actions:
self.tensorboard_metrics[action.name] = 0
self.reset_tensorboard_log()
self._done = False

View File

@@ -100,6 +100,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
"""
# first, penalize if the action is not valid
if not self._is_valid(action):
self.tensorboard_log("is_valid")
return -2
pnl = self.get_unrealized_profit()