reorganize/generalize tensorboard callback
This commit is contained in:
@@ -88,33 +88,6 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
def reset(self):
|
||||
|
||||
# Reset custom info
|
||||
self.custom_info = {}
|
||||
self.custom_info["Invalid"] = 0
|
||||
self.custom_info["Hold"] = 0
|
||||
self.custom_info["Unknown"] = 0
|
||||
self.custom_info["pnl_factor"] = 0
|
||||
self.custom_info["duration_factor"] = 0
|
||||
self.custom_info["reward_exit"] = 0
|
||||
self.custom_info["reward_hold"] = 0
|
||||
for action in Actions:
|
||||
self.custom_info[f"{action.name}"] = 0
|
||||
return super().reset()
|
||||
|
||||
def step(self, action: int):
|
||||
observation, step_reward, done, info = super().step(action)
|
||||
info = dict(
|
||||
tick=self._current_tick,
|
||||
action=action,
|
||||
total_reward=self.total_reward,
|
||||
total_profit=self._total_profit,
|
||||
position=self._position.value,
|
||||
trade_duration=self.get_trade_duration(),
|
||||
current_profit_pct=self.get_unrealized_profit()
|
||||
)
|
||||
return observation, step_reward, done, info
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
|
@@ -1,14 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Dict # , Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
# import numpy.typing as npt
|
||||
from pandas import DataFrame
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import TensorboardCallback, make_env
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -50,4 +50,5 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
|
||||
self.tensorboard_callback = TensorboardCallback()
|
||||
actions = self.train_env.env_method("get_actions")[0]
|
||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||
|
Reference in New Issue
Block a user