add dp to multiproc
This commit is contained in:
parent
350cebb0a8
commit
2285ca7d2a
@ -24,6 +24,7 @@ from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
|||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
||||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
|
from freqtrade.data.dataprovider import DataProvider
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -384,7 +385,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||||
seed: int, train_df: DataFrame, price: DataFrame,
|
seed: int, train_df: DataFrame, price: DataFrame,
|
||||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||||
config: Dict[str, Any] = {}) -> Callable:
|
config: Dict[str, Any] = {}, dp: DataProvider = None) -> Callable:
|
||||||
"""
|
"""
|
||||||
Utility function for multiprocessed env.
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
@ -398,7 +399,8 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
|||||||
def _init() -> gym.Env:
|
def _init() -> gym.Env:
|
||||||
|
|
||||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
|
reward_kwargs=reward_params, id=env_id, seed=seed + rank,
|
||||||
|
config=config, dp=dp)
|
||||||
if monitor:
|
if monitor:
|
||||||
env = Monitor(env)
|
env = Monitor(env)
|
||||||
return env
|
return env
|
||||||
|
@ -37,14 +37,14 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
|||||||
env_id = "train_env"
|
env_id = "train_env"
|
||||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
||||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||||
config=self.config) for i
|
config=self.config, dp=self.data_provider) for i
|
||||||
in range(self.max_threads)])
|
in range(self.max_threads)])
|
||||||
|
|
||||||
eval_env_id = 'eval_env'
|
eval_env_id = 'eval_env'
|
||||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||||
test_df, prices_test,
|
test_df, prices_test,
|
||||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||||
config=self.config) for i
|
config=self.config, dp=self.data_provider) for i
|
||||||
in range(self.max_threads)])
|
in range(self.max_threads)])
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
|
Loading…
Reference in New Issue
Block a user