Merge pull request #7908 from freqtrade/add-3action-rl-env

Add 3 Action RL Env
This commit is contained in:
Robert Caulk 2022-12-19 14:47:57 +01:00 committed by GitHub
commit cc30210b3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 215 additions and 19 deletions

View File

@ -275,12 +275,12 @@ FreqAI also provides a built in episodic summary logger called `self.tensorboard
### Choosing a base environment
FreqAI provides two base environments, `Base4ActionEnvironment` and `Base5ActionEnvironment`. As the names imply, the environments are customized for agents that can select from 4 or 5 actions. In the `Base4ActionEnvironment`, the agent can enter long, enter short, hold neutral, or exit position. Meanwhile, in the `Base5ActionEnvironment`, the agent has the same actions as Base4, but instead of a single exit action, it separates exit long and exit short. The main changes stemming from the environment selection include:
FreqAI provides three base environments, `Base3ActionRLEnvironment`, `Base4ActionEnvironment` and `Base5ActionEnvironment`. As the names imply, the environments are customized for agents that can select from 3, 4 or 5 actions. The `Base3ActionEnvironment` is the simplest, the agent can select from hold, long, or short. This environment can also be used for long-only bots (it automatically follows the `can_short` flag from the strategy), where long is the enter condition and short is the exit condition. Meanwhile, in the `Base4ActionEnvironment`, the agent can enter long, enter short, hold neutral, or exit position. Finally, in the `Base5ActionEnvironment`, the agent has the same actions as Base4, but instead of a single exit action, it separates exit long and exit short. The main changes stemming from the environment selection include:
* the actions available in the `calculate_reward`
* the actions consumed by the user strategy
Both of the FreqAI provided environments inherit from an action/position agnostic environment object called the `BaseEnvironment`, which contains all shared logic. The architecture is designed to be easily customized. The simplest customization is the `calculate_reward()` (see details [here](#creating-a-custom-reward-function)). However, the customizations can be further extended into any of the functions inside the environment. You can do this by simply overriding those functions inside your `MyRLEnv` in the prediction model file. Or for more advanced customizations, it is encouraged to create an entirely new environment inherited from `BaseEnvironment`.
All of the FreqAI provided environments inherit from an action/position agnostic environment object called the `BaseEnvironment`, which contains all shared logic. The architecture is designed to be easily customized. The simplest customization is the `calculate_reward()` (see details [here](#creating-a-custom-reward-function)). However, the customizations can be further extended into any of the functions inside the environment. You can do this by simply overriding those functions inside your `MyRLEnv` in the prediction model file. Or for more advanced customizations, it is encouraged to create an entirely new environment inherited from `BaseEnvironment`.
!!! Note
FreqAI does not provide by default, a long-only training environment. However, creating one should be as simple as copy-pasting one of the built in environments and removing the `short` actions (and all associated references to those).
Only the `Base3ActionRLEnv` can do long-only training/trading (set the user strategy attribute `can_short = False`).

View File

@ -0,0 +1,125 @@
import logging
from enum import Enum
from gym import spaces
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions
logger = logging.getLogger(__name__)
class Actions(Enum):
Neutral = 0
Buy = 1
Sell = 2
class Base3ActionRLEnv(BaseEnvironment):
"""
Base class for a 3 action environment
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.actions = Actions
def set_action_space(self):
self.action_space = spaces.Discrete(len(Actions))
def step(self, action: int):
"""
Logic for a single step (incrementing one candle in time)
by the agent
:param: action: int = the action type that the agent plans
to take for the current step.
:returns:
observation = current state of environment
step_reward = the reward from `calculate_reward()`
_done = if the agent "died" or if the candles finished
info = dict passed back to openai gym lib
"""
self._done = False
self._current_tick += 1
if self._current_tick == self._end_tick:
self._done = True
self._update_unrealized_total_profit()
step_reward = self.calculate_reward(action)
self.total_reward += step_reward
self.tensorboard_log(self.actions._member_names_[action])
trade_type = None
if self.is_tradesignal(action):
if action == Actions.Buy.value:
if self._position == Positions.Short:
self._update_total_profit()
self._position = Positions.Long
trade_type = "long"
self._last_trade_tick = self._current_tick
elif action == Actions.Sell.value and self.can_short:
if self._position == Positions.Long:
self._update_total_profit()
self._position = Positions.Short
trade_type = "short"
self._last_trade_tick = self._current_tick
elif action == Actions.Sell.value and not self.can_short:
self._update_total_profit()
self._position = Positions.Neutral
trade_type = "neutral"
self._last_trade_tick = None
else:
print("case not defined")
if trade_type is not None:
self.trade_history.append(
{'price': self.current_price(), 'index': self._current_tick,
'type': trade_type})
if (self._total_profit < self.max_drawdown or
self._total_unrealized_profit < self.max_drawdown):
self._done = True
self._position_history.append(self._position)
info = dict(
tick=self._current_tick,
action=action,
total_reward=self.total_reward,
total_profit=self._total_profit,
position=self._position.value,
trade_duration=self.get_trade_duration(),
current_profit_pct=self.get_unrealized_profit()
)
observation = self._get_observation()
self._update_history(info)
return observation, step_reward, self._done, info
def is_tradesignal(self, action: int) -> bool:
"""
Determine if the signal is a trade signal
e.g.: agent wants a Actions.Buy while it is in a Positions.short
"""
return (
(action == Actions.Buy.value and self._position == Positions.Neutral)
or (action == Actions.Sell.value and self._position == Positions.Long)
or (action == Actions.Sell.value and self._position == Positions.Neutral
and self.can_short)
or (action == Actions.Buy.value and self._position == Positions.Short
and self.can_short)
)
def _is_valid(self, action: int) -> bool:
"""
Determine if the signal is valid.
e.g.: agent wants a Actions.Sell while it is in a Positions.Long
"""
if self.can_short:
return action in [Actions.Buy.value, Actions.Sell.value, Actions.Neutral.value]
else:
if action == Actions.Sell.value and self._position != Positions.Long:
return False
return True

View File

@ -45,7 +45,7 @@ class BaseEnvironment(gym.Env):
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
reward_kwargs: dict = {}, window_size=10, starting_point=True,
id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False,
fee: float = 0.0015):
fee: float = 0.0015, can_short: bool = False):
"""
Initializes the training/eval environment.
:param df: dataframe of features
@ -58,6 +58,7 @@ class BaseEnvironment(gym.Env):
:param config: Typical user configuration file
:param live: Whether or not this environment is active in dry/live/backtesting
:param fee: The fee to use for environmental interactions.
:param can_short: Whether or not the environment can short
"""
self.config = config
self.rl_config = config['freqai']['rl_config']
@ -73,6 +74,7 @@ class BaseEnvironment(gym.Env):
# set here to default 5Ac, but all children envs can override this
self.actions: Type[Enum] = BaseActions
self.tensorboard_metrics: dict = {}
self.can_short = can_short
self.live = live
if not self.live and self.add_state_info:
self.add_state_info = False

View File

@ -165,7 +165,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
env_info = {"window_size": self.CONV_WIDTH,
"reward_kwargs": self.reward_params,
"config": self.config,
"live": self.live}
"live": self.live,
"can_short": self.can_short}
if self.data_provider:
env_info["fee"] = self.data_provider._exchange \
.get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore

View File

@ -104,6 +104,7 @@ class IFreqaiModel(ABC):
self.metadata: Dict[str, Any] = self.dd.load_global_metadata_from_disk()
self.data_provider: Optional[DataProvider] = None
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
self.can_short = True # overridden in start() with strategy.can_short
record_params(config, self.full_path)
@ -133,6 +134,7 @@ class IFreqaiModel(ABC):
self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE)
self.dd.set_pair_dict_info(metadata)
self.data_provider = strategy.dp
self.can_short = strategy.can_short
if self.live:
self.inference_timer('start')

View File

@ -27,20 +27,23 @@ def is_mac() -> bool:
return "Darwin" in machine
@pytest.mark.parametrize('model, pca, dbscan, float32', [
('LightGBMRegressor', True, False, True),
('XGBoostRegressor', False, True, False),
('XGBoostRFRegressor', False, False, False),
('CatboostRegressor', False, False, False),
('ReinforcementLearner', False, True, False),
('ReinforcementLearner_multiproc', False, False, False),
('ReinforcementLearner_test_4ac', False, False, False)
@pytest.mark.parametrize('model, pca, dbscan, float32, can_short', [
('LightGBMRegressor', True, False, True, True),
('XGBoostRegressor', False, True, False, True),
('XGBoostRFRegressor', False, False, False, True),
('CatboostRegressor', False, False, False, True),
('ReinforcementLearner', False, True, False, True),
('ReinforcementLearner_multiproc', False, False, False, True),
('ReinforcementLearner_test_3ac', False, False, False, False),
('ReinforcementLearner_test_3ac', False, False, False, True),
('ReinforcementLearner_test_4ac', False, False, False, True)
])
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32):
def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
dbscan, float32, can_short):
if is_arm() and model == 'CatboostRegressor':
pytest.skip("CatBoost is not supported on ARM")
if is_mac() and 'Reinforcement' in model:
if is_mac() and not is_arm() and 'Reinforcement' in model:
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
model_save_ext = 'joblib'
@ -58,9 +61,6 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
if 'test_4ac' in model:
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
if 'ReinforcementLearner' in model:
model_save_ext = 'zip'
freqai_conf = make_rl_config(freqai_conf)
@ -68,7 +68,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
if 'test_4ac' in model:
if 'test_3ac' in model or 'test_4ac' in model:
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
@ -77,6 +77,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
strategy.freqai_info = freqai_conf.get("freqai", {})
freqai = strategy.freqai
freqai.live = True
freqai.can_short = can_short
freqai.dk = FreqaiDataKitchen(freqai_conf)
freqai.dk.set_paths('ADA/BTC', 10000)
timerange = TimeRange.parse_timerange("20180110-20180130")

View File

@ -0,0 +1,65 @@
import logging
import numpy as np
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.Base3ActionRLEnv import Actions, Base3ActionRLEnv, Positions
logger = logging.getLogger(__name__)
class ReinforcementLearner_test_3ac(ReinforcementLearner):
"""
User created Reinforcement Learning Model prediction model.
"""
class MyRLEnv(Base3ActionRLEnv):
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.
"""
def calculate_reward(self, action: int) -> float:
# first, penalize if the action is not valid
if not self._is_valid(action):
return -2
pnl = self.get_unrealized_profit()
rew = np.sign(pnl) * (pnl + 1)
factor = 100.
# reward agent for entering trades
if (action in (Actions.Buy.value, Actions.Sell.value)
and self._position == Positions.Neutral):
return 25
# discourage agent from not entering trades
if action == Actions.Neutral.value and self._position == Positions.Neutral:
return -1
max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300)
trade_duration = self._current_tick - self._last_trade_tick # type: ignore
if trade_duration <= max_trade_duration:
factor *= 1.5
elif trade_duration > max_trade_duration:
factor *= 0.5
# discourage sitting in position
if self._position in (Positions.Short, Positions.Long) and (
action == Actions.Neutral.value
or (action == Actions.Sell.value and self._position == Positions.Short)
or (action == Actions.Buy.value and self._position == Positions.Long)
):
return -1 * trade_duration / max_trade_duration
# close position
if (action == Actions.Buy.value and self._position == Positions.Short) or (
action == Actions.Sell.value and self._position == Positions.Long
):
if pnl > self.profit_aim * self.rr:
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
return float(rew * factor)
return 0.