Add can_short param to base env
This commit is contained in:
parent
439914caef
commit
dde363343c
@ -45,7 +45,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
||||||
fee: float = 0.0015):
|
fee: float = 0.0015, can_short: bool = False):
|
||||||
"""
|
"""
|
||||||
Initializes the training/eval environment.
|
Initializes the training/eval environment.
|
||||||
:param df: dataframe of features
|
:param df: dataframe of features
|
||||||
@ -58,6 +58,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
:param config: Typical user configuration file
|
:param config: Typical user configuration file
|
||||||
:param live: Whether or not this environment is active in dry/live/backtesting
|
:param live: Whether or not this environment is active in dry/live/backtesting
|
||||||
:param fee: The fee to use for environmental interactions.
|
:param fee: The fee to use for environmental interactions.
|
||||||
|
:param can_short: Whether or not the environment can short
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rl_config = config['freqai']['rl_config']
|
self.rl_config = config['freqai']['rl_config']
|
||||||
@ -73,6 +74,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
# set here to default 5Ac, but all children envs can override this
|
# set here to default 5Ac, but all children envs can override this
|
||||||
self.actions: Type[Enum] = BaseActions
|
self.actions: Type[Enum] = BaseActions
|
||||||
self.tensorboard_metrics: dict = {}
|
self.tensorboard_metrics: dict = {}
|
||||||
|
self.can_short = can_short
|
||||||
self.live = live
|
self.live = live
|
||||||
if not self.live and self.add_state_info:
|
if not self.live and self.add_state_info:
|
||||||
self.add_state_info = False
|
self.add_state_info = False
|
||||||
|
@ -165,7 +165,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
env_info = {"window_size": self.CONV_WIDTH,
|
env_info = {"window_size": self.CONV_WIDTH,
|
||||||
"reward_kwargs": self.reward_params,
|
"reward_kwargs": self.reward_params,
|
||||||
"config": self.config,
|
"config": self.config,
|
||||||
"live": self.live}
|
"live": self.live,
|
||||||
|
"can_short": self.can_short}
|
||||||
if self.data_provider:
|
if self.data_provider:
|
||||||
env_info["fee"] = self.data_provider._exchange \
|
env_info["fee"] = self.data_provider._exchange \
|
||||||
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||||
|
@ -133,6 +133,7 @@ class IFreqaiModel(ABC):
|
|||||||
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
||||||
self.dd.set_pair_dict_info(metadata)
|
self.dd.set_pair_dict_info(metadata)
|
||||||
self.data_provider = strategy.dp
|
self.data_provider = strategy.dp
|
||||||
|
self.can_short = strategy.can_short
|
||||||
|
|
||||||
if self.live:
|
if self.live:
|
||||||
self.inference_timer('start')
|
self.inference_timer('start')
|
||||||
|
Loading…
Reference in New Issue
Block a user