fix ws token auth

This commit is contained in:
Timothy Pogue 2022-09-08 11:25:30 -06:00
parent fac6626459
commit b9e7af1ce2
4 changed files with 28 additions and 7 deletions

View File

@ -62,7 +62,7 @@ async def get_ws_token(
# Just return the token if it matches # Just return the token if it matches
return token return token
else: else:
logger.debug("Denying websocket request") logger.info("Denying websocket request")
# If it doesn't match, close the websocket connection # If it doesn't match, close the websocket connection
await ws.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)

View File

@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState from starlette.websockets import WebSocketState
from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import get_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage, from freqtrade.rpc.api_server.ws.schema import (ValidationError, WSAnalyzedDFMessage,
@ -95,6 +96,7 @@ async def message_endpoint(
ws: WebSocket, ws: WebSocket,
rpc: RPC = Depends(get_rpc), rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager), channel_manager=Depends(get_channel_manager),
token: str = Depends(get_ws_token)
): ):
""" """
Message WebSocket endpoint, facilitates sending RPC messages Message WebSocket endpoint, facilitates sending RPC messages
@ -105,6 +107,9 @@ async def message_endpoint(
# Return a channel ID, pass that instead of ws to the rest of the methods # Return a channel ID, pass that instead of ws to the rest of the methods
channel = await channel_manager.on_connect(ws) channel = await channel_manager.on_connect(ws)
if not channel:
return
logger.info(f"Consumer connected - {channel}") logger.info(f"Consumer connected - {channel}")
# Keep connection open until explicitly closed, and process requests # Keep connection open until explicitly closed, and process requests

View File

@ -139,8 +139,7 @@ class ApiServer(RPCHandler):
) )
def configure_app(self, app: FastAPI, config): def configure_app(self, app: FastAPI, config):
from freqtrade.rpc.api_server.api_auth import (get_ws_token, http_basic_or_jwt_token, from freqtrade.rpc.api_server.api_auth import http_basic_or_jwt_token, router_login
router_login)
from freqtrade.rpc.api_server.api_backtest import router as api_backtest from freqtrade.rpc.api_server.api_backtest import router as api_backtest
from freqtrade.rpc.api_server.api_v1 import router as api_v1 from freqtrade.rpc.api_server.api_v1 import router as api_v1
from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public
@ -155,9 +154,7 @@ class ApiServer(RPCHandler):
app.include_router(api_backtest, prefix="/api/v1", app.include_router(api_backtest, prefix="/api/v1",
dependencies=[Depends(http_basic_or_jwt_token)], dependencies=[Depends(http_basic_or_jwt_token)],
) )
app.include_router(ws_router, prefix="/api/v1", app.include_router(ws_router, prefix="/api/v1")
dependencies=[Depends(get_ws_token)]
)
app.include_router(router_login, prefix="/api/v1", tags=["auth"]) app.include_router(router_login, prefix="/api/v1", tags=["auth"])
# UI Router MUST be last! # UI Router MUST be last!
app.include_router(router_ui, prefix='') app.include_router(router_ui, prefix='')

View File

@ -10,7 +10,7 @@ from unittest.mock import ANY, MagicMock, PropertyMock
import pandas as pd import pandas as pd
import pytest import pytest
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI, WebSocketDisconnect
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from requests.auth import _basic_auth_str from requests.auth import _basic_auth_str
@ -31,6 +31,7 @@ from tests.conftest import (CURRENT_TEST_STRATEGY, create_mock_trades, get_mock_
BASE_URI = "/api/v1" BASE_URI = "/api/v1"
_TEST_USER = "FreqTrader" _TEST_USER = "FreqTrader"
_TEST_PASS = "SuperSecurePassword1!" _TEST_PASS = "SuperSecurePassword1!"
_TEST_WS_TOKEN = "secret_Ws_t0ken"
@pytest.fixture @pytest.fixture
@ -44,6 +45,7 @@ def botclient(default_conf, mocker):
"CORS_origins": ['http://example.com'], "CORS_origins": ['http://example.com'],
"username": _TEST_USER, "username": _TEST_USER,
"password": _TEST_PASS, "password": _TEST_PASS,
"ws_token": _TEST_WS_TOKEN
}}) }})
ftbot = get_patched_freqtradebot(mocker, default_conf) ftbot = get_patched_freqtradebot(mocker, default_conf)
@ -155,6 +157,23 @@ def test_api_auth():
get_user_from_token(b'not_a_token', 'secret1234') get_user_from_token(b'not_a_token', 'secret1234')
def test_api_ws_auth(botclient):
ftbot, client = botclient
bad_token = "bad-ws_token"
url = f"/api/v1/message/ws?token={bad_token}"
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(url) as websocket:
websocket.receive()
good_token = _TEST_WS_TOKEN
url = f"/api/v1/message/ws?token={good_token}"
with client.websocket_connect(url) as websocket:
websocket.send(1)
def test_api_unauthorized(botclient): def test_api_unauthorized(botclient):
ftbot, client = botclient ftbot, client = botclient
rc = client.get(f"{BASE_URI}/ping") rc = client.get(f"{BASE_URI}/ping")