Implement token/login and token/refresh endpoints

This commit is contained in:
Matthias 2020-05-10 10:35:38 +02:00
parent b72997fc2b
commit 8139058fcc

View File

@ -2,11 +2,16 @@ import logging
import threading import threading
from datetime import date, datetime from datetime import date, datetime
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Dict, Callable, Any from typing import Any, Callable, Dict
from arrow import Arrow from arrow import Arrow
from flask import Flask, jsonify, request from flask import Flask, jsonify, request
from flask.json import JSONEncoder from flask.json import JSONEncoder
from flask_jwt_extended import (JWTManager, create_access_token,
create_refresh_token, get_jwt_identity,
jwt_refresh_token_required,
verify_jwt_in_request_optional)
from werkzeug.security import safe_str_cmp
from werkzeug.serving import make_server from werkzeug.serving import make_server
from freqtrade.__init__ import __version__ from freqtrade.__init__ import __version__
@ -38,9 +43,10 @@ class ArrowJSONEncoder(JSONEncoder):
def require_login(func: Callable[[Any, Any], Any]): def require_login(func: Callable[[Any, Any], Any]):
def func_wrapper(obj, *args, **kwargs): def func_wrapper(obj, *args, **kwargs):
verify_jwt_in_request_optional()
auth = request.authorization auth = request.authorization
if auth and obj.check_auth(auth.username, auth.password): i = get_jwt_identity()
if i or auth and obj.check_auth(auth.username, auth.password):
return func(obj, *args, **kwargs) return func(obj, *args, **kwargs)
else: else:
return jsonify({"error": "Unauthorized"}), 401 return jsonify({"error": "Unauthorized"}), 401
@ -70,8 +76,8 @@ class ApiServer(RPC):
""" """
def check_auth(self, username, password): def check_auth(self, username, password):
return (username == self._config['api_server'].get('username') and return (safe_str_cmp(username, self._config['api_server'].get('username')) and
password == self._config['api_server'].get('password')) safe_str_cmp(password, self._config['api_server'].get('password')))
def __init__(self, freqtrade) -> None: def __init__(self, freqtrade) -> None:
""" """
@ -83,6 +89,11 @@ class ApiServer(RPC):
self._config = freqtrade.config self._config = freqtrade.config
self.app = Flask(__name__) self.app = Flask(__name__)
# Setup the Flask-JWT-Extended extension
self.app.config['JWT_SECRET_KEY'] = 'super-secret' # Change this!
self.jwt = JWTManager(self.app)
self.app.json_encoder = ArrowJSONEncoder self.app.json_encoder = ArrowJSONEncoder
# Register application handling # Register application handling
@ -148,6 +159,10 @@ class ApiServer(RPC):
self.app.register_error_handler(404, self.page_not_found) self.app.register_error_handler(404, self.page_not_found)
# Actions to control the bot # Actions to control the bot
self.app.add_url_rule(f'{BASE_URI}/token/login', 'login',
view_func=self._login, methods=['POST'])
self.app.add_url_rule(f'{BASE_URI}/token/refresh', 'token_refresh',
view_func=self._refresh_token, methods=['POST'])
self.app.add_url_rule(f'{BASE_URI}/start', 'start', self.app.add_url_rule(f'{BASE_URI}/start', 'start',
view_func=self._start, methods=['POST']) view_func=self._start, methods=['POST'])
self.app.add_url_rule(f'{BASE_URI}/stop', 'stop', view_func=self._stop, methods=['POST']) self.app.add_url_rule(f'{BASE_URI}/stop', 'stop', view_func=self._stop, methods=['POST'])
@ -199,6 +214,37 @@ class ApiServer(RPC):
'code': 404 'code': 404
}), 404 }), 404
@require_login
@rpc_catch_errors
def _login(self):
"""
Handler for /token/login
Returns a JWT token
"""
auth = request.authorization
if auth and self.check_auth(auth.username, auth.password):
keystuff = {'u': auth.username}
ret = {
'access_token': create_access_token(identity=keystuff),
'refresh_token': create_refresh_token(identity=keystuff),
}
return self.rest_dump(ret)
return jsonify({"error": "Unauthorized"}), 401
@jwt_refresh_token_required
@rpc_catch_errors
def _refresh_token(self):
"""
Handler for /token/refresh
Returns a JWT token based on a JWT refresh token
"""
current_user = get_jwt_identity()
new_token = create_access_token(identity=current_user, fresh=False)
ret = {'access_token': new_token}
return self.rest_dump(ret)
@require_login @require_login
@rpc_catch_errors @rpc_catch_errors
def _start(self): def _start(self):