diff --git a/freqtrade/rpc/api_server2/auth.py b/freqtrade/rpc/api_server2/auth.py new file mode 100644 index 000000000..3155e7754 --- /dev/null +++ b/freqtrade/rpc/api_server2/auth.py @@ -0,0 +1,83 @@ +from freqtrade.rpc.api_server2.models import AccessAndRefreshToken, AccessToken +import secrets +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.security.http import HTTPBasic, HTTPBasicCredentials +from fastapi.security.utils import get_authorization_scheme_param +from fastapi_jwt_auth import AuthJWT +from pydantic import BaseModel + +from .deps import get_config + + +SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +router_login = APIRouter() + + +class Settings(BaseModel): + # TODO: should be set as config['api_server'].get('jwt_secret_key', 'super-secret') + authjwt_secret_key: str = "secret" + + +@AuthJWT.load_config +def get_jwt_config(): + return Settings() + + +def verify_auth(config, username: str, password: str): + return (secrets.compare_digest(username, config['api_server'].get('username')) and + secrets.compare_digest(password, config['api_server'].get('password'))) + + +class HTTPBasicOrJWTToken(HTTPBasic): + description = "Token Or Pass auth" + + async def __call__(self, request: Request, config=Depends(get_config) # type: ignore + ) -> Optional[str]: + header_authorization: str = request.headers.get("Authorization") + header_scheme, header_param = get_authorization_scheme_param(header_authorization) + if header_scheme.lower() == 'bearer': + AuthJWT.jwt_required() + elif header_scheme.lower() == 'basic': + credentials: Optional[HTTPBasicCredentials] = await HTTPBasic()(request) + if credentials and verify_auth(config, credentials.username, credentials.password): + return credentials.username + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Basic"}, + ) + + +@router_login.post('/token/login', response_model=AccessAndRefreshToken) +def token_login(form_data: HTTPBasicCredentials = Depends(HTTPBasic()), config=Depends(get_config)): + + print(form_data) + Authorize = AuthJWT() + + if verify_auth(config, form_data.username, form_data.password): + token_data = form_data.username + access_token = Authorize.create_access_token(subject=token_data) + refresh_token = Authorize.create_refresh_token(subject=token_data) + return { + "access_token": access_token, + "refresh_token": refresh_token, + } + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Basic"}, + ) + + +@router_login.post('/token/refresh', response_model=AccessToken) +def token_refresh(Authorize: AuthJWT = Depends()): + Authorize.jwt_refresh_token_required() + + access_token = Authorize.create_access_token(subject=Authorize.get_jwt_subject()) + return {'access_token': access_token} diff --git a/freqtrade/rpc/api_server2/models.py b/freqtrade/rpc/api_server2/models.py index 7cd628da0..c9ddb0d6f 100644 --- a/freqtrade/rpc/api_server2/models.py +++ b/freqtrade/rpc/api_server2/models.py @@ -6,6 +6,14 @@ class Ping(BaseModel): status: str +class AccessToken(BaseModel): + access_token: str + + +class AccessAndRefreshToken(AccessToken): + refresh_token: str + + class Version(BaseModel): version: str diff --git a/freqtrade/rpc/api_server2/webserver.py b/freqtrade/rpc/api_server2/webserver.py index 23cae8b73..84f6fc222 100644 --- a/freqtrade/rpc/api_server2/webserver.py +++ b/freqtrade/rpc/api_server2/webserver.py @@ -1,18 +1,17 @@ -from typing import Any, Dict - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware +from typing import Any, Dict, Optional import uvicorn +from fastapi import Depends, FastAPI +from fastapi.middleware.cors import CORSMiddleware -from freqtrade.rpc.rpc import RPCHandler, RPC +from freqtrade.rpc.rpc import RPC, RPCHandler from .uvicorn_threaded import UvicornServer class ApiServer(RPCHandler): - _rpc = None - _config = None + _rpc: Optional[RPC] = None + _config: Dict[str, Any] = {} def __init__(self, rpc: RPC, config: Dict[str, Any]) -> None: super().__init__(rpc, config) @@ -21,7 +20,7 @@ class ApiServer(RPCHandler): ApiServer._rpc = rpc ApiServer._config = config - self.app = FastAPI() + self.app = FastAPI(title="Freqtrade API") self.configure_app(self.app, self._config) self.start_api() @@ -35,12 +34,15 @@ class ApiServer(RPCHandler): pass def configure_app(self, app: FastAPI, config): - from .api_v1 import router_public as api_v1_public from .api_v1 import router as api_v1 + from .api_v1 import router_public as api_v1_public + from .auth import router_login, HTTPBasicOrJWTToken app.include_router(api_v1_public, prefix="/api/v1") - # TODO: Include auth dependency! - app.include_router(api_v1, prefix="/api/v1") + app.include_router(api_v1, prefix="/api/v1", + dependencies=[Depends(HTTPBasicOrJWTToken())] + ) + app.include_router(router_login, prefix="/api/v1") app.add_middleware( CORSMiddleware, diff --git a/requirements.txt b/requirements.txt index ad43c1006..7cda9e48c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,6 +35,7 @@ flask-cors==3.0.9 # API Server fastapi==0.63.0 uvicorn==0.13.2 +fastapi_jwt_auth==0.5.0 # Support for colorized terminal output colorama==0.4.4