add test coverage, fix bug in base environment. Ensure proper fee is used.
This commit is contained in:
parent
81f800a79b
commit
af9e400562
@ -148,7 +148,6 @@ class Base5ActionRLEnv(BaseEnvironment):
|
|||||||
return self._current_tick - self._last_trade_tick
|
return self._current_tick - self._last_trade_tick
|
||||||
|
|
||||||
def is_tradesignal(self, action: int):
|
def is_tradesignal(self, action: int):
|
||||||
# trade signal
|
|
||||||
"""
|
"""
|
||||||
Determine if the signal is a trade signal
|
Determine if the signal is a trade signal
|
||||||
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
e.g.: agent wants a Actions.Long_exit while it is in a Positions.short
|
||||||
|
@ -10,6 +10,8 @@ from gym import spaces
|
|||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
from freqtrade.data.dataprovider import DataProvider
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -32,8 +34,21 @@ class BaseEnvironment(gym.Env):
|
|||||||
|
|
||||||
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(),
|
||||||
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
reward_kwargs: dict = {}, window_size=10, starting_point=True,
|
||||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
|
id: str = 'baseenv-1', seed: int = 1, config: dict = {},
|
||||||
|
dp: Optional[DataProvider] = None):
|
||||||
|
"""
|
||||||
|
Initializes the training/eval environment.
|
||||||
|
:param df: dataframe of features
|
||||||
|
:param prices: dataframe of prices to be used in the training environment
|
||||||
|
:param window_size: size of window (temporal) to pass to the agent
|
||||||
|
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||||
|
:param starting_point: start at edge of window or not
|
||||||
|
:param id: string id of the environment (used in backend for multiprocessed env)
|
||||||
|
:param seed: Sets the seed of the environment higher in the gym.Env object
|
||||||
|
:param config: Typical user configuration file
|
||||||
|
:param dp: dataprovider from freqtrade
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
self.rl_config = config['freqai']['rl_config']
|
self.rl_config = config['freqai']['rl_config']
|
||||||
self.add_state_info = self.rl_config.get('add_state_info', False)
|
self.add_state_info = self.rl_config.get('add_state_info', False)
|
||||||
self.id = id
|
self.id = id
|
||||||
@ -41,12 +56,23 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||||
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
self.max_drawdown = 1 - self.rl_config.get('max_training_drawdown_pct', 0.8)
|
||||||
self.compound_trades = config['stake_amount'] == 'unlimited'
|
self.compound_trades = config['stake_amount'] == 'unlimited'
|
||||||
|
if self.config.get('fee', None) is not None:
|
||||||
|
self.fee = self.config['fee']
|
||||||
|
elif dp is not None:
|
||||||
|
self.fee = self.dp.exchange.get_fee(symbol=dp.current_whitelist()[0])
|
||||||
|
else:
|
||||||
|
self.fee = 0.0015
|
||||||
|
|
||||||
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
def reset_env(self, df: DataFrame, prices: DataFrame, window_size: int,
|
||||||
reward_kwargs: dict, starting_point=True):
|
reward_kwargs: dict, starting_point=True):
|
||||||
"""
|
"""
|
||||||
Resets the environment when the agent fails (in our case, if the drawdown
|
Resets the environment when the agent fails (in our case, if the drawdown
|
||||||
exceeds the user set max_training_drawdown_pct)
|
exceeds the user set max_training_drawdown_pct)
|
||||||
|
:param df: dataframe of features
|
||||||
|
:param prices: dataframe of prices to be used in the training environment
|
||||||
|
:param window_size: size of window (temporal) to pass to the agent
|
||||||
|
:param reward_kwargs: extra config settings assigned by user in `rl_config`
|
||||||
|
:param starting_point: start at edge of window or not
|
||||||
"""
|
"""
|
||||||
self.df = df
|
self.df = df
|
||||||
self.signal_features = self.df
|
self.signal_features = self.df
|
||||||
@ -56,8 +82,6 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.rr = reward_kwargs["rr"]
|
self.rr = reward_kwargs["rr"]
|
||||||
self.profit_aim = reward_kwargs["profit_aim"]
|
self.profit_aim = reward_kwargs["profit_aim"]
|
||||||
|
|
||||||
self.fee = 0.0015
|
|
||||||
|
|
||||||
# # spaces
|
# # spaces
|
||||||
if self.add_state_info:
|
if self.add_state_info:
|
||||||
self.total_features = self.signal_features.shape[1] + 3
|
self.total_features = self.signal_features.shape[1] + 3
|
||||||
@ -233,7 +257,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
def _update_total_profit(self):
|
def _update_total_profit(self):
|
||||||
pnl = self.get_unrealized_profit()
|
pnl = self.get_unrealized_profit()
|
||||||
if self.compound_trades:
|
if self.compound_trades:
|
||||||
# assumes unitestake and compounding
|
# assumes unit stake and compounding
|
||||||
self._total_profit = self._total_profit * (1 + pnl)
|
self._total_profit = self._total_profit * (1 + pnl)
|
||||||
else:
|
else:
|
||||||
# assumes unit stake and no compounding
|
# assumes unit stake and no compounding
|
||||||
|
@ -74,10 +74,10 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
||||||
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
|
logger.warning('User tried to use SVM with RL. Deactivating SVM.')
|
||||||
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
|
if self.ft_params.get('use_DBSCAN_to_remove_outliers', False):
|
||||||
self.ft_params.update({'use_SVM_to_remove_outliers': False})
|
self.ft_params.update({'use_DBSCAN_to_remove_outliers': False})
|
||||||
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
|
logger.warning('User tried to use DBSCAN with RL. Deactivating DBSCAN.')
|
||||||
if self.freqai_info['data_split_parameters'].get('shuffle', False):
|
if self.freqai_info['data_split_parameters'].get('shuffle', False):
|
||||||
self.freqai_info['data_split_parameters'].update('shuffle', False)
|
self.freqai_info['data_split_parameters'].update({'shuffle': False})
|
||||||
logger.warning('User tried to shuffle training data. Setting shuffle to False')
|
logger.warning('User tried to shuffle training data. Setting shuffle to False')
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -141,11 +141,18 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
train_df = data_dictionary["train_features"]
|
train_df = data_dictionary["train_features"]
|
||||||
test_df = data_dictionary["test_features"]
|
test_df = data_dictionary["test_features"]
|
||||||
|
|
||||||
self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, window_size=self.CONV_WIDTH,
|
self.train_env = self.MyRLEnv(df=train_df,
|
||||||
reward_kwargs=self.reward_params, config=self.config)
|
prices=prices_train,
|
||||||
self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test,
|
|
||||||
window_size=self.CONV_WIDTH,
|
window_size=self.CONV_WIDTH,
|
||||||
reward_kwargs=self.reward_params, config=self.config))
|
reward_kwargs=self.reward_params,
|
||||||
|
config=self.config,
|
||||||
|
dp=self.data_provider)
|
||||||
|
self.eval_env = Monitor(self.MyRLEnv(df=test_df,
|
||||||
|
prices=prices_test,
|
||||||
|
window_size=self.CONV_WIDTH,
|
||||||
|
reward_kwargs=self.reward_params,
|
||||||
|
config=self.config,
|
||||||
|
dp=self.data_provider))
|
||||||
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
|
||||||
render=False, eval_freq=len(train_df),
|
render=False, eval_freq=len(train_df),
|
||||||
best_model_save_path=str(dk.data_path))
|
best_model_save_path=str(dk.data_path))
|
||||||
@ -179,12 +186,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
if trade.pair == pair:
|
if trade.pair == pair:
|
||||||
if self.data_provider._exchange is None: # type: ignore
|
if self.data_provider._exchange is None: # type: ignore
|
||||||
logger.error('No exchange available.')
|
logger.error('No exchange available.')
|
||||||
|
return 0, 0, 0
|
||||||
else:
|
else:
|
||||||
current_rate = self.data_provider._exchange.get_rate( # type: ignore
|
current_rate = self.data_provider._exchange.get_rate( # type: ignore
|
||||||
pair, refresh=False, side="exit", is_short=trade.is_short)
|
pair, refresh=False, side="exit", is_short=trade.is_short)
|
||||||
|
|
||||||
now = datetime.now(timezone.utc).timestamp()
|
now = datetime.now(timezone.utc).timestamp()
|
||||||
trade_duration = int((now - trade.open_date_utc) / self.base_tf_seconds)
|
trade_duration = int((now - trade.open_date_utc.timestamp()) / self.base_tf_seconds)
|
||||||
current_profit = trade.calc_profit_ratio(current_rate)
|
current_profit = trade.calc_profit_ratio(current_rate)
|
||||||
|
|
||||||
return market_side, current_profit, int(trade_duration)
|
return market_side, current_profit, int(trade_duration)
|
||||||
@ -230,7 +238,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
def _predict(window):
|
def _predict(window):
|
||||||
observations = dataframe.iloc[window.index]
|
observations = dataframe.iloc[window.index]
|
||||||
if self.live: # self.guard_state_info_if_backtest():
|
if self.live and self.rl_config('add_state_info', False):
|
||||||
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
|
market_side, current_profit, trade_duration = self.get_state_info(dk.pair)
|
||||||
observations['current_profit_pct'] = current_profit
|
observations['current_profit_pct'] = current_profit
|
||||||
observations['position'] = market_side
|
observations['position'] = market_side
|
||||||
@ -242,17 +250,6 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# def guard_state_info_if_backtest(self):
|
|
||||||
# """
|
|
||||||
# Ensure that backtesting mode doesnt try to use state information.
|
|
||||||
# """
|
|
||||||
# if self.rl_config('add_state_info', False) and not self.live:
|
|
||||||
# logger.warning('Backtesting with state info is currently unavailable '
|
|
||||||
# 'turning it off.')
|
|
||||||
# self.rl_config['add_state_info'] = False
|
|
||||||
|
|
||||||
# return not self.rl_config['add_state_info']
|
|
||||||
|
|
||||||
def build_ohlc_price_dataframes(self, data_dictionary: dict,
|
def build_ohlc_price_dataframes(self, data_dictionary: dict,
|
||||||
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
|
pair: str, dk: FreqaiDataKitchen) -> Tuple[DataFrame,
|
||||||
DataFrame]:
|
DataFrame]:
|
||||||
|
@ -13,7 +13,7 @@ from freqtrade.freqai.utils import download_all_data_for_training, get_required_
|
|||||||
from freqtrade.optimize.backtesting import Backtesting
|
from freqtrade.optimize.backtesting import Backtesting
|
||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
from freqtrade.plugins.pairlistmanager import PairListManager
|
from freqtrade.plugins.pairlistmanager import PairListManager
|
||||||
from tests.conftest import get_patched_exchange, log_has_re
|
from tests.conftest import create_mock_trades, get_patched_exchange, log_has_re
|
||||||
from tests.freqai.conftest import get_patched_freqai_strategy, make_rl_config
|
from tests.freqai.conftest import get_patched_freqai_strategy, make_rl_config
|
||||||
|
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ def is_mac() -> bool:
|
|||||||
('XGBoostRegressor', False, True),
|
('XGBoostRegressor', False, True),
|
||||||
('XGBoostRFRegressor', False, False),
|
('XGBoostRFRegressor', False, False),
|
||||||
('CatboostRegressor', False, False),
|
('CatboostRegressor', False, False),
|
||||||
('ReinforcementLearner', False, False),
|
('ReinforcementLearner', False, True),
|
||||||
('ReinforcementLearner_multiproc', False, False),
|
('ReinforcementLearner_multiproc', False, False),
|
||||||
('ReinforcementLearner_test_4ac', False, False)
|
('ReinforcementLearner_test_4ac', False, False)
|
||||||
])
|
])
|
||||||
@ -40,7 +40,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
|||||||
if is_arm() and model == 'CatboostRegressor':
|
if is_arm() and model == 'CatboostRegressor':
|
||||||
pytest.skip("CatBoost is not supported on ARM")
|
pytest.skip("CatBoost is not supported on ARM")
|
||||||
|
|
||||||
if is_mac():
|
if is_mac() and 'Reinforcement' in model:
|
||||||
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
|
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
|
||||||
|
|
||||||
model_save_ext = 'joblib'
|
model_save_ext = 'joblib'
|
||||||
@ -53,6 +53,9 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
|
|||||||
if 'ReinforcementLearner' in model:
|
if 'ReinforcementLearner' in model:
|
||||||
model_save_ext = 'zip'
|
model_save_ext = 'zip'
|
||||||
freqai_conf = make_rl_config(freqai_conf)
|
freqai_conf = make_rl_config(freqai_conf)
|
||||||
|
# test the RL guardrails
|
||||||
|
freqai_conf['freqai']['feature_parameters'].update({"use_SVM_to_remove_outliers": True})
|
||||||
|
freqai_conf['freqai']['data_split_parameters'].update({'shuffle': True})
|
||||||
|
|
||||||
if 'test_4ac' in model:
|
if 'test_4ac' in model:
|
||||||
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
|
||||||
@ -497,3 +500,43 @@ def test_download_all_data_for_training(mocker, freqai_conf, caplog, tmpdir):
|
|||||||
"Downloading",
|
"Downloading",
|
||||||
caplog,
|
caplog,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("init_persistence")
|
||||||
|
@pytest.mark.parametrize('dp_exists', [(False), (True)])
|
||||||
|
def test_get_state_info(mocker, freqai_conf, dp_exists, caplog, tickers):
|
||||||
|
|
||||||
|
if is_mac():
|
||||||
|
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
|
||||||
|
|
||||||
|
freqai_conf.update({"freqaimodel": "ReinforcementLearner"})
|
||||||
|
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||||
|
freqai_conf.update({"strategy": "freqai_rl_test_strat"})
|
||||||
|
freqai_conf = make_rl_config(freqai_conf)
|
||||||
|
freqai_conf['entry_pricing']['price_side'] = 'same'
|
||||||
|
freqai_conf['exit_pricing']['price_side'] = 'same'
|
||||||
|
|
||||||
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
|
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||||
|
ticker_mock = MagicMock(return_value=tickers()['ETH/BTC'])
|
||||||
|
mocker.patch("freqtrade.exchange.Exchange.fetch_ticker", ticker_mock)
|
||||||
|
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||||
|
|
||||||
|
if not dp_exists:
|
||||||
|
strategy.dp._exchange = None
|
||||||
|
|
||||||
|
strategy.freqai_info = freqai_conf.get("freqai", {})
|
||||||
|
freqai = strategy.freqai
|
||||||
|
freqai.data_provider = strategy.dp
|
||||||
|
freqai.live = True
|
||||||
|
|
||||||
|
Trade.use_db = True
|
||||||
|
create_mock_trades(MagicMock(return_value=0.0025), False, True)
|
||||||
|
freqai.get_state_info("ADA/BTC")
|
||||||
|
freqai.get_state_info("ETH/BTC")
|
||||||
|
|
||||||
|
if not dp_exists:
|
||||||
|
assert log_has_re(
|
||||||
|
"No exchange available",
|
||||||
|
caplog,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user