Improve logic for token validation
This commit is contained in:
parent
2f6a61521f
commit
b344f78d00
@ -29,8 +29,7 @@ httpbasic = HTTPBasic(auto_error=False)
|
|||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
def get_user_from_token(token, secret_key: str, token_type: str = "access",
|
def get_user_from_token(token, secret_key: str, token_type: str = "access") -> str:
|
||||||
raise_on_error: bool = True) -> Union[bool, str]:
|
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Could not validate credentials",
|
detail="Could not validate credentials",
|
||||||
@ -40,28 +39,19 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access",
|
|||||||
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("identity", {}).get('u')
|
username: str = payload.get("identity", {}).get('u')
|
||||||
if username is None:
|
if username is None:
|
||||||
if raise_on_error:
|
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
else:
|
|
||||||
return False
|
|
||||||
if payload.get("type") != token_type:
|
if payload.get("type") != token_type:
|
||||||
if raise_on_error:
|
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
if raise_on_error:
|
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return username
|
return username
|
||||||
|
|
||||||
|
|
||||||
# This should be reimplemented to better realign with the existing tools provided
|
# This should be reimplemented to better realign with the existing tools provided
|
||||||
# by FastAPI regarding API Tokens
|
# by FastAPI regarding API Tokens
|
||||||
# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py
|
# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py
|
||||||
async def get_ws_token(
|
async def validate_ws_token(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
ws_token: Union[str, None] = Query(..., alias="token"),
|
ws_token: Union[str, None] = Query(..., alias="token"),
|
||||||
api_config: Dict[str, Any] = Depends(get_api_config)
|
api_config: Dict[str, Any] = Depends(get_api_config)
|
||||||
@ -72,10 +62,13 @@ async def get_ws_token(
|
|||||||
if secrets.compare_digest(secret_ws_token, ws_token):
|
if secrets.compare_digest(secret_ws_token, ws_token):
|
||||||
# Just return the token if it matches
|
# Just return the token if it matches
|
||||||
return ws_token
|
return ws_token
|
||||||
elif user := get_user_from_token(ws_token, secret_jwt_key, raise_on_error=False):
|
|
||||||
# If the token is a jwt, and it's valid return the user
|
|
||||||
return user
|
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
user = get_user_from_token(ws_token, secret_jwt_key)
|
||||||
|
return user
|
||||||
|
# If the token is a jwt, and it's valid return the user
|
||||||
|
except HTTPException:
|
||||||
|
pass
|
||||||
logger.info("Denying websocket request")
|
logger.info("Denying websocket request")
|
||||||
# If it doesn't match, close the websocket connection
|
# If it doesn't match, close the websocket connection
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import ValidationError
|
|||||||
from starlette.websockets import WebSocketState
|
from starlette.websockets import WebSocketState
|
||||||
|
|
||||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||||
from freqtrade.rpc.api_server.api_auth import get_ws_token
|
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
||||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||||
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
from freqtrade.rpc.api_server.ws.channel import WebSocketChannel
|
||||||
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
||||||
@ -96,7 +96,7 @@ async def message_endpoint(
|
|||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
rpc: RPC = Depends(get_rpc),
|
rpc: RPC = Depends(get_rpc),
|
||||||
channel_manager=Depends(get_channel_manager),
|
channel_manager=Depends(get_channel_manager),
|
||||||
token: str = Depends(get_ws_token)
|
token: str = Depends(validate_ws_token)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Message WebSocket endpoint, facilitates sending RPC messages
|
Message WebSocket endpoint, facilitates sending RPC messages
|
||||||
|
@ -158,12 +158,6 @@ def test_api_auth():
|
|||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
get_user_from_token(b'not_a_token', 'secret1234')
|
get_user_from_token(b'not_a_token', 'secret1234')
|
||||||
|
|
||||||
# Check returning false instead of error on bad token
|
|
||||||
assert not get_user_from_token(b'not_a_token', 'secret1234', raise_on_error=False)
|
|
||||||
|
|
||||||
# Check returning false instead of error on bad token type
|
|
||||||
assert not get_user_from_token(token, 'secret1234', token_type='refresh', raise_on_error=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_ws_auth(botclient):
|
def test_api_ws_auth(botclient):
|
||||||
ftbot, client = botclient
|
ftbot, client = botclient
|
||||||
|
Loading…
Reference in New Issue
Block a user