support jwt token in place of ws token
This commit is contained in:
parent
09679cc798
commit
6cbc03a96a
@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
|||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from fastapi.security.http import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security.http import HTTPBasic, HTTPBasicCredentials
|
||||||
|
|
||||||
@ -29,7 +29,8 @@ httpbasic = HTTPBasic(auto_error=False)
|
|||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
def get_user_from_token(token, secret_key: str, token_type: str = "access"):
|
def get_user_from_token(token, secret_key: str, token_type: str = "access",
|
||||||
|
raise_on_error: bool = True) -> Union[bool, str]:
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Could not validate credentials",
|
detail="Could not validate credentials",
|
||||||
@ -39,12 +40,21 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
|
|||||||
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("identity", {}).get('u')
|
username: str = payload.get("identity", {}).get('u')
|
||||||
if username is None:
|
if username is None:
|
||||||
raise credentials_exception
|
if raise_on_error:
|
||||||
|
raise credentials_exception
|
||||||
|
else:
|
||||||
|
return False
|
||||||
if payload.get("type") != token_type:
|
if payload.get("type") != token_type:
|
||||||
raise credentials_exception
|
if raise_on_error:
|
||||||
|
raise credentials_exception
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
raise credentials_exception
|
if raise_on_error:
|
||||||
|
raise credentials_exception
|
||||||
|
else:
|
||||||
|
return False
|
||||||
return username
|
return username
|
||||||
|
|
||||||
|
|
||||||
@ -53,14 +63,18 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
|
|||||||
# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py
|
# https://github.com/tiangolo/fastapi/blob/master/fastapi/security/api_key.py
|
||||||
async def get_ws_token(
|
async def get_ws_token(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
token: Union[str, None] = None,
|
ws_token: Union[str, None] = Query(..., alias="token"),
|
||||||
api_config: Dict[str, Any] = Depends(get_api_config)
|
api_config: Dict[str, Any] = Depends(get_api_config)
|
||||||
):
|
):
|
||||||
secret_ws_token = api_config['ws_token']
|
secret_ws_token = api_config.get('ws_token', 'secret_ws_t0ken.')
|
||||||
|
secret_jwt_key = api_config.get('jwt_secret_key', 'super-secret')
|
||||||
|
|
||||||
if token == secret_ws_token:
|
if secrets.compare_digest(secret_ws_token, ws_token):
|
||||||
# Just return the token if it matches
|
# Just return the token if it matches
|
||||||
return token
|
return ws_token
|
||||||
|
elif user := get_user_from_token(ws_token, secret_jwt_key, raise_on_error=False):
|
||||||
|
# If the token is a jwt, and it's valid return the user
|
||||||
|
return user
|
||||||
else:
|
else:
|
||||||
logger.info("Denying websocket request")
|
logger.info("Denying websocket request")
|
||||||
# If it doesn't match, close the websocket connection
|
# If it doesn't match, close the websocket connection
|
||||||
|
@ -132,6 +132,10 @@ async def message_endpoint(
|
|||||||
else:
|
else:
|
||||||
await ws.close()
|
await ws.close()
|
||||||
|
|
||||||
|
except RuntimeError:
|
||||||
|
# WebSocket was closed
|
||||||
|
await channel_manager.on_disconnect(ws)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to serve - {ws.client}")
|
logger.error(f"Failed to serve - {ws.client}")
|
||||||
# Log tracebacks to keep track of what errors are happening
|
# Log tracebacks to keep track of what errors are happening
|
||||||
|
@ -161,18 +161,20 @@ def test_api_auth():
|
|||||||
|
|
||||||
def test_api_ws_auth(botclient):
|
def test_api_ws_auth(botclient):
|
||||||
ftbot, client = botclient
|
ftbot, client = botclient
|
||||||
|
def url(token): return f"/api/v1/message/ws?token={token}"
|
||||||
|
|
||||||
bad_token = "bad-ws_token"
|
bad_token = "bad-ws_token"
|
||||||
url = f"/api/v1/message/ws?token={bad_token}"
|
|
||||||
|
|
||||||
with pytest.raises(WebSocketDisconnect):
|
with pytest.raises(WebSocketDisconnect):
|
||||||
with client.websocket_connect(url) as websocket:
|
with client.websocket_connect(url(bad_token)) as websocket:
|
||||||
websocket.receive()
|
websocket.receive()
|
||||||
|
|
||||||
good_token = _TEST_WS_TOKEN
|
good_token = _TEST_WS_TOKEN
|
||||||
url = f"/api/v1/message/ws?token={good_token}"
|
with client.websocket_connect(url(good_token)) as websocket:
|
||||||
|
pass
|
||||||
|
|
||||||
with client.websocket_connect(url) as websocket:
|
jwt_secret = ftbot.config['api_server'].get('jwt_secret_key', 'super-secret')
|
||||||
|
jwt_token = create_token({'identity': {'u': 'Freqtrade'}}, jwt_secret)
|
||||||
|
with client.websocket_connect(url(jwt_token)) as websocket:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user