support jwt token in place of ws token
This commit is contained in:
parent
09679cc798
commit
6cbc03a96a
@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.http import HTTPBasic, HTTPBasicCredentials
|
||||
|
||||
@ -29,7 +29,8 @@ httpbasic = HTTPBasic(auto_error=False)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
def get_user_from_token(token, secret_key: str, token_type: str = "access"):
|
||||
def get_user_from_token(token, secret_key: str, token_type: str = "access",
|
||||
raise_on_error: bool = True) -> Union[bool, str]:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@ -39,12 +40,21 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
|
||||
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("identity", {}).get('u')
|
||||
if username is None:
|
||||
if raise_on_error:
|
||||
raise credentials_exception
|
||||
else:
|
||||
return False
|
||||
if payload.get("type") != token_type:
|
||||
if raise_on_error:
|
||||
raise credentials_exception
|
||||
else:
|
||||
return False
|
||||
|
||||
except jwt.PyJWTError:
|
||||
if raise_on_error:
|
||||
raise credentials_exception
|
||||
else:
|
||||
return False
|
||||
return username
|
||||
|
||||
|
||||
@ -53,14 +63,18 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
|
||||
# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py
|
||||
async def get_ws_token(
|
||||
ws: WebSocket,
|
||||
token: Union[str, None] = None,
|
||||
ws_token: Union[str, None] = Query(..., alias="token"),
|
||||
api_config: Dict[str, Any] = Depends(get_api_config)
|
||||
):
|
||||
secret_ws_token = api_config['ws_token']
|
||||
secret_ws_token = api_config.get('ws_token', 'secret_ws_t0ken.')
|
||||
secret_jwt_key = api_config.get('jwt_secret_key', 'super-secret')
|
||||
|
||||
if token == secret_ws_token:
|
||||
if secrets.compare_digest(secret_ws_token, ws_token):
|
||||
# Just return the token if it matches
|
||||
return token
|
||||
return ws_token
|
||||
elif user := get_user_from_token(ws_token, secret_jwt_key, raise_on_error=False):
|
||||
# If the token is a jwt, and it's valid return the user
|
||||
return user
|
||||
else:
|
||||
logger.info("Denying websocket request")
|
||||
# If it doesn't match, close the websocket connection
|
||||
|
@ -132,6 +132,10 @@ async def message_endpoint(
|
||||
else:
|
||||
await ws.close()
|
||||
|
||||
except RuntimeError:
|
||||
# WebSocket was closed
|
||||
await channel_manager.on_disconnect(ws)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to serve - {ws.client}")
|
||||
# Log tracebacks to keep track of what errors are happening
|
||||
|
@ -161,18 +161,20 @@ def test_api_auth():
|
||||
|
||||
def test_api_ws_auth(botclient):
|
||||
ftbot, client = botclient
|
||||
def url(token): return f"/api/v1/message/ws?token={token}"
|
||||
|
||||
bad_token = "bad-ws_token"
|
||||
url = f"/api/v1/message/ws?token={bad_token}"
|
||||
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect(url) as websocket:
|
||||
with client.websocket_connect(url(bad_token)) as websocket:
|
||||
websocket.receive()
|
||||
|
||||
good_token = _TEST_WS_TOKEN
|
||||
url = f"/api/v1/message/ws?token={good_token}"
|
||||
with client.websocket_connect(url(good_token)) as websocket:
|
||||
pass
|
||||
|
||||
with client.websocket_connect(url) as websocket:
|
||||
jwt_secret = ftbot.config['api_server'].get('jwt_secret_key', 'super-secret')
|
||||
jwt_token = create_token({'identity': {'u': 'Freqtrade'}}, jwt_secret)
|
||||
with client.websocket_connect(url(jwt_token)) as websocket:
|
||||
pass
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user