Add env_info dict to base environment
This commit is contained in:
parent
2285ca7d2a
commit
2018da0767
@ -11,9 +11,6 @@ from gym import spaces
|
|||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
from freqtrade.data.dataprovider import DataProvider
|
|
||||||
from freqtrade.enums import RunMode
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -48,7 +45,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
|
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
|
||||||
dp: Optional[DataProvider] = None):
|
env_info: dict = {}):
|
||||||
"""
|
"""
|
||||||
Initializes the training/eval environment.
|
Initializes the training/eval environment.
|
||||||
:param df: dataframe of features
|
:param df: dataframe of features
|
||||||
@ -59,7 +56,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
:param id: string id of the environment (used in backend for multiprocessed env)
|
:param id: string id of the environment (used in backend for multiprocessed env)
|
||||||
:param seed: Sets the seed of the environment higher in the gym.Env object
|
:param seed: Sets the seed of the environment higher in the gym.Env object
|
||||||
:param config: Typical user configuration file
|
:param config: Typical user configuration file
|
||||||
:param dp: dataprovider from freqtrade
|
:param env_info: Environment info dictionary, used to pass live status, fee, etc.
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rl_config = config['freqai']['rl_config']
|
self.rl_config = config['freqai']['rl_config']
|
||||||
@ -71,17 +68,13 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.compound_trades = config['stake_amount'] == 'unlimited'
|
self.compound_trades = config['stake_amount'] == 'unlimited'
|
||||||
if self.config.get('fee', None) is not None:
|
if self.config.get('fee', None) is not None:
|
||||||
self.fee = self.config['fee']
|
self.fee = self.config['fee']
|
||||||
elif dp is not None:
|
|
||||||
self.fee = dp._exchange.get_fee(symbol=dp.current_whitelist()[0]) # type: ignore
|
|
||||||
else:
|
else:
|
||||||
self.fee = 0.0015
|
self.fee = env_info.get('fee', 0.0015)
|
||||||
|
|
||||||
# set here to default 5Ac, but all children envs can override this
|
# set here to default 5Ac, but all children envs can override this
|
||||||
self.actions: Type[Enum] = BaseActions
|
self.actions: Type[Enum] = BaseActions
|
||||||
self.tensorboard_metrics: dict = {}
|
self.tensorboard_metrics: dict = {}
|
||||||
self.live: bool = False
|
self.live = env_info.get('live', False)
|
||||||
if dp:
|
|
||||||
self.live = dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
|
||||||
if not self.live and self.add_state_info:
|
if not self.live and self.add_state_info:
|
||||||
self.add_state_info = False
|
self.add_state_info = False
|
||||||
logger.warning("add_state_info is not available in backtesting. Deactivating.")
|
logger.warning("add_state_info is not available in backtesting. Deactivating.")
|
||||||
@ -213,7 +206,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
"""
|
"""
|
||||||
features_window = self.signal_features[(
|
features_window = self.signal_features[(
|
||||||
self._current_tick - self.window_size):self._current_tick]
|
self._current_tick - self.window_size):self._current_tick]
|
||||||
if self.add_state_info and self.live:
|
if self.add_state_info:
|
||||||
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
||||||
columns=['current_profit_pct',
|
columns=['current_profit_pct',
|
||||||
'position',
|
'position',
|
||||||
|
@ -17,6 +17,7 @@ from stable_baselines3.common.monitor import Monitor
|
|||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
|
from freqtrade.enums import RunMode
|
||||||
from freqtrade.exceptions import OperationalException
|
from freqtrade.exceptions import OperationalException
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||||
@ -24,7 +25,6 @@ from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
|||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
||||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
from freqtrade.data.dataprovider import DataProvider
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -144,18 +144,24 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
test_df = data_dictionary["test_features"]
|
test_df = data_dictionary["test_features"]
|
||||||
|
|
||||||
|
env_info = {"live": False}
|
||||||
|
if self.data_provider:
|
||||||
|
env_info["live"] = self.data_provider.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
||||||
|
env_info["fee"] = self.data_provider._exchange \
|
||||||
|
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||||
|
|
||||||
self.train_env = self.MyRLEnv(df=train_df,
|
self.train_env = self.MyRLEnv(df=train_df,
|
||||||
prices=prices_train,
|
prices=prices_train,
|
||||||
window_size=self.CONV_WIDTH,
|
window_size=self.CONV_WIDTH,
|
||||||
reward_kwargs=self.reward_params,
|
reward_kwargs=self.reward_params,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
dp=self.data_provider)
|
env_info=env_info)
|
||||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
|
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
|
||||||
prices=prices_test,
|
prices=prices_test,
|
||||||
window_size=self.CONV_WIDTH,
|
window_size=self.CONV_WIDTH,
|
||||||
reward_kwargs=self.reward_params,
|
reward_kwargs=self.reward_params,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
dp=self.data_provider))
|
env_info=env_info))
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
@ -385,7 +391,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
||||||
seed: int, train_df: DataFrame, price: DataFrame,
|
seed: int, train_df: DataFrame, price: DataFrame,
|
||||||
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
reward_params: Dict[str, int], window_size: int, monitor: bool = False,
|
||||||
config: Dict[str, Any] = {}, dp: DataProvider = None) -> Callable:
|
config: Dict[str, Any] = {}, env_info: Dict[str, Any] = {}) -> Callable:
|
||||||
"""
|
"""
|
||||||
Utility function for multiprocessed env.
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
@ -400,7 +406,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
|||||||
|
|
||||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank,
|
reward_kwargs=reward_params, id=env_id, seed=seed + rank,
|
||||||
config=config, dp=dp)
|
config=config, env_info=env_info)
|
||||||
if monitor:
|
if monitor:
|
||||||
env = Monitor(env)
|
env = Monitor(env)
|
||||||
return env
|
return env
|
||||||
|
@ -5,6 +5,7 @@ from pandas import DataFrame
|
|||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
|
from freqtrade.enums import RunMode
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||||
@ -34,17 +35,23 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
|||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
test_df = data_dictionary["test_features"]
|
test_df = data_dictionary["test_features"]
|
||||||
|
|
||||||
|
env_info = {"live": False}
|
||||||
|
if self.data_provider:
|
||||||
|
env_info["live"] = self.data_provider.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
|
||||||
|
env_info["fee"] = self.data_provider._exchange \
|
||||||
|
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
|
||||||
|
|
||||||
env_id = "train_env"
|
env_id = "train_env"
|
||||||
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
|
||||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||||
config=self.config, dp=self.data_provider) for i
|
config=self.config, env_info=env_info) for i
|
||||||
in range(self.max_threads)])
|
in range(self.max_threads)])
|
||||||
|
|
||||||
eval_env_id = 'eval_env'
|
eval_env_id = 'eval_env'
|
||||||
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
|
||||||
test_df, prices_test,
|
test_df, prices_test,
|
||||||
self.reward_params, self.CONV_WIDTH, monitor=True,
|
self.reward_params, self.CONV_WIDTH, monitor=True,
|
||||||
config=self.config, dp=self.data_provider) for i
|
config=self.config, env_info=env_info) for i
|
||||||
in range(self.max_threads)])
|
in range(self.max_threads)])
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
|
Loading…
Reference in New Issue
Block a user