stable/tests/freqai/test_models/ReinforcementLearner_test_4ac.py

67 lines
2.4 KiB
Python
Raw Normal View History

2022-09-23 07:19:16 +00:00
import logging
import numpy as np
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
2022-11-24 17:57:01 +00:00
from freqtrade.freqai.RL.Base4ActionRLEnv import Actions, Base4ActionRLEnv, Positions
2022-09-23 07:19:16 +00:00
logger = logging.getLogger(__name__)
class ReinforcementLearner_test_4ac(ReinforcementLearner):
2022-09-23 07:19:16 +00:00
"""
User created Reinforcement Learning Model prediction model.
"""
class MyRLEnv(Base4ActionRLEnv):
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
2022-11-26 11:11:59 +00:00
def calculate_reward(self, action: int) -> float:
2022-09-23 07:19:16 +00:00
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2
pnl = self.get_unrealized_profit()
rew = np.sign(pnl) * (pnl + 1)
2022-11-26 11:11:59 +00:00
factor = 100.
2022-09-23 07:19:16 +00:00
# reward agent for entering trades
if (action in (Actions.Long_enter.value, Actions.Short_enter.value)
and self._position == Positions.Neutral):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
2022-11-26 11:11:59 +00:00
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
2022-09-23 07:19:16 +00:00
if trade_duration <= max_trade_duration:
factor *= 1.5
elif trade_duration > max_trade_duration:
factor *= 0.5
# discourage sitting in position
if (self._position in (Positions.Short, Positions.Long) and
action == Actions.Neutral.value):
return -1 * trade_duration / max_trade_duration
# close long
if action == Actions.Exit.value and self._position == Positions.Long:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(rew * factor)
# close short
if action == Actions.Exit.value and self._position == Positions.Short:
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2)
return float(rew * factor)
return 0.