fix ws token auth
This commit is contained in:
parent
fac6626459
commit
b9e7af1ce2
@ -62,7 +62,7 @@ async def get_ws_token(
|
|||||||
# Just return the token if it matches
|
# Just return the token if it matches
|
||||||
return token
|
return token
|
||||||
else:
|
else:
|
||||||
logger.debug("Denying websocket request")
|
logger.info("Denying websocket request")
|
||||||
# If it doesn't match, close the websocket connection
|
# If it doesn't match, close the websocket connection
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|||||||
from starlette.websockets import WebSocketState
|
from starlette.websockets import WebSocketState
|
||||||
|
|
||||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||||
|
from freqtrade.rpc.api_server.api_auth import get_ws_token
|
||||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||||
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
|
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
|
||||||
@ -95,6 +96,7 @@ async def message_endpoint(
|
|||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
rpc: RPC = Depends(get_rpc),
|
rpc: RPC = Depends(get_rpc),
|
||||||
channel_manager=Depends(get_channel_manager),
|
channel_manager=Depends(get_channel_manager),
|
||||||
|
token: str = Depends(get_ws_token)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Message WebSocket endpoint, facilitates sending RPC messages
|
Message WebSocket endpoint, facilitates sending RPC messages
|
||||||
@ -105,6 +107,9 @@ async def message_endpoint(
|
|||||||
# Return a channel ID, pass that instead of ws to the rest of the methods
|
# Return a channel ID, pass that instead of ws to the rest of the methods
|
||||||
channel = await channel_manager.on_connect(ws)
|
channel = await channel_manager.on_connect(ws)
|
||||||
|
|
||||||
|
if not channel:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"Consumer connected - {channel}")
|
logger.info(f"Consumer connected - {channel}")
|
||||||
|
|
||||||
# Keep connection open until explicitly closed, and process requests
|
# Keep connection open until explicitly closed, and process requests
|
||||||
|
@ -139,8 +139,7 @@ class ApiServer(RPCHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def configure_app(self, app: FastAPI, config):
|
def configure_app(self, app: FastAPI, config):
|
||||||
from freqtrade.rpc.api_server.api_auth import (get_ws_token, http_basic_or_jwt_token,
|
from freqtrade.rpc.api_server.api_auth import http_basic_or_jwt_token, router_login
|
||||||
router_login)
|
|
||||||
from freqtrade.rpc.api_server.api_backtest import router as api_backtest
|
from freqtrade.rpc.api_server.api_backtest import router as api_backtest
|
||||||
from freqtrade.rpc.api_server.api_v1 import router as api_v1
|
from freqtrade.rpc.api_server.api_v1 import router as api_v1
|
||||||
from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public
|
from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public
|
||||||
@ -155,9 +154,7 @@ class ApiServer(RPCHandler):
|
|||||||
app.include_router(api_backtest, prefix="/api/v1",
|
app.include_router(api_backtest, prefix="/api/v1",
|
||||||
dependencies=[Depends(http_basic_or_jwt_token)],
|
dependencies=[Depends(http_basic_or_jwt_token)],
|
||||||
)
|
)
|
||||||
app.include_router(ws_router, prefix="/api/v1",
|
app.include_router(ws_router, prefix="/api/v1")
|
||||||
dependencies=[Depends(get_ws_token)]
|
|
||||||
)
|
|
||||||
app.include_router(router_login, prefix="/api/v1", tags=["auth"])
|
app.include_router(router_login, prefix="/api/v1", tags=["auth"])
|
||||||
# UI Router MUST be last!
|
# UI Router MUST be last!
|
||||||
app.include_router(router_ui, prefix='')
|
app.include_router(router_ui, prefix='')
|
||||||
|
@ -10,7 +10,7 @@ from unittest.mock import ANY, MagicMock, PropertyMock
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, WebSocketDisconnect
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from requests.auth import _basic_auth_str
|
from requests.auth import _basic_auth_str
|
||||||
@ -31,6 +31,7 @@ from tests.conftest import (CURRENT_TEST_STRATEGY, create_mock_trades, get_mock_
|
|||||||
BASE_URI = "/api/v1"
|
BASE_URI = "/api/v1"
|
||||||
_TEST_USER = "FreqTrader"
|
_TEST_USER = "FreqTrader"
|
||||||
_TEST_PASS = "SuperSecurePassword1!"
|
_TEST_PASS = "SuperSecurePassword1!"
|
||||||
|
_TEST_WS_TOKEN = "secret_Ws_t0ken"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -44,6 +45,7 @@ def botclient(default_conf, mocker):
|
|||||||
"CORS_origins": ['http://example.com'],
|
"CORS_origins": ['http://example.com'],
|
||||||
"username": _TEST_USER,
|
"username": _TEST_USER,
|
||||||
"password": _TEST_PASS,
|
"password": _TEST_PASS,
|
||||||
|
"ws_token": _TEST_WS_TOKEN
|
||||||
}})
|
}})
|
||||||
|
|
||||||
ftbot = get_patched_freqtradebot(mocker, default_conf)
|
ftbot = get_patched_freqtradebot(mocker, default_conf)
|
||||||
@ -155,6 +157,23 @@ def test_api_auth():
|
|||||||
get_user_from_token(b'not_a_token', 'secret1234')
|
get_user_from_token(b'not_a_token', 'secret1234')
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_ws_auth(botclient):
|
||||||
|
ftbot, client = botclient
|
||||||
|
|
||||||
|
bad_token = "bad-ws_token"
|
||||||
|
url = f"/api/v1/message/ws?token={bad_token}"
|
||||||
|
|
||||||
|
with pytest.raises(WebSocketDisconnect):
|
||||||
|
with client.websocket_connect(url) as websocket:
|
||||||
|
websocket.receive()
|
||||||
|
|
||||||
|
good_token = _TEST_WS_TOKEN
|
||||||
|
url = f"/api/v1/message/ws?token={good_token}"
|
||||||
|
|
||||||
|
with client.websocket_connect(url) as websocket:
|
||||||
|
websocket.send(1)
|
||||||
|
|
||||||
|
|
||||||
def test_api_unauthorized(botclient):
|
def test_api_unauthorized(botclient):
|
||||||
ftbot, client = botclient
|
ftbot, client = botclient
|
||||||
rc = client.get(f"{BASE_URI}/ping")
|
rc = client.get(f"{BASE_URI}/ping")
|
||||||
|
Loading…
Reference in New Issue
Block a user