reorganize/generalize tensorboard callback

This commit is contained in:
robcaulk
2022-12-04 13:54:30 +01:00
parent b2edc58089
commit 24766928ba
7 changed files with 125 additions and 88 deletions

View File

@@ -88,33 +88,6 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
def reset(self):
# Reset custom info
self.custom_info = {}
self.custom_info["Invalid"] = 0
self.custom_info["Hold"] = 0
self.custom_info["Unknown"] = 0
self.custom_info["pnl_factor"] = 0
self.custom_info["duration_factor"] = 0
self.custom_info["reward_exit"] = 0
self.custom_info["reward_hold"] = 0
for action in Actions:
self.custom_info[f"{action.name}"] = 0
return super().reset()
def step(self, action: int):
observation, step_reward, done, info = super().step(action)
info = dict(
tick=self._current_tick,
action=action,
total_reward=self.total_reward,
total_profit=self._total_profit,
position=self._position.value,
trade_duration=self.get_trade_duration(),
current_profit_pct=self.get_unrealized_profit()
)
return observation, step_reward, done, info
def calculate_reward(self, action: int) -> float:
"""

View File

@@ -1,14 +1,14 @@
import logging
from typing import Any, Dict # , Tuple
from typing import Any, Dict
# import numpy.typing as npt
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import TensorboardCallback, make_env
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
logger = logging.getLogger(__name__)
@@ -50,4 +50,5 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
self.tensorboard_callback = TensorboardCallback()
actions = self.train_env.env_method("get_actions")[0]
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)