Improve logic for token validation

This commit is contained in:
Matthias 2022-09-10 14:19:11 +02:00
parent 2f6a61521f
commit b344f78d00
3 changed files with 16 additions and 29 deletions

View File

@ -29,8 +29,7 @@ httpbasic = HTTPBasic(auto_error=False)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_user_from_token(token, secret_key: str, token_type: str = "access", def get_user_from_token(token, secret_key: str, token_type: str = "access") -> str:
raise_on_error: bool = True) -> Union[bool, str]:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials", detail="Could not validate credentials",
@ -40,28 +39,19 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access",
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
username: str = payload.get("identity", {}).get('u') username: str = payload.get("identity", {}).get('u')
if username is None: if username is None:
if raise_on_error:
raise credentials_exception raise credentials_exception
else:
return False
if payload.get("type") != token_type: if payload.get("type") != token_type:
if raise_on_error:
raise credentials_exception raise credentials_exception
else:
return False
except jwt.PyJWTError: except jwt.PyJWTError:
if raise_on_error:
raise credentials_exception raise credentials_exception
else:
return False
return username return username
# This should be reimplemented to better realign with the existing tools provided # This should be reimplemented to better realign with the existing tools provided
# by FastAPI regarding API Tokens # by FastAPI regarding API Tokens
# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py # https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py
async def get_ws_token( async def validate_ws_token(
ws: WebSocket, ws: WebSocket,
ws_token: Union[str, None] = Query(..., alias="token"), ws_token: Union[str, None] = Query(..., alias="token"),
api_config: Dict[str, Any] = Depends(get_api_config) api_config: Dict[str, Any] = Depends(get_api_config)
@ -72,10 +62,13 @@ async def get_ws_token(
if secrets.compare_digest(secret_ws_token, ws_token): if secrets.compare_digest(secret_ws_token, ws_token):
# Just return the token if it matches # Just return the token if it matches
return ws_token return ws_token
elif user := get_user_from_token(ws_token, secret_jwt_key, raise_on_error=False):
# If the token is a jwt, and it's valid return the user
return user
else: else:
try:
user = get_user_from_token(ws_token, secret_jwt_key)
return user
# If the token is a jwt, and it's valid return the user
except HTTPException:
pass
logger.info("Denying websocket request") logger.info("Denying websocket request")
# If it doesn't match, close the websocket connection # If it doesn't match, close the websocket connection
await ws.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)

View File

@ -7,7 +7,7 @@ from pydantic import ValidationError
from starlette.websockets import WebSocketState from starlette.websockets import WebSocketState
from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import get_ws_token from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
@ -96,7 +96,7 @@ async def message_endpoint(
ws: WebSocket, ws: WebSocket,
rpc: RPC = Depends(get_rpc), rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager), channel_manager=Depends(get_channel_manager),
token: str = Depends(get_ws_token) token: str = Depends(validate_ws_token)
): ):
""" """
Message WebSocket endpoint, facilitates sending RPC messages Message WebSocket endpoint, facilitates sending RPC messages

View File

@ -158,12 +158,6 @@ def test_api_auth():
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
get_user_from_token(b'not_a_token', 'secret1234') get_user_from_token(b'not_a_token', 'secret1234')
# Check returning false instead of error on bad token
assert not get_user_from_token(b'not_a_token', 'secret1234', raise_on_error=False)
# Check returning false instead of error on bad token type
assert not get_user_from_token(token, 'secret1234', token_type='refresh', raise_on_error=False)
def test_api_ws_auth(botclient): def test_api_ws_auth(botclient):
ftbot, client = botclient ftbot, client = botclient