From b344f78d007c20b75acefa181c2a2129f4787ecd Mon Sep 17 00:00:00 2001 From: Matthias Date: Sat, 10 Sep 2022 14:19:11 +0200 Subject: [PATCH] Improve logic for token validation --- freqtrade/rpc/api_server/api_auth.py | 35 +++++++++++----------------- freqtrade/rpc/api_server/api_ws.py | 4 ++-- tests/rpc/test_rpc_apiserver.py | 6 ----- 3 files changed, 16 insertions(+), 29 deletions(-) diff --git a/freqtrade/rpc/api_server/api_auth.py b/freqtrade/rpc/api_server/api_auth.py index a2b722f0a..767a2d5b9 100644 --- a/freqtrade/rpc/api_server/api_auth.py +++ b/freqtrade/rpc/api_server/api_auth.py @@ -29,8 +29,7 @@ httpbasic = HTTPBasic(auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) -def get_user_from_token(token, secret_key: str, token_type: str = "access", - raise_on_error: bool = True) -> Union[bool, str]: +def get_user_from_token(token, secret_key: str, token_type: str = "access") -> str: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -40,28 +39,19 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access", payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) username: str = payload.get("identity", {}).get('u') if username is None: - if raise_on_error: - raise credentials_exception - else: - return False + raise credentials_exception if payload.get("type") != token_type: - if raise_on_error: - raise credentials_exception - else: - return False + raise credentials_exception except jwt.PyJWTError: - if raise_on_error: - raise credentials_exception - else: - return False + raise credentials_exception return username # This should be reimplemented to better realign with the existing tools provided # by FastAPI regarding API Tokens # https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py -async def get_ws_token( +async def validate_ws_token( ws: WebSocket, ws_token: Union[str, None] = Query(..., alias="token"), api_config: Dict[str, Any] = Depends(get_api_config) @@ -72,13 +62,16 @@ async def get_ws_token( if secrets.compare_digest(secret_ws_token, ws_token): # Just return the token if it matches return ws_token - elif user := get_user_from_token(ws_token, secret_jwt_key, raise_on_error=False): - # If the token is a jwt, and it's valid return the user - return user else: - logger.info("Denying websocket request") - # If it doesn't match, close the websocket connection - await ws.close(code=status.WS_1008_POLICY_VIOLATION) + try: + user = get_user_from_token(ws_token, secret_jwt_key) + return user + # If the token is a jwt, and it's valid return the user + except HTTPException: + pass + logger.info("Denying websocket request") + # If it doesn't match, close the websocket connection + await ws.close(code=status.WS_1008_POLICY_VIOLATION) def create_token(data: dict, secret_key: str, token_type: str = "access") -> str: diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 25d29a7ce..34b780956 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -7,7 +7,7 @@ from pydantic import ValidationError from starlette.websockets import WebSocketState from freqtrade.enums import RPCMessageType, RPCRequestType -from freqtrade.rpc.api_server.api_auth import get_ws_token +from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.ws.channel import WebSocketChannel from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, @@ -96,7 +96,7 @@ async def message_endpoint( ws: WebSocket, rpc: RPC = Depends(get_rpc), channel_manager=Depends(get_channel_manager), - token: str = Depends(get_ws_token) + token: str = Depends(validate_ws_token) ): """ Message WebSocket endpoint, facilitates sending RPC messages diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index 17705e62e..f1aa465f0 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -158,12 +158,6 @@ def test_api_auth(): with pytest.raises(HTTPException): get_user_from_token(b'not_a_token', 'secret1234') - # Check returning false instead of error on bad token - assert not get_user_from_token(b'not_a_token', 'secret1234', raise_on_error=False) - - # Check returning false instead of error on bad token type - assert not get_user_from_token(token, 'secret1234', token_type='refresh', raise_on_error=False) - def test_api_ws_auth(botclient): ftbot, client = botclient