Add can_short param to base env

This commit is contained in:
Emre 2022-12-16 22:16:19 +03:00
parent 439914caef
commit dde363343c
No known key found for this signature in database
GPG Key ID: 0EAD2EE11B666ABA
3 changed files with 6 additions and 2 deletions

View File

@ -45,7 +45,7 @@ class BaseEnvironment(gym.Env):
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(), def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True, reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False, id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
fee: float = 0.0015): fee: float = 0.0015, can_short: bool = False):
""" """
Initializes the training/eval environment. Initializes the training/eval environment.
:param df: dataframe of features :param df: dataframe of features
@ -58,6 +58,7 @@ class BaseEnvironment(gym.Env):
:param config: Typical user configuration file :param config: Typical user configuration file
:param live: Whether or not this environment is active in dry/live/backtesting :param live: Whether or not this environment is active in dry/live/backtesting
:param fee: The fee to use for environmental interactions. :param fee: The fee to use for environmental interactions.
:param can_short: Whether or not the environment can short
""" """
self.config = config self.config = config
self.rl_config = config['freqai']['rl_config'] self.rl_config = config['freqai']['rl_config']
@ -73,6 +74,7 @@ class BaseEnvironment(gym.Env):
# set here to default 5Ac, but all children envs can override this # set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions self.actions: Type[Enum] = BaseActions
self.tensorboard_metrics: dict = {} self.tensorboard_metrics: dict = {}
self.can_short = can_short
self.live = live self.live = live
if not self.live and self.add_state_info: if not self.live and self.add_state_info:
self.add_state_info = False self.add_state_info = False

View File

@ -165,7 +165,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
env_info = {"window_size": self.CONV_WIDTH, env_info = {"window_size": self.CONV_WIDTH,
"reward_kwargs": self.reward_params, "reward_kwargs": self.reward_params,
"config": self.config, "config": self.config,
"live": self.live} "live": self.live,
"can_short": self.can_short}
if self.data_provider: if self.data_provider:
env_info["fee"] = self.data_provider._exchange \ env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore .get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore

View File

@ -133,6 +133,7 @@ class IFreqaiModel(ABC):
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE) self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
self.dd.set_pair_dict_info(metadata) self.dd.set_pair_dict_info(metadata)
self.data_provider = strategy.dp self.data_provider = strategy.dp
self.can_short = strategy.can_short
if self.live: if self.live:
self.inference_timer('start') self.inference_timer('start')