fix ws token auth

This commit is contained in:
Timothy Pogue 2022-09-08 11:25:30 -06:00
parent fac6626459
commit b9e7af1ce2
4 changed files with 28 additions and 7 deletions

View File

@ -62,7 +62,7 @@ async def get_ws_token(
# Just return the token if it matches
return token
else:
logger.debug("Denying websocket request")
logger.info("Denying websocket request")
# If it doesn't match, close the websocket connection
await ws.close(code=status.WS_1008_POLICY_VIOLATION)

View File

@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import get_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
@ -95,6 +96,7 @@ async def message_endpoint(
ws: WebSocket,
rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager),
token: str = Depends(get_ws_token)
):
"""
Message WebSocket endpoint, facilitates sending RPC messages
@ -105,6 +107,9 @@ async def message_endpoint(
# Return a channel ID, pass that instead of ws to the rest of the methods
channel = await channel_manager.on_connect(ws)
if not channel:
return
logger.info(f"Consumer connected - {channel}")
# Keep connection open until explicitly closed, and process requests

View File

@ -139,8 +139,7 @@ class ApiServer(RPCHandler):
)
def configure_app(self, app: FastAPI, config):
from freqtrade.rpc.api_server.api_auth import (get_ws_token, http_basic_or_jwt_token,
router_login)
from freqtrade.rpc.api_server.api_auth import http_basic_or_jwt_token, router_login
from freqtrade.rpc.api_server.api_backtest import router as api_backtest
from freqtrade.rpc.api_server.api_v1 import router as api_v1
from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public
@ -155,9 +154,7 @@ class ApiServer(RPCHandler):
app.include_router(api_backtest, prefix="/api/v1",
dependencies=[Depends(http_basic_or_jwt_token)],
)
app.include_router(ws_router, prefix="/api/v1",
dependencies=[Depends(get_ws_token)]
)
app.include_router(ws_router, prefix="/api/v1")
app.include_router(router_login, prefix="/api/v1", tags=["auth"])
# UI Router MUST be last!
app.include_router(router_ui, prefix='')

View File

@ -10,7 +10,7 @@ from unittest.mock import ANY, MagicMock, PropertyMock
import pandas as pd
import pytest
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, WebSocketDisconnect
from fastapi.exceptions import HTTPException
from fastapi.testclient import TestClient
from requests.auth import _basic_auth_str
@ -31,6 +31,7 @@ from tests.conftest import (CURRENT_TEST_STRATEGY, create_mock_trades, get_mock_
BASE_URI = "/api/v1"
_TEST_USER = "FreqTrader"
_TEST_PASS = "SuperSecurePassword1!"
_TEST_WS_TOKEN = "secret_Ws_t0ken"
@pytest.fixture
@ -44,6 +45,7 @@ def botclient(default_conf, mocker):
"CORS_origins": ['http://example.com'],
"username": _TEST_USER,
"password": _TEST_PASS,
"ws_token": _TEST_WS_TOKEN
}})
ftbot = get_patched_freqtradebot(mocker, default_conf)
@ -155,6 +157,23 @@ def test_api_auth():
get_user_from_token(b'not_a_token', 'secret1234')
def test_api_ws_auth(botclient):
ftbot, client = botclient
bad_token = "bad-ws_token"
url = f"/api/v1/message/ws?token={bad_token}"
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(url) as websocket:
websocket.receive()
good_token = _TEST_WS_TOKEN
url = f"/api/v1/message/ws?token={good_token}"
with client.websocket_connect(url) as websocket:
websocket.send(1)
def test_api_unauthorized(botclient):
ftbot, client = botclient
rc = client.get(f"{BASE_URI}/ping")