diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 5e9b81108..b77f21d58 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -24,6 +24,7 @@ from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback from freqtrade.persistence import Trade +from freqtrade.data.dataprovider import DataProvider logger = logging.getLogger(__name__) @@ -384,7 +385,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame, reward_params: Dict[str, int], window_size: int, monitor: bool = False, - config: Dict[str, Any] = {}) -> Callable: + config: Dict[str, Any] = {}, dp: DataProvider = None) -> Callable: """ Utility function for multiprocessed env. @@ -398,7 +399,8 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, def _init() -> gym.Env: env = MyRLEnv(df=train_df, prices=price, window_size=window_size, - reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config) + reward_kwargs=reward_params, id=env_id, seed=seed + rank, + config=config, dp=dp) if monitor: env = Monitor(env) return env diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index 32a2a2076..c9b824978 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -37,14 +37,14 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): env_id = "train_env" self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, self.reward_params, self.CONV_WIDTH, monitor=True, - config=self.config) for i + config=self.config, dp=self.data_provider) for i in range(self.max_threads)]) eval_env_id = 'eval_env' self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, test_df, prices_test, self.reward_params, self.CONV_WIDTH, monitor=True, - config=self.config) for i + config=self.config, dp=self.data_provider) for i in range(self.max_threads)]) self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df),