Merge pull request #8147 from freqtrade/add-pair-to-env
Add pair to environment for access inside calculate_reward
This commit is contained in:
commit
42c76d9e0c
@ -175,10 +175,20 @@ As you begin to modify the strategy and the prediction model, you will quickly r
|
|||||||
pnl = self.get_unrealized_profit()
|
pnl = self.get_unrealized_profit()
|
||||||
|
|
||||||
factor = 100
|
factor = 100
|
||||||
# reward agent for entering trades
|
|
||||||
if action in (Actions.Long_enter.value, Actions.Short_enter.value) \
|
# you can use feature values from dataframe
|
||||||
and self._position == Positions.Neutral:
|
rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{self.pair}_"
|
||||||
return 25
|
f"{self.config['timeframe']}"].iloc[self._current_tick]
|
||||||
|
|
||||||
|
# reward agent for entering trades
|
||||||
|
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||||
|
and self._position == Positions.Neutral):
|
||||||
|
if rsi_now < 40:
|
||||||
|
factor = 40 / rsi_now
|
||||||
|
else:
|
||||||
|
factor = 1
|
||||||
|
return 25 * factor
|
||||||
|
|
||||||
# discourage agent from not entering trades
|
# discourage agent from not entering trades
|
||||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
return -1
|
return -1
|
||||||
|
@ -45,7 +45,8 @@ class BaseEnvironment(gym.Env):
|
|||||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
||||||
fee: float = 0.0015, can_short: bool = False):
|
fee: float = 0.0015, can_short: bool = False, pair: str = "",
|
||||||
|
df_raw: DataFrame = DataFrame()):
|
||||||
"""
|
"""
|
||||||
Initializes the training/eval environment.
|
Initializes the training/eval environment.
|
||||||
:param df: dataframe of features
|
:param df: dataframe of features
|
||||||
@ -60,12 +61,14 @@ class BaseEnvironment(gym.Env):
|
|||||||
:param fee: The fee to use for environmental interactions.
|
:param fee: The fee to use for environmental interactions.
|
||||||
:param can_short: Whether or not the environment can short
|
:param can_short: Whether or not the environment can short
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config: dict = config
|
||||||
self.rl_config = config['freqai']['rl_config']
|
self.rl_config: dict = config['freqai']['rl_config']
|
||||||
self.add_state_info = self.rl_config.get('add_state_info', False)
|
self.add_state_info: bool = self.rl_config.get('add_state_info', False)
|
||||||
self.id = id
|
self.id: str = id
|
||||||
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
self.max_drawdown: float = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
||||||
self.compound_trades = config['stake_amount'] == 'unlimited'
|
self.compound_trades: bool = config['stake_amount'] == 'unlimited'
|
||||||
|
self.pair: str = pair
|
||||||
|
self.raw_features: DataFrame = df_raw
|
||||||
if self.config.get('fee', None) is not None:
|
if self.config.get('fee', None) is not None:
|
||||||
self.fee = self.config['fee']
|
self.fee = self.config['fee']
|
||||||
else:
|
else:
|
||||||
@ -74,8 +77,8 @@ class BaseEnvironment(gym.Env):
|
|||||||
# set here to default 5Ac, but all children envs can override this
|
# set here to default 5Ac, but all children envs can override this
|
||||||
self.actions: Type[Enum] = BaseActions
|
self.actions: Type[Enum] = BaseActions
|
||||||
self.tensorboard_metrics: dict = {}
|
self.tensorboard_metrics: dict = {}
|
||||||
self.can_short = can_short
|
self.can_short: bool = can_short
|
||||||
self.live = live
|
self.live: bool = live
|
||||||
if not self.live and self.add_state_info:
|
if not self.live and self.add_state_info:
|
||||||
self.add_state_info = False
|
self.add_state_info = False
|
||||||
logger.warning("add_state_info is not available in backtesting. Deactivating.")
|
logger.warning("add_state_info is not available in backtesting. Deactivating.")
|
||||||
@ -93,13 +96,12 @@ class BaseEnvironment(gym.Env):
|
|||||||
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||||
:param starting_point: start at edge of window or not
|
:param starting_point: start at edge of window or not
|
||||||
"""
|
"""
|
||||||
self.df = df
|
self.signal_features: DataFrame = df
|
||||||
self.signal_features = self.df
|
self.prices: DataFrame = prices
|
||||||
self.prices = prices
|
self.window_size: int = window_size
|
||||||
self.window_size = window_size
|
self.starting_point: bool = starting_point
|
||||||
self.starting_point = starting_point
|
self.rr: float = reward_kwargs["rr"]
|
||||||
self.rr = reward_kwargs["rr"]
|
self.profit_aim: float = reward_kwargs["profit_aim"]
|
||||||
self.profit_aim = reward_kwargs["profit_aim"]
|
|
||||||
|
|
||||||
# # spaces
|
# # spaces
|
||||||
if self.add_state_info:
|
if self.add_state_info:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
@ -50,6 +51,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.eval_callback: Optional[EvalCallback] = None
|
self.eval_callback: Optional[EvalCallback] = None
|
||||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
self.rl_config = self.freqai_info['rl_config']
|
self.rl_config = self.freqai_info['rl_config']
|
||||||
|
self.df_raw: DataFrame = DataFrame()
|
||||||
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
||||||
if self.model_type in SB3_MODELS:
|
if self.model_type in SB3_MODELS:
|
||||||
import_str = 'stable_baselines3'
|
import_str = 'stable_baselines3'
|
||||||
@ -107,6 +109,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
||||||
features_filtered, labels_filtered)
|
features_filtered, labels_filtered)
|
||||||
|
self.df_raw = copy.deepcopy(data_dictionary["train_features"])
|
||||||
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
||||||
|
|
||||||
# normalize all data based on train_dataset only
|
# normalize all data based on train_dataset only
|
||||||
@ -143,7 +146,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
test_df = data_dictionary["test_features"]
|
test_df = data_dictionary["test_features"]
|
||||||
|
|
||||||
env_info = self.pack_env_dict()
|
env_info = self.pack_env_dict(dk.pair)
|
||||||
|
|
||||||
self.train_env = self.MyRLEnv(df=train_df,
|
self.train_env = self.MyRLEnv(df=train_df,
|
||||||
prices=prices_train,
|
prices=prices_train,
|
||||||
@ -158,7 +161,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
actions = self.train_env.get_actions()
|
actions = self.train_env.get_actions()
|
||||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||||
|
|
||||||
def pack_env_dict(self) -> Dict[str, Any]:
|
def pack_env_dict(self, pair: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Create dictionary of environment arguments
|
Create dictionary of environment arguments
|
||||||
"""
|
"""
|
||||||
@ -166,7 +169,9 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
"reward_kwargs": self.reward_params,
|
"reward_kwargs": self.reward_params,
|
||||||
"config": self.config,
|
"config": self.config,
|
||||||
"live": self.live,
|
"live": self.live,
|
||||||
"can_short": self.can_short}
|
"can_short": self.can_short,
|
||||||
|
"pair": pair,
|
||||||
|
"df_raw": self.df_raw}
|
||||||
if self.data_provider:
|
if self.data_provider:
|
||||||
env_info["fee"] = self.data_provider._exchange \
|
env_info["fee"] = self.data_provider._exchange \
|
||||||
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||||
@ -347,7 +352,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
sets a custom reward based on profit and trade duration.
|
sets a custom reward based on profit and trade duration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def calculate_reward(self, action: int) -> float:
|
def calculate_reward(self, action: int) -> float: # noqa: C901
|
||||||
"""
|
"""
|
||||||
An example reward function. This is the one function that users will likely
|
An example reward function. This is the one function that users will likely
|
||||||
wish to inject their own creativity into.
|
wish to inject their own creativity into.
|
||||||
@ -363,10 +368,19 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
pnl = self.get_unrealized_profit()
|
pnl = self.get_unrealized_profit()
|
||||||
factor = 100.
|
factor = 100.
|
||||||
|
|
||||||
|
# you can use feature values from dataframe
|
||||||
|
rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{self.pair}_"
|
||||||
|
f"{self.config['timeframe']}"].iloc[self._current_tick]
|
||||||
|
|
||||||
# reward agent for entering trades
|
# reward agent for entering trades
|
||||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||||
and self._position == Positions.Neutral):
|
and self._position == Positions.Neutral):
|
||||||
return 25
|
if rsi_now < 40:
|
||||||
|
factor = 40 / rsi_now
|
||||||
|
else:
|
||||||
|
factor = 1
|
||||||
|
return 25 * factor
|
||||||
|
|
||||||
# discourage agent from not entering trades
|
# discourage agent from not entering trades
|
||||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
return -1
|
return -1
|
||||||
|
@ -34,7 +34,7 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
|||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
test_df = data_dictionary["test_features"]
|
test_df = data_dictionary["test_features"]
|
||||||
|
|
||||||
env_info = self.pack_env_dict()
|
env_info = self.pack_env_dict(dk.pair)
|
||||||
|
|
||||||
env_id = "train_env"
|
env_id = "train_env"
|
||||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
|
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
|
||||||
|
Loading…
Reference in New Issue
Block a user