support jwt token in place of ws token

This commit is contained in:
Timothy Pogue 2022-09-09 11:38:42 -06:00
parent 09679cc798
commit 6cbc03a96a
3 changed files with 34 additions and 14 deletions

View File

@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from typing import Any, Dict, Union
import jwt
from fastapi import APIRouter, Depends, HTTPException, WebSocket, status
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, status
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.http import HTTPBasic, HTTPBasicCredentials
@ -29,7 +29,8 @@ httpbasic = HTTPBasic(auto_error=False)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_user_from_token(token, secret_key: str, token_type: str = "access"):
def get_user_from_token(token, secret_key: str, token_type: str = "access",
raise_on_error: bool = True) -> Union[bool, str]:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
@ -39,12 +40,21 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
username: str = payload.get("identity", {}).get('u')
if username is None:
if raise_on_error:
raise credentials_exception
else:
return False
if payload.get("type") != token_type:
if raise_on_error:
raise credentials_exception
else:
return False
except jwt.PyJWTError:
if raise_on_error:
raise credentials_exception
else:
return False
return username
@ -53,14 +63,18 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py
async def get_ws_token(
ws: WebSocket,
token: Union[str, None] = None,
ws_token: Union[str, None] = Query(..., alias="token"),
api_config: Dict[str, Any] = Depends(get_api_config)
):
secret_ws_token = api_config['ws_token']
secret_ws_token = api_config.get('ws_token', 'secret_ws_t0ken.')
secret_jwt_key = api_config.get('jwt_secret_key', 'super-secret')
if token == secret_ws_token:
if secrets.compare_digest(secret_ws_token, ws_token):
# Just return the token if it matches
return token
return ws_token
elif user := get_user_from_token(ws_token, secret_jwt_key, raise_on_error=False):
# If the token is a jwt, and it's valid return the user
return user
else:
logger.info("Denying websocket request")
# If it doesn't match, close the websocket connection

View File

@ -132,6 +132,10 @@ async def message_endpoint(
else:
await ws.close()
except RuntimeError:
# WebSocket was closed
await channel_manager.on_disconnect(ws)
except Exception as e:
logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening

View File

@ -161,18 +161,20 @@ def test_api_auth():
def test_api_ws_auth(botclient):
ftbot, client = botclient
def url(token): return f"/api/v1/message/ws?token={token}"
bad_token = "bad-ws_token"
url = f"/api/v1/message/ws?token={bad_token}"
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(url) as websocket:
with client.websocket_connect(url(bad_token)) as websocket:
websocket.receive()
good_token = _TEST_WS_TOKEN
url = f"/api/v1/message/ws?token={good_token}"
with client.websocket_connect(url(good_token)) as websocket:
pass
with client.websocket_connect(url) as websocket:
jwt_secret = ftbot.config['api_server'].get('jwt_secret_key', 'super-secret')
jwt_token = create_token({'identity': {'u': 'Freqtrade'}}, jwt_secret)
with client.websocket_connect(url(jwt_token)) as websocket:
pass