use a dictionary to make code more readable

This commit is contained in:
robcaulk 2022-12-15 12:25:33 +01:00
parent d3ad5cb722
commit 7b4abd5ef5
3 changed files with 32 additions and 34 deletions

View File

@ -44,8 +44,8 @@ class BaseEnvironment(gym.Env):
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(), def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True, reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
env_info: dict = {}): fee: float = 0.0015):
""" """
Initializes the training/eval environment. Initializes the training/eval environment.
:param df: dataframe of features :param df: dataframe of features
@ -67,12 +67,12 @@ class BaseEnvironment(gym.Env):
if self.config.get('fee', None) is not None: if self.config.get('fee', None) is not None:
self.fee = self.config['fee'] self.fee = self.config['fee']
else: else:
self.fee = env_info.get('fee', 0.0015) self.fee = fee
# set here to default 5Ac, but all children envs can override this # set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions self.actions: Type[Enum] = BaseActions
self.tensorboard_metrics: dict = {} self.tensorboard_metrics: dict = {}
self.live = env_info.get('live', False) self.live = live
if not self.live and self.add_state_info: if not self.live and self.add_state_info:
self.add_state_info = False self.add_state_info = False
logger.warning("add_state_info is not available in backtesting. Deactivating.") logger.warning("add_state_info is not available in backtesting. Deactivating.")

View File

@ -17,7 +17,6 @@ from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.enums import RunMode
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.freqai_interface import IFreqaiModel
@ -144,24 +143,14 @@ class BaseReinforcementLearningModel(IFreqaiModel):
train_df = data_dictionary["train_features"] train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"] test_df = data_dictionary["test_features"]
env_info = {"live": False} env_info = self.pack_env_dict()
if self.data_provider:
env_info["live"] = self.data_provider.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
self.train_env = self.MyRLEnv(df=train_df, self.train_env = self.MyRLEnv(df=train_df,
prices=prices_train, prices=prices_train,
window_size=self.CONV_WIDTH, **env_info)
reward_kwargs=self.reward_params,
config=self.config,
env_info=env_info)
self.eval_env = Monitor(self.MyRLEnv(df=test_df, self.eval_env = Monitor(self.MyRLEnv(df=test_df,
prices=prices_test, prices=prices_test,
window_size=self.CONV_WIDTH, **env_info))
reward_kwargs=self.reward_params,
config=self.config,
env_info=env_info))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),
best_model_save_path=str(dk.data_path)) best_model_save_path=str(dk.data_path))
@ -169,6 +158,20 @@ class BaseReinforcementLearningModel(IFreqaiModel):
actions = self.train_env.get_actions() actions = self.train_env.get_actions()
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
def pack_env_dict(self) -> Dict[str, Any]:
"""
Create dictionary of environment arguments
"""
env_info = {"window_size": self.CONV_WIDTH,
"reward_kwargs": self.reward_params,
"config": self.config,
"live": self.live}
if self.data_provider:
env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
return env_info
@abstractmethod @abstractmethod
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
""" """
@ -390,8 +393,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame, seed: int, train_df: DataFrame, price: DataFrame,
reward_params: Dict[str, int], window_size: int, monitor: bool = False, monitor: bool = False,
config: Dict[str, Any] = {}, env_info: Dict[str, Any] = {}) -> Callable: env_info: Dict[str, Any] = {}) -> Callable:
""" """
Utility function for multiprocessed env. Utility function for multiprocessed env.
@ -404,9 +407,8 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
def _init() -> gym.Env: def _init() -> gym.Env:
env = MyRLEnv(df=train_df, prices=price, window_size=window_size, env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
reward_kwargs=reward_params, id=env_id, seed=seed + rank, **env_info)
config=config, env_info=env_info)
if monitor: if monitor:
env = Monitor(env) env = Monitor(env)
return env return env

View File

@ -5,7 +5,6 @@ from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.enums import RunMode
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
@ -35,23 +34,20 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
train_df = data_dictionary["train_features"] train_df = data_dictionary["train_features"]
test_df = data_dictionary["test_features"] test_df = data_dictionary["test_features"]
env_info = {"live": False} env_info = self.pack_env_dict()
if self.data_provider:
env_info["live"] = self.data_provider.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore
env_id = "train_env" env_id = "train_env"
self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
self.reward_params, self.CONV_WIDTH, monitor=True, train_df, prices_train,
config=self.config, env_info=env_info) for i monitor=True,
env_info=env_info) for i
in range(self.max_threads)]) in range(self.max_threads)])
eval_env_id = 'eval_env' eval_env_id = 'eval_env'
self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test, test_df, prices_test,
self.reward_params, self.CONV_WIDTH, monitor=True, monitor=True,
config=self.config, env_info=env_info) for i env_info=env_info) for i
in range(self.max_threads)]) in range(self.max_threads)])
self.eval_callback = EvalCallback(self.eval_env, deterministic=True, self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=len(train_df), render=False, eval_freq=len(train_df),