add state/action info to callbacks
This commit is contained in:
parent
0be82b4ed1
commit
075c8c23c8
@ -71,7 +71,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
|
|
||||||
model.learn(
|
model.learn(
|
||||||
total_timesteps=int(total_timesteps),
|
total_timesteps=int(total_timesteps),
|
||||||
callback=self.eval_callback
|
callback=[self.eval_callback, self.tensorboard_callback]
|
||||||
)
|
)
|
||||||
|
|
||||||
if Path(dk.data_path / "best_model.zip").is_file():
|
if Path(dk.data_path / "best_model.zip").is_file():
|
||||||
@ -88,6 +88,33 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
sets a custom reward based on profit and trade duration.
|
sets a custom reward based on profit and trade duration.
|
||||||
"""
|
"""
|
||||||
|
def reset(self):
|
||||||
|
|
||||||
|
# Reset custom info
|
||||||
|
self.custom_info = {}
|
||||||
|
self.custom_info["Invalid"] = 0
|
||||||
|
self.custom_info["Hold"] = 0
|
||||||
|
self.custom_info["Unknown"] = 0
|
||||||
|
self.custom_info["pnl_factor"] = 0
|
||||||
|
self.custom_info["duration_factor"] = 0
|
||||||
|
self.custom_info["reward_exit"] = 0
|
||||||
|
self.custom_info["reward_hold"] = 0
|
||||||
|
for action in Actions:
|
||||||
|
self.custom_info[f"{action.name}"] = 0
|
||||||
|
return super().reset()
|
||||||
|
|
||||||
|
def step(self, action: int):
|
||||||
|
observation, step_reward, done, info = super().step(action)
|
||||||
|
info = dict(
|
||||||
|
tick=self._current_tick,
|
||||||
|
action=action,
|
||||||
|
total_reward=self.total_reward,
|
||||||
|
total_profit=self._total_profit,
|
||||||
|
position=self._position.value,
|
||||||
|
trade_duration=self.get_trade_duration(),
|
||||||
|
current_profit_pct=self.get_unrealized_profit()
|
||||||
|
)
|
||||||
|
return observation, step_reward, done, info
|
||||||
|
|
||||||
def calculate_reward(self, action: int) -> float:
|
def calculate_reward(self, action: int) -> float:
|
||||||
"""
|
"""
|
||||||
@ -100,17 +127,24 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
"""
|
"""
|
||||||
# first, penalize if the action is not valid
|
# first, penalize if the action is not valid
|
||||||
if not self._is_valid(action):
|
if not self._is_valid(action):
|
||||||
|
self.custom_info["Invalid"] += 1
|
||||||
return -2
|
return -2
|
||||||
|
|
||||||
pnl = self.get_unrealized_profit()
|
pnl = self.get_unrealized_profit()
|
||||||
factor = 100.
|
factor = 100.
|
||||||
|
|
||||||
# reward agent for entering trades
|
# reward agent for entering trades
|
||||||
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
|
if (action ==Actions.Long_enter.value
|
||||||
and self._position == Positions.Neutral):
|
and self._position == Positions.Neutral):
|
||||||
|
self.custom_info[f"{Actions.Long_enter.name}"] += 1
|
||||||
|
return 25
|
||||||
|
if (action == Actions.Short_enter.value
|
||||||
|
and self._position == Positions.Neutral):
|
||||||
|
self.custom_info[f"{Actions.Short_enter.name}"] += 1
|
||||||
return 25
|
return 25
|
||||||
# discourage agent from not entering trades
|
# discourage agent from not entering trades
|
||||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
|
self.custom_info[f"{Actions.Neutral.name}"] += 1
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||||
@ -124,18 +158,22 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
# discourage sitting in position
|
# discourage sitting in position
|
||||||
if (self._position in (Positions.Short, Positions.Long) and
|
if (self._position in (Positions.Short, Positions.Long) and
|
||||||
action == Actions.Neutral.value):
|
action == Actions.Neutral.value):
|
||||||
|
self.custom_info["Hold"] += 1
|
||||||
return -1 * trade_duration / max_trade_duration
|
return -1 * trade_duration / max_trade_duration
|
||||||
|
|
||||||
# close long
|
# close long
|
||||||
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
if action == Actions.Long_exit.value and self._position == Positions.Long:
|
||||||
if pnl > self.profit_aim * self.rr:
|
if pnl > self.profit_aim * self.rr:
|
||||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||||
|
self.custom_info[f"{Actions.Long_exit.name}"] += 1
|
||||||
return float(pnl * factor)
|
return float(pnl * factor)
|
||||||
|
|
||||||
# close short
|
# close short
|
||||||
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
if action == Actions.Short_exit.value and self._position == Positions.Short:
|
||||||
if pnl > self.profit_aim * self.rr:
|
if pnl > self.profit_aim * self.rr:
|
||||||
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
|
||||||
|
self.custom_info[f"{Actions.Short_exit.name}"] += 1
|
||||||
return float(pnl * factor)
|
return float(pnl * factor)
|
||||||
|
|
||||||
|
self.custom_info["Unknown"] += 1
|
||||||
return 0.
|
return 0.
|
||||||
|
Loading…
Reference in New Issue
Block a user