diff --git a/freqtrade/rpc/api_server2/api_auth.py b/freqtrade/rpc/api_server2/api_auth.py index d0c975480..cf0168576 100644 --- a/freqtrade/rpc/api_server2/api_auth.py +++ b/freqtrade/rpc/api_server2/api_auth.py @@ -1,12 +1,11 @@ +from datetime import datetime, timedelta import secrets -from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.security.http import HTTPBasic, HTTPBasicCredentials -from fastapi.security.utils import get_authorization_scheme_param -from fastapi_jwt_auth import AuthJWT from pydantic import BaseModel - +import jwt +from fastapi.security import OAuth2PasswordBearer from freqtrade.rpc.api_server2.api_models import AccessAndRefreshToken, AccessToken from .deps import get_config @@ -19,50 +18,72 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30 router_login = APIRouter() -class Settings(BaseModel): - # TODO: should be set as config['api_server'].get('jwt_secret_key', 'super-secret') - authjwt_secret_key: str = "secret" - - -@AuthJWT.load_config -def get_jwt_config(): - return Settings() - - def verify_auth(config, username: str, password: str): + """Verify username/password""" return (secrets.compare_digest(username, config['api_server'].get('username')) and secrets.compare_digest(password, config['api_server'].get('password'))) -class HTTPBasicOrJWTToken(HTTPBasic): - description = "Token Or Pass auth" +httpbasic = HTTPBasic(auto_error=False) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) - async def __call__(self, request: Request, config=Depends(get_config) # type: ignore - ) -> Optional[str]: - header_authorization: str = request.headers.get("Authorization") - header_scheme, header_param = get_authorization_scheme_param(header_authorization) - if header_scheme.lower() == 'bearer': - AuthJWT(request).jwt_required() - elif header_scheme.lower() == 'basic': - credentials: Optional[HTTPBasicCredentials] = await HTTPBasic()(request) - if credentials and verify_auth(config, credentials.username, credentials.password): - return credentials.username - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - ) + +def get_user_from_token(token, token_type: str = "access"): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + if payload.get("type") != token_type: + raise credentials_exception + + except jwt.PyJWTError: + raise credentials_exception + return username + + +def create_token(data: dict, token_type: str = "access") -> str: + to_encode = data.copy() + if token_type == "access": + expire = datetime.utcnow() + timedelta(minutes=15) + elif token_type == "refresh": + expire = datetime.utcnow() + timedelta(days=30) + else: + raise ValueError() + to_encode.update({ + "exp": expire, + "iat": datetime.utcnow(), + "type": token_type, + }) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def http_basic_or_jwt_token(form_data: HTTPBasicCredentials = Depends(httpbasic), + token: str = Depends(oauth2_scheme), config=Depends(get_config)): + if token: + return get_user_from_token(token) + elif form_data and verify_auth(config, form_data.username, form_data.password): + return form_data.username + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + ) @router_login.post('/token/login', response_model=AccessAndRefreshToken) def token_login(form_data: HTTPBasicCredentials = Depends(HTTPBasic()), config=Depends(get_config)): - print(form_data) - Authorize = AuthJWT() - if verify_auth(config, form_data.username, form_data.password): - token_data = form_data.username - access_token = Authorize.create_access_token(subject=token_data) - refresh_token = Authorize.create_refresh_token(subject=token_data) + token_data = {'sub': form_data.username} + access_token = create_token(token_data) + refresh_token = create_token(token_data, token_type="refresh") return { "access_token": access_token, "refresh_token": refresh_token, @@ -76,8 +97,9 @@ def token_login(form_data: HTTPBasicCredentials = Depends(HTTPBasic()), config=D @router_login.post('/token/refresh', response_model=AccessToken) -def token_refresh(Authorize: AuthJWT = Depends()): - Authorize.jwt_refresh_token_required() - - access_token = Authorize.create_access_token(subject=Authorize.get_jwt_subject()) +def token_refresh(token: str = Depends(oauth2_scheme)): + # Refresh token + u = get_user_from_token(token, 'refresh') + token_data = {'sub': u} + access_token = create_token(token_data, token_type="access") return {'access_token': access_token} diff --git a/freqtrade/rpc/api_server2/webserver.py b/freqtrade/rpc/api_server2/webserver.py index 5f2e1d6fe..755b43127 100644 --- a/freqtrade/rpc/api_server2/webserver.py +++ b/freqtrade/rpc/api_server2/webserver.py @@ -37,13 +37,13 @@ class ApiServer(RPCHandler): def configure_app(self, app: FastAPI, config): from .api_v1 import router as api_v1 from .api_v1 import router_public as api_v1_public - from .api_auth import HTTPBasicOrJWTToken, router_login + from .api_auth import http_basic_or_jwt_token, router_login app.include_router(api_v1_public, prefix="/api/v1") app.include_router(api_v1, prefix="/api/v1", - dependencies=[Depends(HTTPBasicOrJWTToken())] + dependencies=[Depends(http_basic_or_jwt_token)], ) - app.include_router(router_login, prefix="/api/v1") + app.include_router(router_login, prefix="/api/v1", tags=["auth"]) app.add_middleware( CORSMiddleware, diff --git a/requirements.txt b/requirements.txt index 7cda9e48c..4b439079b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,7 +35,7 @@ flask-cors==3.0.9 # API Server fastapi==0.63.0 uvicorn==0.13.2 -fastapi_jwt_auth==0.5.0 +pyjwt==1.7.1 # Support for colorized terminal output colorama==0.4.4