diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 3fca6a25d..e43951142 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -77,6 +77,7 @@ class BaseEnvironment(gym.Env): # set here to default 5Ac, but all children envs can overwrite this self.actions: Type[Enum] = BaseActions + self.custom_info: dict = {} def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int, reward_kwargs: dict, starting_point=True): diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 4aea9bdf5..b5b8ba23d 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam -from freqtrade.freqai.RL.BaseEnvironment import BaseActions +from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment class TensorboardCallback(BaseCallback): @@ -15,9 +15,8 @@ class TensorboardCallback(BaseCallback): def __init__(self, verbose=1, actions: Type[Enum] = BaseActions): super(TensorboardCallback, self).__init__(verbose) self.model: Any = None - # An alias for self.model.get_env(), the environment used for training self.logger = None # type: Any - # self.training_env = None # type: Union[gym.Env, VecEnv] + self.training_env: BaseEnvironment = None # type: ignore self.actions: Type[Enum] = actions def _on_training_start(self) -> None: @@ -43,7 +42,7 @@ class TensorboardCallback(BaseCallback): ) def _on_step(self) -> bool: - custom_info = self.training_env.get_attr("custom_info")[0] # type: ignore + custom_info = self.training_env.custom_info self.logger.record("_state/position", self.locals["infos"][0]["position"]) self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"]) self.logger.record("_state/current_profit_pct", self.locals["infos"]