Merge pull request #8147 from freqtrade/add-pair-to-env
Add pair to environment for access inside calculate_reward
This commit is contained in:
commit
42c76d9e0c
@ -175,10 +175,20 @@ As you begin to modify the strategy and the prediction model, you will quickly r
|
||||
pnl = self.get_unrealized_profit()
|
||||
|
||||
factor = 100
|
||||
# reward agent for entering trades
|
||||
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
||||
and self._position == Positions.Neutral:
|
||||
return 25
|
||||
|
||||
# you can use feature values from dataframe
|
||||
rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{self.pair}_"
|
||||
f"{self.config['timeframe']}"].iloc[self._current_tick]
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
if rsi_now < 40:
|
||||
factor = 40 / rsi_now
|
||||
else:
|
||||
factor = 1
|
||||
return 25 * factor
|
||||
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
@ -45,7 +45,8 @@ class BaseEnvironment(gym.Env):
|
||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
||||
fee: float = 0.0015, can_short: bool = False):
|
||||
fee: float = 0.0015, can_short: bool = False, pair: str = "",
|
||||
df_raw: DataFrame = DataFrame()):
|
||||
"""
|
||||
Initializes the training/eval environment.
|
||||
:param df: dataframe of features
|
||||
@ -60,12 +61,14 @@ class BaseEnvironment(gym.Env):
|
||||
:param fee: The fee to use for environmental interactions.
|
||||
:param can_short: Whether or not the environment can short
|
||||
"""
|
||||
self.config = config
|
||||
self.rl_config = config['freqai']['rl_config']
|
||||
self.add_state_info = self.rl_config.get('add_state_info', False)
|
||||
self.id = id
|
||||
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
||||
self.compound_trades = config['stake_amount'] == 'unlimited'
|
||||
self.config: dict = config
|
||||
self.rl_config: dict = config['freqai']['rl_config']
|
||||
self.add_state_info: bool = self.rl_config.get('add_state_info', False)
|
||||
self.id: str = id
|
||||
self.max_drawdown: float = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
||||
self.compound_trades: bool = config['stake_amount'] == 'unlimited'
|
||||
self.pair: str = pair
|
||||
self.raw_features: DataFrame = df_raw
|
||||
if self.config.get('fee', None) is not None:
|
||||
self.fee = self.config['fee']
|
||||
else:
|
||||
@ -74,8 +77,8 @@ class BaseEnvironment(gym.Env):
|
||||
# set here to default 5Ac, but all children envs can override this
|
||||
self.actions: Type[Enum] = BaseActions
|
||||
self.tensorboard_metrics: dict = {}
|
||||
self.can_short = can_short
|
||||
self.live = live
|
||||
self.can_short: bool = can_short
|
||||
self.live: bool = live
|
||||
if not self.live and self.add_state_info:
|
||||
self.add_state_info = False
|
||||
logger.warning("add_state_info is not available in backtesting. Deactivating.")
|
||||
@ -93,13 +96,12 @@ class BaseEnvironment(gym.Env):
|
||||
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||
:param starting_point: start at edge of window or not
|
||||
"""
|
||||
self.df = df
|
||||
self.signal_features = self.df
|
||||
self.prices = prices
|
||||
self.window_size = window_size
|
||||
self.starting_point = starting_point
|
||||
self.rr = reward_kwargs["rr"]
|
||||
self.profit_aim = reward_kwargs["profit_aim"]
|
||||
self.signal_features: DataFrame = df
|
||||
self.prices: DataFrame = prices
|
||||
self.window_size: int = window_size
|
||||
self.starting_point: bool = starting_point
|
||||
self.rr: float = reward_kwargs["rr"]
|
||||
self.profit_aim: float = reward_kwargs["profit_aim"]
|
||||
|
||||
# # spaces
|
||||
if self.add_state_info:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import importlib
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
@ -50,6 +51,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
self.eval_callback: Optional[EvalCallback] = None
|
||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||
self.rl_config = self.freqai_info['rl_config']
|
||||
self.df_raw: DataFrame = DataFrame()
|
||||
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
||||
if self.model_type in SB3_MODELS:
|
||||
import_str = 'stable_baselines3'
|
||||
@ -107,6 +109,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
||||
features_filtered, labels_filtered)
|
||||
self.df_raw = copy.deepcopy(data_dictionary["train_features"])
|
||||
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
||||
|
||||
# normalize all data based on train_dataset only
|
||||
@ -143,7 +146,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
env_info = self.pack_env_dict()
|
||||
env_info = self.pack_env_dict(dk.pair)
|
||||
|
||||
self.train_env = self.MyRLEnv(df=train_df,
|
||||
prices=prices_train,
|
||||
@ -158,7 +161,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
actions = self.train_env.get_actions()
|
||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||
|
||||
def pack_env_dict(self) -> Dict[str, Any]:
|
||||
def pack_env_dict(self, pair: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Create dictionary of environment arguments
|
||||
"""
|
||||
@ -166,7 +169,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
"reward_kwargs": self.reward_params,
|
||||
"config": self.config,
|
||||
"live": self.live,
|
||||
"can_short": self.can_short}
|
||||
"can_short": self.can_short,
|
||||
"pair": pair,
|
||||
"df_raw": self.df_raw}
|
||||
if self.data_provider:
|
||||
env_info["fee"] = self.data_provider._exchange \
|
||||
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||
@ -347,7 +352,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
def calculate_reward(self, action: int) -> float: # noqa: C901
|
||||
"""
|
||||
An example reward function. This is the one function that users will likely
|
||||
wish to inject their own creativity into.
|
||||
@ -363,10 +368,19 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
pnl = self.get_unrealized_profit()
|
||||
factor = 100.
|
||||
|
||||
# you can use feature values from dataframe
|
||||
rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{self.pair}_"
|
||||
f"{self.config['timeframe']}"].iloc[self._current_tick]
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
if rsi_now < 40:
|
||||
factor = 40 / rsi_now
|
||||
else:
|
||||
factor = 1
|
||||
return 25 * factor
|
||||
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
@ -34,7 +34,7 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
env_info = self.pack_env_dict()
|
||||
env_info = self.pack_env_dict(dk.pair)
|
||||
|
||||
env_id = "train_env"
|
||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
|
||||
|
Loading…
Reference in New Issue
Block a user