Add env_info dict to base environment
This commit is contained in:
@@ -17,6 +17,7 @@ from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from freqtrade.enums import RunMode
|
||||
from freqtrade.exceptions import OperationalException
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
@@ -24,7 +25,6 @@ from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
from freqtrade.persistence import Trade
|
||||
from freqtrade.data.dataprovider import DataProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -144,18 +144,24 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
train_df = data_dictionary["train_features"]
|
||||
test_df = data_dictionary["test_features"]
|
||||
|
||||
env_info = {"live": False}
|
||||
if self.data_provider:
|
||||
env_info["live"] = self.data_provider.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
||||
env_info["fee"] = self.data_provider._exchange \
|
||||
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||
|
||||
self.train_env = self.MyRLEnv(df=train_df,
|
||||
prices=prices_train,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider)
|
||||
env_info=env_info)
|
||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
|
||||
prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params,
|
||||
config=self.config,
|
||||
dp=self.data_provider))
|
||||
env_info=env_info))
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=len(train_df),
|
||||
best_model_save_path=str(dk.data_path))
|
||||
@@ -385,7 +391,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
seed: int, train_df: DataFrame, price: DataFrame,
|
||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||
config: Dict[str, Any] = {}, dp: DataProvider = None) -> Callable:
|
||||
config: Dict[str, Any] = {}, env_info: Dict[str, Any] = {}) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
@@ -400,7 +406,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank,
|
||||
config=config, dp=dp)
|
||||
config=config, env_info=env_info)
|
||||
if monitor:
|
||||
env = Monitor(env)
|
||||
return env
|
||||
|
Reference in New Issue
Block a user