Fix auth providers

This commit is contained in:
Matthias 2020-12-26 08:48:15 +01:00
parent 86d0700884
commit 5e4c4cae06
3 changed files with 66 additions and 44 deletions

View File

@ -1,12 +1,11 @@
from datetime import datetime, timedelta
import secrets import secrets
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security.http import HTTPBasic, HTTPBasicCredentials from fastapi.security.http import HTTPBasic, HTTPBasicCredentials
from fastapi.security.utils import get_authorization_scheme_param
from fastapi_jwt_auth import AuthJWT
from pydantic import BaseModel from pydantic import BaseModel
import jwt
from fastapi.security import OAuth2PasswordBearer
from freqtrade.rpc.api_server2.api_models import AccessAndRefreshToken, AccessToken from freqtrade.rpc.api_server2.api_models import AccessAndRefreshToken, AccessToken
from .deps import get_config from .deps import get_config
@ -19,34 +18,59 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
router_login = APIRouter() router_login = APIRouter()
class Settings(BaseModel):
# TODO: should be set as config['api_server'].get('jwt_secret_key', 'super-secret')
authjwt_secret_key: str = "secret"
@AuthJWT.load_config
def get_jwt_config():
return Settings()
def verify_auth(config, username: str, password: str): def verify_auth(config, username: str, password: str):
"""Verify username/password"""
return (secrets.compare_digest(username, config['api_server'].get('username')) and return (secrets.compare_digest(username, config['api_server'].get('username')) and
secrets.compare_digest(password, config['api_server'].get('password'))) secrets.compare_digest(password, config['api_server'].get('password')))
class HTTPBasicOrJWTToken(HTTPBasic): httpbasic = HTTPBasic(auto_error=False)
description = "Token Or Pass auth" oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_user_from_token(token, token_type: str = "access"):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
if payload.get("type") != token_type:
raise credentials_exception
except jwt.PyJWTError:
raise credentials_exception
return username
def create_token(data: dict, token_type: str = "access") -> str:
to_encode = data.copy()
if token_type == "access":
expire = datetime.utcnow() + timedelta(minutes=15)
elif token_type == "refresh":
expire = datetime.utcnow() + timedelta(days=30)
else:
raise ValueError()
to_encode.update({
"exp": expire,
"iat": datetime.utcnow(),
"type": token_type,
})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def http_basic_or_jwt_token(form_data: HTTPBasicCredentials = Depends(httpbasic),
token: str = Depends(oauth2_scheme), config=Depends(get_config)):
if token:
return get_user_from_token(token)
elif form_data and verify_auth(config, form_data.username, form_data.password):
return form_data.username
async def __call__(self, request: Request, config=Depends(get_config) # type: ignore
) -> Optional[str]:
header_authorization: str = request.headers.get("Authorization")
header_scheme, header_param = get_authorization_scheme_param(header_authorization)
if header_scheme.lower() == 'bearer':
AuthJWT(request).jwt_required()
elif header_scheme.lower() == 'basic':
credentials: Optional[HTTPBasicCredentials] = await HTTPBasic()(request)
if credentials and verify_auth(config, credentials.username, credentials.password):
return credentials.username
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password", detail="Incorrect username or password",
@ -56,13 +80,10 @@ class HTTPBasicOrJWTToken(HTTPBasic):
@router_login.post('/token/login', response_model=AccessAndRefreshToken) @router_login.post('/token/login', response_model=AccessAndRefreshToken)
def token_login(form_data: HTTPBasicCredentials = Depends(HTTPBasic()), config=Depends(get_config)): def token_login(form_data: HTTPBasicCredentials = Depends(HTTPBasic()), config=Depends(get_config)):
print(form_data)
Authorize = AuthJWT()
if verify_auth(config, form_data.username, form_data.password): if verify_auth(config, form_data.username, form_data.password):
token_data = form_data.username token_data = {'sub': form_data.username}
access_token = Authorize.create_access_token(subject=token_data) access_token = create_token(token_data)
refresh_token = Authorize.create_refresh_token(subject=token_data) refresh_token = create_token(token_data, token_type="refresh")
return { return {
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
@ -76,8 +97,9 @@ def token_login(form_data: HTTPBasicCredentials = Depends(HTTPBasic()), config=D
@router_login.post('/token/refresh', response_model=AccessToken) @router_login.post('/token/refresh', response_model=AccessToken)
def token_refresh(Authorize: AuthJWT = Depends()): def token_refresh(token: str = Depends(oauth2_scheme)):
Authorize.jwt_refresh_token_required() # Refresh token
u = get_user_from_token(token, 'refresh')
access_token = Authorize.create_access_token(subject=Authorize.get_jwt_subject()) token_data = {'sub': u}
access_token = create_token(token_data, token_type="access")
return {'access_token': access_token} return {'access_token': access_token}

View File

@ -37,13 +37,13 @@ class ApiServer(RPCHandler):
def configure_app(self, app: FastAPI, config): def configure_app(self, app: FastAPI, config):
from .api_v1 import router as api_v1 from .api_v1 import router as api_v1
from .api_v1 import router_public as api_v1_public from .api_v1 import router_public as api_v1_public
from .api_auth import HTTPBasicOrJWTToken, router_login from .api_auth import http_basic_or_jwt_token, router_login
app.include_router(api_v1_public, prefix="/api/v1") app.include_router(api_v1_public, prefix="/api/v1")
app.include_router(api_v1, prefix="/api/v1", app.include_router(api_v1, prefix="/api/v1",
dependencies=[Depends(HTTPBasicOrJWTToken())] dependencies=[Depends(http_basic_or_jwt_token)],
) )
app.include_router(router_login, prefix="/api/v1") app.include_router(router_login, prefix="/api/v1", tags=["auth"])
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,

View File

@ -35,7 +35,7 @@ flask-cors==3.0.9
# API Server # API Server
fastapi==0.63.0 fastapi==0.63.0
uvicorn==0.13.2 uvicorn==0.13.2
fastapi_jwt_auth==0.5.0 pyjwt==1.7.1
# Support for colorized terminal output # Support for colorized terminal output
colorama==0.4.4 colorama==0.4.4