Simplify backtest calling interface

This commit is contained in:
Matthias 2022-10-18 06:39:55 +02:00
parent c7fff1213c
commit c3d4fb9f1b
4 changed files with 10 additions and 16 deletions

View File

@ -152,6 +152,7 @@ class Backtesting:
# strategies which define "can_short=True" will fail to load in Spot mode. # strategies which define "can_short=True" will fail to load in Spot mode.
self._can_short = self.trading_mode != TradingMode.SPOT self._can_short = self.trading_mode != TradingMode.SPOT
self._position_stacking: bool = self.config.get('position_stacking', False) self._position_stacking: bool = self.config.get('position_stacking', False)
self.enable_protections: bool = self.config.get('enable_protections', False)
self.init_backtest() self.init_backtest()
@ -960,9 +961,8 @@ class Backtesting:
return 'short' return 'short'
return None return None
def run_protections( def run_protections(self, pair: str, current_time: datetime, side: LongShort):
self, enable_protections, pair: str, current_time: datetime, side: LongShort): if self.enable_protections:
if enable_protections:
self.protections.stop_per_pair(pair, current_time, side) self.protections.stop_per_pair(pair, current_time, side)
self.protections.global_stop(current_time, side) self.protections.global_stop(current_time, side)
@ -1070,8 +1070,7 @@ class Backtesting:
def backtest_loop( def backtest_loop(
self, row: Tuple, pair: str, current_time: datetime, end_date: datetime, self, row: Tuple, pair: str, current_time: datetime, end_date: datetime,
max_open_trades: int, enable_protections: bool, max_open_trades: int, open_trade_count_start: int) -> int:
open_trade_count_start: int) -> int:
""" """
NOTE: This method is used by Hyperopt at each iteration. Please keep it optimized. NOTE: This method is used by Hyperopt at each iteration. Please keep it optimized.
@ -1135,13 +1134,12 @@ class Backtesting:
# logger.debug(f"{pair} - Backtesting exit {trade}") # logger.debug(f"{pair} - Backtesting exit {trade}")
LocalTrade.close_bt_trade(trade) LocalTrade.close_bt_trade(trade)
self.wallets.update() self.wallets.update()
self.run_protections(enable_protections, pair, current_time, trade.trade_direction) self.run_protections(pair, current_time, trade.trade_direction)
return open_trade_count_start return open_trade_count_start
def backtest(self, processed: Dict, def backtest(self, processed: Dict,
start_date: datetime, end_date: datetime, start_date: datetime, end_date: datetime,
max_open_trades: int = 0, max_open_trades: int = 0) -> Dict[str, Any]:
enable_protections: bool = False) -> Dict[str, Any]:
""" """
Implement backtesting functionality Implement backtesting functionality
@ -1154,10 +1152,9 @@ class Backtesting:
:param start_date: backtesting timerange start datetime :param start_date: backtesting timerange start datetime
:param end_date: backtesting timerange end datetime :param end_date: backtesting timerange end datetime
:param max_open_trades: maximum number of concurrent trades, <= 0 means unlimited :param max_open_trades: maximum number of concurrent trades, <= 0 means unlimited
:param enable_protections: Should protections be enabled?
:return: DataFrame with trades (results of backtesting) :return: DataFrame with trades (results of backtesting)
""" """
self.prepare_backtest(enable_protections) self.prepare_backtest(self.enable_protections)
# Ensure wallets are uptodate (important for --strategy-list) # Ensure wallets are uptodate (important for --strategy-list)
self.wallets.update() self.wallets.update()
# Use dict of lists with data for performance # Use dict of lists with data for performance
@ -1186,8 +1183,7 @@ class Backtesting:
self.dataprovider._set_dataframe_max_index(row_index) self.dataprovider._set_dataframe_max_index(row_index)
open_trade_count_start = self.backtest_loop( open_trade_count_start = self.backtest_loop(
row, pair, current_time, end_date, max_open_trades, row, pair, current_time, end_date, max_open_trades, open_trade_count_start)
enable_protections, open_trade_count_start)
# Move time one configured time_interval ahead. # Move time one configured time_interval ahead.
self.progress.increment() self.progress.increment()
@ -1249,7 +1245,6 @@ class Backtesting:
start_date=min_date, start_date=min_date,
end_date=max_date, end_date=max_date,
max_open_trades=max_open_trades, max_open_trades=max_open_trades,
enable_protections=self.config.get('enable_protections', False),
) )
backtest_end_time = datetime.now(timezone.utc) backtest_end_time = datetime.now(timezone.utc)
results.update({ results.update({

View File

@ -257,6 +257,7 @@ class Hyperopt:
logger.debug("Hyperopt has 'protection' space") logger.debug("Hyperopt has 'protection' space")
# Enable Protections if protection space is selected. # Enable Protections if protection space is selected.
self.config['enable_protections'] = True self.config['enable_protections'] = True
self.backtesting.enable_protections = True
self.protection_space = self.custom_hyperopt.protection_space() self.protection_space = self.custom_hyperopt.protection_space()
if HyperoptTools.has_space(self.config, 'buy'): if HyperoptTools.has_space(self.config, 'buy'):
@ -338,7 +339,6 @@ class Hyperopt:
start_date=self.min_date, start_date=self.min_date,
end_date=self.max_date, end_date=self.max_date,
max_open_trades=self.max_open_trades, max_open_trades=self.max_open_trades,
enable_protections=self.config.get('enable_protections', False),
) )
backtest_end_time = datetime.now(timezone.utc) backtest_end_time = datetime.now(timezone.utc)
bt_results.update({ bt_results.update({

View File

@ -89,6 +89,7 @@ async def api_start_backtest(bt_settings: BacktestRequest, background_tasks: Bac
lastconfig['enable_protections'] = btconfig.get('enable_protections') lastconfig['enable_protections'] = btconfig.get('enable_protections')
lastconfig['dry_run_wallet'] = btconfig.get('dry_run_wallet') lastconfig['dry_run_wallet'] = btconfig.get('dry_run_wallet')
ApiServer._bt.enable_protections = btconfig.get('enable_protections', False)
ApiServer._bt.strategylist = [strat] ApiServer._bt.strategylist = [strat]
ApiServer._bt.results = {} ApiServer._bt.results = {}
ApiServer._bt.load_prior_backtest() ApiServer._bt.load_prior_backtest()

View File

@ -973,7 +973,6 @@ def test_backtest_pricecontours_protections(default_conf, fee, mocker, testdatad
start_date=min_date, start_date=min_date,
end_date=max_date, end_date=max_date,
max_open_trades=1, max_open_trades=1,
enable_protections=default_conf.get('enable_protections', False),
) )
assert len(results['results']) == numres assert len(results['results']) == numres
@ -1016,7 +1015,6 @@ def test_backtest_pricecontours(default_conf, fee, mocker, testdatadir,
start_date=min_date, start_date=min_date,
end_date=max_date, end_date=max_date,
max_open_trades=1, max_open_trades=1,
enable_protections=default_conf.get('enable_protections', False),
) )
assert len(results['results']) == expected assert len(results['results']) == expected