diff --git a/freqtrade/rpc/api_server/api_backtest.py b/freqtrade/rpc/api_server/api_backtest.py index b17636a7d..bc2a40d91 100644 --- a/freqtrade/rpc/api_server/api_backtest.py +++ b/freqtrade/rpc/api_server/api_backtest.py @@ -11,6 +11,7 @@ from freqtrade.configuration.config_validation import validate_config_consistenc from freqtrade.data.btanalysis import get_backtest_resultlist, load_and_merge_backtest_result from freqtrade.enums import BacktestState from freqtrade.exceptions import DependencyException +from freqtrade.misc import deep_merge_dicts from freqtrade.rpc.api_server.api_schemas import (BacktestHistoryEntry, BacktestRequest, BacktestResponse) from freqtrade.rpc.api_server.deps import get_config, is_webserver_mode @@ -37,10 +38,11 @@ async def api_start_backtest(bt_settings: BacktestRequest, background_tasks: Bac btconfig = deepcopy(config) settings = dict(bt_settings) + if settings.get('freqai', None) is not None: + settings['freqai'] = dict(settings['freqai']) # Pydantic models will contain all keys, but non-provided ones are None - for setting in settings.keys(): - if settings[setting] is not None: - btconfig[setting] = settings[setting] + + btconfig = deep_merge_dicts(settings, btconfig, allow_null_overrides=False) try: btconfig['stake_amount'] = float(btconfig['stake_amount']) except ValueError: diff --git a/freqtrade/rpc/api_server/api_schemas.py b/freqtrade/rpc/api_server/api_schemas.py index ada20230a..17dff222d 100644 --- a/freqtrade/rpc/api_server/api_schemas.py +++ b/freqtrade/rpc/api_server/api_schemas.py @@ -372,6 +372,10 @@ class StrategyListResponse(BaseModel): strategies: List[str] +class FreqAIModelListResponse(BaseModel): + freqaimodels: List[str] + + class StrategyResponse(BaseModel): strategy: str code: str @@ -410,6 +414,10 @@ class PairHistory(BaseModel): } +class BacktestFreqAIInputs(BaseModel): + identifier: str + + class BacktestRequest(BaseModel): strategy: str timeframe: Optional[str] @@ -419,6 +427,9 @@ class BacktestRequest(BaseModel): stake_amount: Optional[str] enable_protections: bool dry_run_wallet: Optional[float] + backtest_cache: Optional[str] + freqaimodel: Optional[str] + freqai: Optional[BacktestFreqAIInputs] class BacktestResponse(BaseModel): diff --git a/freqtrade/rpc/api_server/api_v1.py b/freqtrade/rpc/api_server/api_v1.py index 9e4b140e4..e26df6eea 100644 --- a/freqtrade/rpc/api_server/api_v1.py +++ b/freqtrade/rpc/api_server/api_v1.py @@ -13,12 +13,13 @@ from freqtrade.rpc import RPC from freqtrade.rpc.api_server.api_schemas import (AvailablePairs, Balances, BlacklistPayload, BlacklistResponse, Count, Daily, DeleteLockRequest, DeleteTrade, ForceEnterPayload, - ForceEnterResponse, ForceExitPayload, Health, - Locks, Logs, OpenTradeSchema, PairHistory, - PerformanceEntry, Ping, PlotConfig, Profit, - ResultMsg, ShowConfig, Stats, StatusMsg, - StrategyListResponse, StrategyResponse, SysInfo, - Version, WhitelistResponse) + ForceEnterResponse, ForceExitPayload, + FreqAIModelListResponse, Health, Locks, Logs, + OpenTradeSchema, PairHistory, PerformanceEntry, + Ping, PlotConfig, Profit, ResultMsg, ShowConfig, + Stats, StatusMsg, StrategyListResponse, + StrategyResponse, SysInfo, Version, + WhitelistResponse) from freqtrade.rpc.api_server.deps import get_config, get_exchange, get_rpc, get_rpc_optional from freqtrade.rpc.rpc import RPCException @@ -38,7 +39,8 @@ logger = logging.getLogger(__name__) # 2.17: Forceentry - leverage, partial force_exit # 2.20: Add websocket endpoints # 2.21: Add new_candle messagetype -API_VERSION = 2.21 +# 2.22: Add FreqAI to backtesting +API_VERSION = 2.22 # Public API, requires no auth. router_public = APIRouter() @@ -279,6 +281,16 @@ def get_strategy(strategy: str, config=Depends(get_config)): } +@router.get('/freqaimodels', response_model=FreqAIModelListResponse, tags=['freqai']) +def list_freqaimodels(config=Depends(get_config)): + from freqtrade.resolvers.freqaimodel_resolver import FreqaiModelResolver + strategies = FreqaiModelResolver.search_all_objects( + config, False) + strategies = sorted(strategies, key=lambda x: x['name']) + + return {'freqaimodels': [x['name'] for x in strategies]} + + @router.get('/available_pairs', response_model=AvailablePairs, tags=['candle data']) def list_available_pairs(timeframe: Optional[str] = None, stake_currency: Optional[str] = None, candletype: Optional[CandleType] = None, config=Depends(get_config)): diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index ee067f911..aea8ea059 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -1488,6 +1488,44 @@ def test_api_strategy(botclient): assert_response(rc, 500) +def test_api_freqaimodels(botclient, tmpdir, mocker): + ftbot, client = botclient + ftbot.config['user_data_dir'] = Path(tmpdir) + mocker.patch( + "freqtrade.resolvers.freqaimodel_resolver.FreqaiModelResolver.search_all_objects", + return_value=[ + {'name': 'LightGBMClassifier'}, + {'name': 'LightGBMClassifierMultiTarget'}, + {'name': 'LightGBMRegressor'}, + {'name': 'LightGBMRegressorMultiTarget'}, + {'name': 'ReinforcementLearner'}, + {'name': 'ReinforcementLearner_multiproc'}, + {'name': 'XGBoostClassifier'}, + {'name': 'XGBoostRFClassifier'}, + {'name': 'XGBoostRFRegressor'}, + {'name': 'XGBoostRegressor'}, + {'name': 'XGBoostRegressorMultiTarget'}, + ]) + + rc = client_get(client, f"{BASE_URI}/freqaimodels") + + assert_response(rc) + + assert rc.json() == {'freqaimodels': [ + 'LightGBMClassifier', + 'LightGBMClassifierMultiTarget', + 'LightGBMRegressor', + 'LightGBMRegressorMultiTarget', + 'ReinforcementLearner', + 'ReinforcementLearner_multiproc', + 'XGBoostClassifier', + 'XGBoostRFClassifier', + 'XGBoostRFRegressor', + 'XGBoostRegressor', + 'XGBoostRegressorMultiTarget' + ]} + + def test_list_available_pairs(botclient): ftbot, client = botclient