Add env_info dict to base environment

This commit is contained in:
Emre 2022-12-14 22:03:05 +03:00
parent 2285ca7d2a
commit 2018da0767
No known key found for this signature in database
GPG Key ID: 0EAD2EE11B666ABA
3 changed files with 25 additions and 19 deletions

View File

@ -11,9 +11,6 @@ from gym import spaces
from gym.utils import seeding from gym.utils import seeding
from pandas import DataFrame from pandas import DataFrame
from freqtrade.data.dataprovider import DataProvider
from freqtrade.enums import RunMode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,7 +45,7 @@ class BaseEnvironment(gym.Env):
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(), def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True, reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, id: str = 'baseenv-1', seed: int = 1, config: dict = {},
dp: Optional[DataProvider] = None): env_info: dict = {}):
""" """
Initializes the training/eval environment. Initializes the training/eval environment.
:param df: dataframe of features :param df: dataframe of features
@ -59,7 +56,7 @@ class BaseEnvironment(gym.Env):
:param id: string id of the environment (used in backend for multiprocessed env) :param id: string id of the environment (used in backend for multiprocessed env)
:param seed: Sets the seed of the environment higher in the gym.Env object :param seed: Sets the seed of the environment higher in the gym.Env object
:param config: Typical user configuration file :param config: Typical user configuration file
:param dp: dataprovider from freqtrade :param env_info: Environment info dictionary, used to pass live status, fee, etc.
""" """
self.config = config self.config = config
self.rl_config = config['freqai']['rl_config'] self.rl_config = config['freqai']['rl_config']
@ -71,17 +68,13 @@ class BaseEnvironment(gym.Env):
self.compound_trades = config['stake_amount'] == 'unlimited' self.compound_trades = config['stake_amount'] == 'unlimited'
if self.config.get('fee', None) is not None: if self.config.get('fee', None) is not None:
self.fee = self.config['fee'] self.fee = self.config['fee']
elif dp is not None:
self.fee = dp._exchange.get_fee(symbol=dp.current_whitelist()[0]) # type: ignore
else: else:
self.fee = 0.0015 self.fee = env_info.get('fee', 0.0015)
# set here to default 5Ac, but all children envs can override this # set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions self.actions: Type[Enum] = BaseActions
self.tensorboard_metrics: dict = {} self.tensorboard_metrics: dict = {}
self.live: bool = False self.live = env_info.get('live', False)
if dp:
self.live = dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
if not self.live and self.add_state_info: if not self.live and self.add_state_info:
self.add_state_info = False self.add_state_info = False
logger.warning("add_state_info is not available in backtesting. Deactivating.") logger.warning("add_state_info is not available in backtesting. Deactivating.")
@ -213,7 +206,7 @@ class BaseEnvironment(gym.Env):
""" """
features_window = self.signal_features[( features_window = self.signal_features[(
self._current_tick - self.window_size):self._current_tick] self._current_tick - self.window_size):self._current_tick]
if self.add_state_info and self.live: if self.add_state_info:
features_and_state = DataFrame(np.zeros((len(features_window), 3)), features_and_state = DataFrame(np.zeros((len(features_window), 3)),
columns=['current_profit_pct', columns=['current_profit_pct',
'position', 'position',

View File

@ -17,6 +17,7 @@ from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.enums import RunMode
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.freqai_interface import IFreqaiModel
@ -24,7 +25,6 @@ from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
from freqtrade.data.dataprovider import DataProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -144,18 +144,24 @@ class BaseReinforcementLearningModel(IFreqaiModel):
train_df = data_dictionary["train_features"] train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"] test_df = data_dictionary["test_features"]
env_info = {"live": False}
if self.data_provider:
env_info["live"] = self.data_provider.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
self.train_env = self.MyRLEnv(df=train_df, self.train_env = self.MyRLEnv(df=train_df,
prices=prices_train, prices=prices_train,
window_size=self.CONV_WIDTH, window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, reward_kwargs=self.reward_params,
config=self.config, config=self.config,
dp=self.data_provider) env_info=env_info)
self.eval_env = Monitor(self.MyRLEnv(df=test_df, self.eval_env = Monitor(self.MyRLEnv(df=test_df,
prices=prices_test, prices=prices_test,
window_size=self.CONV_WIDTH, window_size=self.CONV_WIDTH,
reward_kwargs=self.reward_params, reward_kwargs=self.reward_params,
config=self.config, config=self.config,
dp=self.data_provider)) env_info=env_info))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
@ -385,7 +391,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame, seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False, reward_params: Dict[str, int], window_size: int, monitor: bool = False,
config: Dict[str, Any] = {}, dp: DataProvider = None) -> Callable: config: Dict[str, Any] = {}, env_info: Dict[str, Any] = {}) -> Callable:
""" """
Utility function for multiprocessed env. Utility function for multiprocessed env.
@ -400,7 +406,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
env = MyRLEnv(df=train_df, prices=price, window_size=window_size, env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
reward_kwargs=reward_params, id=env_id, seed=seed + rank, reward_kwargs=reward_params, id=env_id, seed=seed + rank,
config=config, dp=dp) config=config, env_info=env_info)
if monitor: if monitor:
env = Monitor(env) env = Monitor(env)
return env return env

View File

@ -5,6 +5,7 @@ from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.enums import RunMode
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
@ -34,17 +35,23 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
train_df = data_dictionary["train_features"] train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"] test_df = data_dictionary["test_features"]
env_info = {"live": False}
if self.data_provider:
env_info["live"] = self.data_provider.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
env_id = "train_env" env_id = "train_env"
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train,
self.reward_params, self.CONV_WIDTH, monitor=True, self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config, dp=self.data_provider) for i config=self.config, env_info=env_info) for i
in range(self.max_threads)]) in range(self.max_threads)])
eval_env_id = 'eval_env' eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test, test_df, prices_test,
self.reward_params, self.CONV_WIDTH, monitor=True, self.reward_params, self.CONV_WIDTH, monitor=True,
config=self.config, dp=self.data_provider) for i config=self.config, env_info=env_info) for i
in range(self.max_threads)]) in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),