diff --git a/freqtrade/freqai/RL/Base4ActionRLEnv.py b/freqtrade/freqai/RL/Base4ActionRLEnv.py index 7818ac51e..79616d778 100644 --- a/freqtrade/freqai/RL/Base4ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base4ActionRLEnv.py @@ -20,8 +20,8 @@ class Base4ActionRLEnv(BaseEnvironment): """ Base class for a 4 action environment """ - def __init__(self, *args): - super().__init__(*args) + def __init__(self, **kwargs): + super().__init__(**kwargs) self.actions = Actions def set_action_space(self): diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index e43951142..a31ded0c6 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -75,7 +75,7 @@ class BaseEnvironment(gym.Env): else: self.fee = 0.0015 - # set here to default 5Ac, but all children envs can overwrite this + # set here to default 5Ac, but all children envs can override this self.actions: Type[Enum] = BaseActions self.custom_info: dict = {} @@ -121,7 +121,6 @@ class BaseEnvironment(gym.Env): self._total_unrealized_profit: float = 1 self.history: dict = {} self.trade_history: list = [] - self.custom_info: dict = {} @abstractmethod def set_action_space(self): diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index b5b8ba23d..f590bdf84 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -42,7 +42,7 @@ class TensorboardCallback(BaseCallback): ) def _on_step(self) -> bool: - custom_info = self.training_env.custom_info + custom_info = self.training_env.get_attr("custom_info")[0] self.logger.record("_state/position", self.locals["infos"][0]["position"]) self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"]) self.logger.record("_state/current_profit_pct", self.locals["infos"] diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index c53137093..f19acb018 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -237,7 +237,6 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog) df = freqai.cache_corr_pairlist_dfs(df, freqai.dk) for i in range(5): df[f'%-constant_{i}'] = i - # df.loc[:, f'%-constant_{i}'] = i metadata = {"pair": "LTC/BTC"} freqai.start_backtesting(df, metadata, freqai.dk)