Merge pull request #7899 from freqtrade/fix/multiproc-dp

Ensure data provider is passed to multiproc envs
This commit is contained in:
Matthias 2022-12-15 19:31:23 +01:00 committed by GitHub
commit b915872f66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 32 deletions

View File

@ -11,9 +11,6 @@ from gym import spaces
from gym.utils import seeding
from pandas import DataFrame
from freqtrade.data.dataprovider import DataProvider
from freqtrade.enums import RunMode
logger = logging.getLogger(__name__)
@ -47,8 +44,8 @@ class BaseEnvironment(gym.Env):
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
dp: Optional[DataProvider] = None):
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
fee: float = 0.0015):
"""
Initializes the training/eval environment.
:param df: dataframe of features
@ -59,32 +56,29 @@ class BaseEnvironment(gym.Env):
:param id: string id of the environment (used in backend for multiprocessed env)
:param seed: Sets the seed of the environment higher in the gym.Env object
:param config: Typical user configuration file
:param dp: dataprovider from freqtrade
:param live: Whether or not this environment is active in dry/live/backtesting
:param fee: The fee to use for environmental interactions.
"""
self.config = config
self.rl_config = config['freqai']['rl_config']
self.add_state_info = self.rl_config.get('add_state_info', False)
self.id = id
self.seed(seed)
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
self.compound_trades = config['stake_amount'] == 'unlimited'
if self.config.get('fee', None) is not None:
self.fee = self.config['fee']
elif dp is not None:
self.fee = dp._exchange.get_fee(symbol=dp.current_whitelist()[0]) # type: ignore
else:
self.fee = 0.0015
self.fee = fee
# set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions
self.tensorboard_metrics: dict = {}
self.live: bool = False
if dp:
self.live = dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
self.live = live
if not self.live and self.add_state_info:
self.add_state_info = False
logger.warning("add_state_info is not available in backtesting. Deactivating.")
self.seed(seed)
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
reward_kwargs: dict, starting_point=True):
@ -213,7 +207,7 @@ class BaseEnvironment(gym.Env):
"""
features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick]
if self.add_state_info and self.live:
if self.add_state_info:
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
columns=['current_profit_pct',
'position',

View File

@ -143,18 +143,14 @@ class BaseReinforcementLearningModel(IFreqaiModel):
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
env_info = self.pack_env_dict()
self.train_env = self.MyRLEnv(df=train_df,
prices=prices_train,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params,
config=self.config,
dp=self.data_provider)
**env_info)
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
prices=prices_test,
window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params,
config=self.config,
dp=self.data_provider))
**env_info))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path))
@ -162,6 +158,20 @@ class BaseReinforcementLearningModel(IFreqaiModel):
actions = self.train_env.get_actions()
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
def pack_env_dict(self) -> Dict[str, Any]:
"""
Create dictionary of environment arguments
"""
env_info = {"window_size": self.CONV_WIDTH,
"reward_kwargs": self.reward_params,
"config": self.config,
"live": self.live}
if self.data_provider:
env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
return env_info
@abstractmethod
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
"""
@ -383,8 +393,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}) -> Callable:
monitor: bool = False,
env_info: Dict[str, Any] = {}) -> Callable:
"""
Utility function for multiprocessed env.
@ -392,13 +402,14 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
:param num_env: (int) the number of environment you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
:param env_info: (dict) all required arguments to instantiate the environment.
:return: (Callable)
"""
def _init() -> gym.Env:
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
**env_info)
if monitor:
env = Monitor(env)
return env

View File

@ -34,17 +34,20 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"]
env_info = self.pack_env_dict()
env_id = "train_env"
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
train_df, prices_train,
monitor=True,
env_info=env_info) for i
in range(self.max_threads)])
eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test,
self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i
monitor=True,
env_info=env_info) for i
in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df),