diff --git a/freqtrade/rpc/api_server/api_auth.py b/freqtrade/rpc/api_server/api_auth.py index 0d1378b6d..a2b722f0a 100644 --- a/freqtrade/rpc/api_server/api_auth.py +++ b/freqtrade/rpc/api_server/api_auth.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from typing import Any, Dict, Union import jwt -from fastapi import APIRouter, Depends, HTTPException, WebSocket, status +from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, status from fastapi.security import OAuth2PasswordBearer from fastapi.security.http import HTTPBasic, HTTPBasicCredentials @@ -29,7 +29,8 @@ httpbasic = HTTPBasic(auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) -def get_user_from_token(token, secret_key: str, token_type: str = "access"): +def get_user_from_token(token, secret_key: str, token_type: str = "access", + raise_on_error: bool = True) -> Union[bool, str]: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -39,12 +40,21 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"): payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) username: str = payload.get("identity", {}).get('u') if username is None: - raise credentials_exception + if raise_on_error: + raise credentials_exception + else: + return False if payload.get("type") != token_type: - raise credentials_exception + if raise_on_error: + raise credentials_exception + else: + return False except jwt.PyJWTError: - raise credentials_exception + if raise_on_error: + raise credentials_exception + else: + return False return username @@ -53,14 +63,18 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"): # https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py async def get_ws_token( ws: WebSocket, - token: Union[str, None] = None, + ws_token: Union[str, None] = Query(..., alias="token"), api_config: Dict[str, Any] = Depends(get_api_config) ): - secret_ws_token = api_config['ws_token'] + secret_ws_token = api_config.get('ws_token', 'secret_ws_t0ken.') + secret_jwt_key = api_config.get('jwt_secret_key', 'super-secret') - if token == secret_ws_token: + if secrets.compare_digest(secret_ws_token, ws_token): # Just return the token if it matches - return token + return ws_token + elif user := get_user_from_token(ws_token, secret_jwt_key, raise_on_error=False): + # If the token is a jwt, and it's valid return the user + return user else: logger.info("Denying websocket request") # If it doesn't match, close the websocket connection diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 16d5ef9a7..25d29a7ce 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -132,6 +132,10 @@ async def message_endpoint( else: await ws.close() + except RuntimeError: + # WebSocket was closed + await channel_manager.on_disconnect(ws) + except Exception as e: logger.error(f"Failed to serve - {ws.client}") # Log tracebacks to keep track of what errors are happening diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index 6a37e7cdd..f1aa465f0 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -161,18 +161,20 @@ def test_api_auth(): def test_api_ws_auth(botclient): ftbot, client = botclient + def url(token): return f"/api/v1/message/ws?token={token}" bad_token = "bad-ws_token" - url = f"/api/v1/message/ws?token={bad_token}" - with pytest.raises(WebSocketDisconnect): - with client.websocket_connect(url) as websocket: + with client.websocket_connect(url(bad_token)) as websocket: websocket.receive() good_token = _TEST_WS_TOKEN - url = f"/api/v1/message/ws?token={good_token}" + with client.websocket_connect(url(good_token)) as websocket: + pass - with client.websocket_connect(url) as websocket: + jwt_secret = ftbot.config['api_server'].get('jwt_secret_key', 'super-secret') + jwt_token = create_token({'identity': {'u': 'Freqtrade'}}, jwt_secret) + with client.websocket_connect(url(jwt_token)) as websocket: pass