fix custom_info

This commit is contained in:
robcaulk 2022-12-05 20:22:54 +01:00
parent d8565261e1
commit 62c69bf2b5
4 changed files with 4 additions and 6 deletions

View File

@ -20,8 +20,8 @@ class Base4ActionRLEnv(BaseEnvironment):
""" """
Base class for a 4 action environment Base class for a 4 action environment
""" """
def __init__(self, *args): def __init__(self, **kwargs):
super().__init__(*args) super().__init__(**kwargs)
self.actions = Actions self.actions = Actions
def set_action_space(self): def set_action_space(self):

View File

@ -75,7 +75,7 @@ class BaseEnvironment(gym.Env):
else: else:
self.fee = 0.0015 self.fee = 0.0015
# set here to default 5Ac, but all children envs can overwrite this # set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions self.actions: Type[Enum] = BaseActions
self.custom_info: dict = {} self.custom_info: dict = {}
@ -121,7 +121,6 @@ class BaseEnvironment(gym.Env):
self._total_unrealized_profit: float = 1 self._total_unrealized_profit: float = 1
self.history: dict = {} self.history: dict = {}
self.trade_history: list = [] self.trade_history: list = []
self.custom_info: dict = {}
@abstractmethod @abstractmethod
def set_action_space(self): def set_action_space(self):

View File

@ -42,7 +42,7 @@ class TensorboardCallback(BaseCallback):
) )
def _on_step(self) -> bool: def _on_step(self) -> bool:
custom_info = self.training_env.custom_info custom_info = self.training_env.get_attr("custom_info")[0]
self.logger.record("_state/position", self.locals["infos"][0]["position"]) self.logger.record("_state/position", self.locals["infos"][0]["position"])
self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"]) self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"])
self.logger.record("_state/current_profit_pct", self.locals["infos"] self.logger.record("_state/current_profit_pct", self.locals["infos"]

View File

@ -237,7 +237,6 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog)
df = freqai.cache_corr_pairlist_dfs(df, freqai.dk) df = freqai.cache_corr_pairlist_dfs(df, freqai.dk)
for i in range(5): for i in range(5):
df[f'%-constant_{i}'] = i df[f'%-constant_{i}'] = i
# df.loc[:, f'%-constant_{i}'] = i
metadata = {"pair": "LTC/BTC"} metadata = {"pair": "LTC/BTC"}
freqai.start_backtesting(df, metadata, freqai.dk) freqai.start_backtesting(df, metadata, freqai.dk)