expose environment reward parameters to the user config
This commit is contained in:
@@ -110,10 +110,10 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
# environments
|
||||
if not self.train_env:
|
||||
self.train_env = MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params)
|
||||
reward_kwargs=self.reward_params, config=self.config)
|
||||
self.eval_env = Monitor(MyRLEnv(df=test_df, prices=prices_test,
|
||||
window_size=self.CONV_WIDTH,
|
||||
reward_kwargs=self.reward_params), ".")
|
||||
reward_kwargs=self.reward_params, config=self.config), ".")
|
||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||
render=False, eval_freq=eval_freq,
|
||||
best_model_save_path=dk.data_path)
|
||||
@@ -239,7 +239,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
||||
|
||||
|
||||
def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
reward_params, window_size, monitor=False) -> Callable:
|
||||
reward_params, window_size, monitor=False, config={}) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
@@ -252,7 +252,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price,
|
||||
def _init() -> gym.Env:
|
||||
|
||||
env = MyRLEnv(df=train_df, prices=price, window_size=window_size,
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank)
|
||||
reward_kwargs=reward_params, id=env_id, seed=seed + rank, config=config)
|
||||
if monitor:
|
||||
env = Monitor(env, ".")
|
||||
return env
|
||||
@@ -277,16 +277,16 @@ class MyRLEnv(Base5ActionRLEnv):
|
||||
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
|
||||
factor = 1
|
||||
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr:
|
||||
factor = 2
|
||||
factor = self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float((np.log(current_price) - np.log(last_trade_price)) * factor)
|
||||
|
||||
# close short
|
||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||
last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||
current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open)
|
||||
last_trade_price = self.add_exit_fee(self.prices.iloc[self._last_trade_tick].open)
|
||||
current_price = self.add_entry_fee(self.prices.iloc[self._current_tick].open)
|
||||
factor = 1
|
||||
if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr:
|
||||
factor = 2
|
||||
factor = self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||
return float(np.log(last_trade_price) - np.log(current_price) * factor)
|
||||
|
||||
return 0.
|
||||
|
||||
Reference in New Issue
Block a user