Merge pull request #7908 from freqtrade/add-3action-rl-env
Add 3 Action RL Env
This commit is contained in:
@@ -27,20 +27,23 @@ def is_mac() -> bool:
|
||||
return "Darwin" in machine
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model, pca, dbscan, float32', [
|
||||
('LightGBMRegressor', True, False, True),
|
||||
('XGBoostRegressor', False, True, False),
|
||||
('XGBoostRFRegressor', False, False, False),
|
||||
('CatboostRegressor', False, False, False),
|
||||
('ReinforcementLearner', False, True, False),
|
||||
('ReinforcementLearner_multiproc', False, False, False),
|
||||
('ReinforcementLearner_test_4ac', False, False, False)
|
||||
@pytest.mark.parametrize('model, pca, dbscan, float32, can_short', [
|
||||
('LightGBMRegressor', True, False, True, True),
|
||||
('XGBoostRegressor', False, True, False, True),
|
||||
('XGBoostRFRegressor', False, False, False, True),
|
||||
('CatboostRegressor', False, False, False, True),
|
||||
('ReinforcementLearner', False, True, False, True),
|
||||
('ReinforcementLearner_multiproc', False, False, False, True),
|
||||
('ReinforcementLearner_test_3ac', False, False, False, False),
|
||||
('ReinforcementLearner_test_3ac', False, False, False, True),
|
||||
('ReinforcementLearner_test_4ac', False, False, False, True)
|
||||
])
|
||||
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32):
|
||||
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
||||
dbscan, float32, can_short):
|
||||
if is_arm() and model == 'CatboostRegressor':
|
||||
pytest.skip("CatBoost is not supported on ARM")
|
||||
|
||||
if is_mac() and 'Reinforcement' in model:
|
||||
if is_mac() and not is_arm() and 'Reinforcement' in model:
|
||||
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
|
||||
|
||||
model_save_ext = 'joblib'
|
||||
@@ -58,9 +61,6 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
||||
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
||||
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
||||
|
||||
if 'test_4ac' in model:
|
||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
||||
|
||||
if 'ReinforcementLearner' in model:
|
||||
model_save_ext = 'zip'
|
||||
freqai_conf = make_rl_config(freqai_conf)
|
||||
@@ -68,7 +68,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
||||
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
||||
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
||||
|
||||
if 'test_4ac' in model:
|
||||
if 'test_3ac' in model or 'test_4ac' in model:
|
||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
||||
|
||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||
@@ -77,6 +77,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
||||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
||||
freqai = strategy.freqai
|
||||
freqai.live = True
|
||||
freqai.can_short = can_short
|
||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
||||
freqai.dk.set_paths('ADA/BTC', 10000)
|
||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||
|
65
tests/freqai/test_models/ReinforcementLearner_test_3ac.py
Normal file
65
tests/freqai/test_models/ReinforcementLearner_test_3ac.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.Base3ActionRLEnv import Actions, Base3ActionRLEnv, Positions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReinforcementLearner_test_3ac(ReinforcementLearner):
|
||||
"""
|
||||
User created Reinforcement Learning Model prediction model.
|
||||
"""
|
||||
|
||||
class MyRLEnv(Base3ActionRLEnv):
|
||||
"""
|
||||
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||
sets a custom reward based on profit and trade duration.
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
|
||||
# first, penalize if the action is not valid
|
||||
if not self._is_valid(action):
|
||||
return -2
|
||||
|
||||
pnl = self.get_unrealized_profit()
|
||||
rew = np.sign(pnl) * (pnl + 1)
|
||||
factor = 100.
|
||||
|
||||
# reward agent for entering trades
|
||||
if (action in (Actions.Buy.value, Actions.Sell.value)
|
||||
and self._position == Positions.Neutral):
|
||||
return 25
|
||||
# discourage agent from not entering trades
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
return -1
|
||||
|
||||
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
|
||||
|
||||
if trade_duration <= max_trade_duration:
|
||||
factor *= 1.5
|
||||
elif trade_duration > max_trade_duration:
|
||||
factor *= 0.5
|
||||
|
||||
# discourage sitting in position
|
||||
if self._position in (Positions.Short, Positions.Long) and (
|
||||
action == Actions.Neutral.value
|
||||
or (action == Actions.Sell.value and self._position == Positions.Short)
|
||||
or (action == Actions.Buy.value and self._position == Positions.Long)
|
||||
):
|
||||
return -1 * trade_duration / max_trade_duration
|
||||
|
||||
# close position
|
||||
if (action == Actions.Buy.value and self._position == Positions.Short) or (
|
||||
action == Actions.Sell.value and self._position == Positions.Long
|
||||
):
|
||||
if pnl > self.profit_aim * self.rr:
|
||||
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
|
||||
return float(rew * factor)
|
||||
|
||||
return 0.
|
Reference in New Issue
Block a user