add dp to multiproc

This commit is contained in:
robcaulk 2022-12-14 18:22:20 +01:00
parent 350cebb0a8
commit 2285ca7d2a
2 changed files with 6 additions and 4 deletions

View File

@ -24,6 +24,7 @@ from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
from freqtrade.data.dataprovider import DataProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -384,7 +385,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame, seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False, reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}) -> Callable: config: Dict[str, Any] = {}, dp: DataProvider = None) -> Callable:
""" """
Utility function for multiprocessed env. Utility function for multiprocessed env.
@ -398,7 +399,8 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
def _init() -> gym.Env: def _init() -> gym.Env:
env = MyRLEnv(df=train_df, prices=price, window_size=window_size, env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config) reward_kwargs=reward_params, id=env_id, seed=seed + rank,
config=config, dp=dp)
if monitor: if monitor:
env = Monitor(env) env = Monitor(env)
return env return env

View File

@ -37,14 +37,14 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
env_id = "train_env" env_id = "train_env"
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
self.reward_params, self.CONV_WIDTH, monitor=True, self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i config=self.config, dp=self.data_provider) for i
in range(self.max_threads)]) in range(self.max_threads)])
eval_env_id = 'eval_env' eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test, test_df, prices_test,
self.reward_params, self.CONV_WIDTH, monitor=True, self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config) for i config=self.config, dp=self.data_provider) for i
in range(self.max_threads)]) in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),