expose raw features to the environment for use in calculate_reward
This commit is contained in:
parent
154b6711b3
commit
8873a565ee
@ -177,7 +177,7 @@ As you begin to modify the strategy and the prediction model, you will quickly r
|
|||||||
factor = 100
|
factor = 100
|
||||||
|
|
||||||
# you can use feature values from dataframe
|
# you can use feature values from dataframe
|
||||||
rsi_now = self.df[f"%-rsi-period-10_shift-1_{self.pair}_"
|
rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{self.pair}_"
|
||||||
f"{self.config['timeframe']}"].iloc[self._current_tick]
|
f"{self.config['timeframe']}"].iloc[self._current_tick]
|
||||||
|
|
||||||
# reward agent for entering trades
|
# reward agent for entering trades
|
||||||
|
@ -45,7 +45,8 @@ class BaseEnvironment(gym.Env):
|
|||||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
|
||||||
fee: float = 0.0015, can_short: bool = False, pair: str = ""):
|
fee: float = 0.0015, can_short: bool = False, pair: str = "",
|
||||||
|
df_raw: DataFrame = DataFrame()):
|
||||||
"""
|
"""
|
||||||
Initializes the training/eval environment.
|
Initializes the training/eval environment.
|
||||||
:param df: dataframe of features
|
:param df: dataframe of features
|
||||||
@ -67,6 +68,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.max_drawdown: float = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
self.max_drawdown: float = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
||||||
self.compound_trades: bool = config['stake_amount'] == 'unlimited'
|
self.compound_trades: bool = config['stake_amount'] == 'unlimited'
|
||||||
self.pair: str = pair
|
self.pair: str = pair
|
||||||
|
self.raw_features: DataFrame = df_raw
|
||||||
if self.config.get('fee', None) is not None:
|
if self.config.get('fee', None) is not None:
|
||||||
self.fee = self.config['fee']
|
self.fee = self.config['fee']
|
||||||
else:
|
else:
|
||||||
@ -94,8 +96,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||||
:param starting_point: start at edge of window or not
|
:param starting_point: start at edge of window or not
|
||||||
"""
|
"""
|
||||||
self.df: DataFrame = df
|
self.signal_features: DataFrame = df
|
||||||
self.signal_features: DataFrame = self.df
|
|
||||||
self.prices: DataFrame = prices
|
self.prices: DataFrame = prices
|
||||||
self.window_size: int = window_size
|
self.window_size: int = window_size
|
||||||
self.starting_point: bool = starting_point
|
self.starting_point: bool = starting_point
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
@ -50,6 +51,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.eval_callback: Optional[EvalCallback] = None
|
self.eval_callback: Optional[EvalCallback] = None
|
||||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
self.rl_config = self.freqai_info['rl_config']
|
self.rl_config = self.freqai_info['rl_config']
|
||||||
|
self.df_raw: DataFrame = DataFrame()
|
||||||
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
||||||
if self.model_type in SB3_MODELS:
|
if self.model_type in SB3_MODELS:
|
||||||
import_str = 'stable_baselines3'
|
import_str = 'stable_baselines3'
|
||||||
@ -107,6 +109,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
data_dictionary: Dict[str, Any] = dk.make_train_test_datasets(
|
||||||
features_filtered, labels_filtered)
|
features_filtered, labels_filtered)
|
||||||
|
self.df_raw = copy.deepcopy(data_dictionary["train_features"])
|
||||||
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
||||||
|
|
||||||
# normalize all data based on train_dataset only
|
# normalize all data based on train_dataset only
|
||||||
@ -167,7 +170,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
"config": self.config,
|
"config": self.config,
|
||||||
"live": self.live,
|
"live": self.live,
|
||||||
"can_short": self.can_short,
|
"can_short": self.can_short,
|
||||||
"pair": pair}
|
"pair": pair,
|
||||||
|
"df_raw": self.df_raw}
|
||||||
if self.data_provider:
|
if self.data_provider:
|
||||||
env_info["fee"] = self.data_provider._exchange \
|
env_info["fee"] = self.data_provider._exchange \
|
||||||
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||||
@ -365,8 +369,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
factor = 100.
|
factor = 100.
|
||||||
|
|
||||||
# you can use feature values from dataframe
|
# you can use feature values from dataframe
|
||||||
rsi_now = self.df[f"%-rsi-period-10_shift-1_{self.pair}_"
|
rsi_now = self.raw_features[f"%-rsi-period-10_shift-1_{self.pair}_"
|
||||||
f"{self.config['timeframe']}"].iloc[self._current_tick]
|
f"{self.config['timeframe']}"].iloc[self._current_tick]
|
||||||
|
|
||||||
# reward agent for entering trades
|
# reward agent for entering trades
|
||||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
||||||
|
Loading…
Reference in New Issue
Block a user