Add 3ac test
This commit is contained in:
parent
7727f31507
commit
a8c9aa01fb
@ -34,6 +34,7 @@ def is_mac() -> bool:
|
|||||||
('CatboostRegressor', False, False, False),
|
('CatboostRegressor', False, False, False),
|
||||||
('ReinforcementLearner', False, True, False),
|
('ReinforcementLearner', False, True, False),
|
||||||
('ReinforcementLearner_multiproc', False, False, False),
|
('ReinforcementLearner_multiproc', False, False, False),
|
||||||
|
('ReinforcementLearner_test_3ac', False, False, False),
|
||||||
('ReinforcementLearner_test_4ac', False, False, False)
|
('ReinforcementLearner_test_4ac', False, False, False)
|
||||||
])
|
])
|
||||||
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32):
|
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32):
|
||||||
@ -58,7 +59,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
|||||||
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
||||||
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
||||||
|
|
||||||
if 'test_4ac' in model:
|
if 'test_3ac' in model or 'test_4ac' in model:
|
||||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
||||||
|
|
||||||
if 'ReinforcementLearner' in model:
|
if 'ReinforcementLearner' in model:
|
||||||
@ -68,7 +69,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
|||||||
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
||||||
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
||||||
|
|
||||||
if 'test_4ac' in model:
|
if 'test_3ac' in model or 'test_4ac' in model:
|
||||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
||||||
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
|
65
tests/freqai/test_models/ReinforcementLearner_test_3ac.py
Normal file
65
tests/freqai/test_models/ReinforcementLearner_test_3ac.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||||
|
from freqtrade.freqai.RL.Base3ActionRLEnv import Actions, Base3ActionRLEnv, Positions
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReinforcementLearner_test_3ac(ReinforcementLearner):
|
||||||
|
"""
|
||||||
|
User created Reinforcement Learning Model prediction model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyRLEnv(Base3ActionRLEnv):
|
||||||
|
"""
|
||||||
|
User can override any function in BaseRLEnv and gym.Env. Here the user
|
||||||
|
sets a custom reward based on profit and trade duration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def calculate_reward(self, action: int) -> float:
|
||||||
|
|
||||||
|
# first, penalize if the action is not valid
|
||||||
|
if not self._is_valid(action):
|
||||||
|
return -2
|
||||||
|
|
||||||
|
pnl = self.get_unrealized_profit()
|
||||||
|
rew = np.sign(pnl) * (pnl + 1)
|
||||||
|
factor = 100.
|
||||||
|
|
||||||
|
# reward agent for entering trades
|
||||||
|
if (action in (Actions.Buy.value, Actions.Sell.value)
|
||||||
|
and self._position == Positions.Neutral):
|
||||||
|
return 25
|
||||||
|
# discourage agent from not entering trades
|
||||||
|
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
|
||||||
|
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
|
||||||
|
|
||||||
|
if trade_duration <= max_trade_duration:
|
||||||
|
factor *= 1.5
|
||||||
|
elif trade_duration > max_trade_duration:
|
||||||
|
factor *= 0.5
|
||||||
|
|
||||||
|
# discourage sitting in position
|
||||||
|
if self._position in (Positions.Short, Positions.Long) and (
|
||||||
|
action == Actions.Neutral.value
|
||||||
|
or (action == Actions.Sell.value and self._position == Positions.Short)
|
||||||
|
or (action == Actions.Buy.value and self._position == Positions.Long)
|
||||||
|
):
|
||||||
|
return -1 * trade_duration / max_trade_duration
|
||||||
|
|
||||||
|
# close position
|
||||||
|
if (action == Actions.Buy.value and self._position == Positions.Short) or (
|
||||||
|
action == Actions.Sell.value and self._position == Positions.Long
|
||||||
|
):
|
||||||
|
if pnl > self.profit_aim * self.rr:
|
||||||
|
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
|
||||||
|
return float(rew * factor)
|
||||||
|
|
||||||
|
return 0.
|
Loading…
Reference in New Issue
Block a user