diff --git a/freqtrade/rpc/api_server/api_auth.py b/freqtrade/rpc/api_server/api_auth.py index 6655dbf86..0d1378b6d 100644 --- a/freqtrade/rpc/api_server/api_auth.py +++ b/freqtrade/rpc/api_server/api_auth.py @@ -62,7 +62,7 @@ async def get_ws_token( # Just return the token if it matches return token else: - logger.debug("Denying websocket request") + logger.info("Denying websocket request") # If it doesn't match, close the websocket connection await ws.close(code=status.WS_1008_POLICY_VIOLATION) diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 45cc20e4d..384bd4115 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect from starlette.websockets import WebSocketState from freqtrade.enums import RPCMessageType, RPCRequestType +from freqtrade.rpc.api_server.api_auth import get_ws_token from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.ws.channel import WebSocketChannel from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage, @@ -95,6 +96,7 @@ async def message_endpoint( ws: WebSocket, rpc: RPC = Depends(get_rpc), channel_manager=Depends(get_channel_manager), + token: str = Depends(get_ws_token) ): """ Message WebSocket endpoint, facilitates sending RPC messages @@ -105,6 +107,9 @@ async def message_endpoint( # Return a channel ID, pass that instead of ws to the rest of the methods channel = await channel_manager.on_connect(ws) + if not channel: + return + logger.info(f"Consumer connected - {channel}") # Keep connection open until explicitly closed, and process requests diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 6ad3f143e..73e80dd48 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -139,8 +139,7 @@ class ApiServer(RPCHandler): ) def configure_app(self, app: FastAPI, config): - from freqtrade.rpc.api_server.api_auth import (get_ws_token, http_basic_or_jwt_token, - router_login) + from freqtrade.rpc.api_server.api_auth import http_basic_or_jwt_token, router_login from freqtrade.rpc.api_server.api_backtest import router as api_backtest from freqtrade.rpc.api_server.api_v1 import router as api_v1 from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public @@ -155,9 +154,7 @@ class ApiServer(RPCHandler): app.include_router(api_backtest, prefix="/api/v1", dependencies=[Depends(http_basic_or_jwt_token)], ) - app.include_router(ws_router, prefix="/api/v1", - dependencies=[Depends(get_ws_token)] - ) + app.include_router(ws_router, prefix="/api/v1") app.include_router(router_login, prefix="/api/v1", tags=["auth"]) # UI Router MUST be last! app.include_router(router_ui, prefix='') diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index 0efcc00c1..ccfe31424 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -10,7 +10,7 @@ from unittest.mock import ANY, MagicMock, PropertyMock import pandas as pd import pytest import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, WebSocketDisconnect from fastapi.exceptions import HTTPException from fastapi.testclient import TestClient from requests.auth import _basic_auth_str @@ -31,6 +31,7 @@ from tests.conftest import (CURRENT_TEST_STRATEGY, create_mock_trades, get_mock_ BASE_URI = "/api/v1" _TEST_USER = "FreqTrader" _TEST_PASS = "SuperSecurePassword1!" +_TEST_WS_TOKEN = "secret_Ws_t0ken" @pytest.fixture @@ -44,6 +45,7 @@ def botclient(default_conf, mocker): "CORS_origins": ['http://example.com'], "username": _TEST_USER, "password": _TEST_PASS, + "ws_token": _TEST_WS_TOKEN }}) ftbot = get_patched_freqtradebot(mocker, default_conf) @@ -155,6 +157,23 @@ def test_api_auth(): get_user_from_token(b'not_a_token', 'secret1234') +def test_api_ws_auth(botclient): + ftbot, client = botclient + + bad_token = "bad-ws_token" + url = f"/api/v1/message/ws?token={bad_token}" + + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect(url) as websocket: + websocket.receive() + + good_token = _TEST_WS_TOKEN + url = f"/api/v1/message/ws?token={good_token}" + + with client.websocket_connect(url) as websocket: + websocket.send(1) + + def test_api_unauthorized(botclient): ftbot, client = botclient rc = client.get(f"{BASE_URI}/ping")