support jwt token in place of ws token

This commit is contained in:
Timothy Pogue 2022-09-09 11:38:42 -06:00
parent 09679cc798
commit 6cbc03a96a
3 changed files with 34 additions and 14 deletions

View File

@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from typing import Any, Dict, Union from typing import Any, Dict, Union
import jwt import jwt
from fastapi import APIRouter, Depends, HTTPException, WebSocket, status from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.security.http import HTTPBasic, HTTPBasicCredentials from fastapi.security.http import HTTPBasic, HTTPBasicCredentials
@ -29,7 +29,8 @@ httpbasic = HTTPBasic(auto_error=False)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_user_from_token(token, secret_key: str, token_type: str = "access"): def get_user_from_token(token, secret_key: str, token_type: str = "access",
raise_on_error: bool = True) -> Union[bool, str]:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials", detail="Could not validate credentials",
@ -39,12 +40,21 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
username: str = payload.get("identity", {}).get('u') username: str = payload.get("identity", {}).get('u')
if username is None: if username is None:
raise credentials_exception if raise_on_error:
raise credentials_exception
else:
return False
if payload.get("type") != token_type: if payload.get("type") != token_type:
raise credentials_exception if raise_on_error:
raise credentials_exception
else:
return False
except jwt.PyJWTError: except jwt.PyJWTError:
raise credentials_exception if raise_on_error:
raise credentials_exception
else:
return False
return username return username
@ -53,14 +63,18 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py # https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py
async def get_ws_token( async def get_ws_token(
ws: WebSocket, ws: WebSocket,
token: Union[str, None] = None, ws_token: Union[str, None] = Query(..., alias="token"),
api_config: Dict[str, Any] = Depends(get_api_config) api_config: Dict[str, Any] = Depends(get_api_config)
): ):
secret_ws_token = api_config['ws_token'] secret_ws_token = api_config.get('ws_token', 'secret_ws_t0ken.')
secret_jwt_key = api_config.get('jwt_secret_key', 'super-secret')
if token == secret_ws_token: if secrets.compare_digest(secret_ws_token, ws_token):
# Just return the token if it matches # Just return the token if it matches
return token return ws_token
elif user := get_user_from_token(ws_token, secret_jwt_key, raise_on_error=False):
# If the token is a jwt, and it's valid return the user
return user
else: else:
logger.info("Denying websocket request") logger.info("Denying websocket request")
# If it doesn't match, close the websocket connection # If it doesn't match, close the websocket connection

View File

@ -132,6 +132,10 @@ async def message_endpoint(
else: else:
await ws.close() await ws.close()
except RuntimeError:
# WebSocket was closed
await channel_manager.on_disconnect(ws)
except Exception as e: except Exception as e:
logger.error(f"Failed to serve - {ws.client}") logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening # Log tracebacks to keep track of what errors are happening

View File

@ -161,18 +161,20 @@ def test_api_auth():
def test_api_ws_auth(botclient): def test_api_ws_auth(botclient):
ftbot, client = botclient ftbot, client = botclient
def url(token): return f"/api/v1/message/ws?token={token}"
bad_token = "bad-ws_token" bad_token = "bad-ws_token"
url = f"/api/v1/message/ws?token={bad_token}"
with pytest.raises(WebSocketDisconnect): with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(url) as websocket: with client.websocket_connect(url(bad_token)) as websocket:
websocket.receive() websocket.receive()
good_token = _TEST_WS_TOKEN good_token = _TEST_WS_TOKEN
url = f"/api/v1/message/ws?token={good_token}" with client.websocket_connect(url(good_token)) as websocket:
pass
with client.websocket_connect(url) as websocket: jwt_secret = ftbot.config['api_server'].get('jwt_secret_key', 'super-secret')
jwt_token = create_token({'identity': {'u': 'Freqtrade'}}, jwt_secret)
with client.websocket_connect(url(jwt_token)) as websocket:
pass pass