Implement token/login and token/refresh endpoints
This commit is contained in:
		| @@ -2,11 +2,16 @@ import logging | |||||||
| import threading | import threading | ||||||
| from datetime import date, datetime | from datetime import date, datetime | ||||||
| from ipaddress import IPv4Address | from ipaddress import IPv4Address | ||||||
| from typing import Dict, Callable, Any | from typing import Any, Callable, Dict | ||||||
|  |  | ||||||
| from arrow import Arrow | from arrow import Arrow | ||||||
| from flask import Flask, jsonify, request | from flask import Flask, jsonify, request | ||||||
| from flask.json import JSONEncoder | from flask.json import JSONEncoder | ||||||
|  | from flask_jwt_extended import (JWTManager, create_access_token, | ||||||
|  |                                 create_refresh_token, get_jwt_identity, | ||||||
|  |                                 jwt_refresh_token_required, | ||||||
|  |                                 verify_jwt_in_request_optional) | ||||||
|  | from werkzeug.security import safe_str_cmp | ||||||
| from werkzeug.serving import make_server | from werkzeug.serving import make_server | ||||||
|  |  | ||||||
| from freqtrade.__init__ import __version__ | from freqtrade.__init__ import __version__ | ||||||
| @@ -38,9 +43,10 @@ class ArrowJSONEncoder(JSONEncoder): | |||||||
| def require_login(func: Callable[[Any, Any], Any]): | def require_login(func: Callable[[Any, Any], Any]): | ||||||
|  |  | ||||||
|     def func_wrapper(obj, *args, **kwargs): |     def func_wrapper(obj, *args, **kwargs): | ||||||
|  |         verify_jwt_in_request_optional() | ||||||
|         auth = request.authorization |         auth = request.authorization | ||||||
|         if auth and obj.check_auth(auth.username, auth.password): |         i = get_jwt_identity() | ||||||
|  |         if i or auth and obj.check_auth(auth.username, auth.password): | ||||||
|             return func(obj, *args, **kwargs) |             return func(obj, *args, **kwargs) | ||||||
|         else: |         else: | ||||||
|             return jsonify({"error": "Unauthorized"}), 401 |             return jsonify({"error": "Unauthorized"}), 401 | ||||||
| @@ -70,8 +76,8 @@ class ApiServer(RPC): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def check_auth(self, username, password): |     def check_auth(self, username, password): | ||||||
|         return (username == self._config['api_server'].get('username') and |         return (safe_str_cmp(username, self._config['api_server'].get('username')) and | ||||||
|                 password == self._config['api_server'].get('password')) |                 safe_str_cmp(password, self._config['api_server'].get('password'))) | ||||||
|  |  | ||||||
|     def __init__(self, freqtrade) -> None: |     def __init__(self, freqtrade) -> None: | ||||||
|         """ |         """ | ||||||
| @@ -83,6 +89,11 @@ class ApiServer(RPC): | |||||||
|  |  | ||||||
|         self._config = freqtrade.config |         self._config = freqtrade.config | ||||||
|         self.app = Flask(__name__) |         self.app = Flask(__name__) | ||||||
|  |  | ||||||
|  |         # Setup the Flask-JWT-Extended extension | ||||||
|  |         self.app.config['JWT_SECRET_KEY'] = 'super-secret'  # Change this! | ||||||
|  |  | ||||||
|  |         self.jwt = JWTManager(self.app) | ||||||
|         self.app.json_encoder = ArrowJSONEncoder |         self.app.json_encoder = ArrowJSONEncoder | ||||||
|  |  | ||||||
|         # Register application handling |         # Register application handling | ||||||
| @@ -148,6 +159,10 @@ class ApiServer(RPC): | |||||||
|         self.app.register_error_handler(404, self.page_not_found) |         self.app.register_error_handler(404, self.page_not_found) | ||||||
|  |  | ||||||
|         # Actions to control the bot |         # Actions to control the bot | ||||||
|  |         self.app.add_url_rule(f'{BASE_URI}/token/login', 'login', | ||||||
|  |                               view_func=self._login, methods=['POST']) | ||||||
|  |         self.app.add_url_rule(f'{BASE_URI}/token/refresh', 'token_refresh', | ||||||
|  |                               view_func=self._refresh_token, methods=['POST']) | ||||||
|         self.app.add_url_rule(f'{BASE_URI}/start', 'start', |         self.app.add_url_rule(f'{BASE_URI}/start', 'start', | ||||||
|                               view_func=self._start, methods=['POST']) |                               view_func=self._start, methods=['POST']) | ||||||
|         self.app.add_url_rule(f'{BASE_URI}/stop', 'stop', view_func=self._stop, methods=['POST']) |         self.app.add_url_rule(f'{BASE_URI}/stop', 'stop', view_func=self._stop, methods=['POST']) | ||||||
| @@ -199,6 +214,37 @@ class ApiServer(RPC): | |||||||
|             'code': 404 |             'code': 404 | ||||||
|         }), 404 |         }), 404 | ||||||
|  |  | ||||||
|  |     @require_login | ||||||
|  |     @rpc_catch_errors | ||||||
|  |     def _login(self): | ||||||
|  |         """ | ||||||
|  |         Handler for /token/login | ||||||
|  |         Returns a JWT token | ||||||
|  |         """ | ||||||
|  |         auth = request.authorization | ||||||
|  |         if auth and self.check_auth(auth.username, auth.password): | ||||||
|  |             keystuff = {'u': auth.username} | ||||||
|  |             ret = { | ||||||
|  |                 'access_token': create_access_token(identity=keystuff), | ||||||
|  |                 'refresh_token': create_refresh_token(identity=keystuff), | ||||||
|  |             } | ||||||
|  |             return self.rest_dump(ret) | ||||||
|  |  | ||||||
|  |         return jsonify({"error": "Unauthorized"}), 401 | ||||||
|  |  | ||||||
|  |     @jwt_refresh_token_required | ||||||
|  |     @rpc_catch_errors | ||||||
|  |     def _refresh_token(self): | ||||||
|  |         """ | ||||||
|  |         Handler for /token/refresh | ||||||
|  |         Returns a JWT token based on a JWT refresh token | ||||||
|  |         """ | ||||||
|  |         current_user = get_jwt_identity() | ||||||
|  |         new_token = create_access_token(identity=current_user, fresh=False) | ||||||
|  |  | ||||||
|  |         ret = {'access_token': new_token} | ||||||
|  |         return self.rest_dump(ret) | ||||||
|  |  | ||||||
|     @require_login |     @require_login | ||||||
|     @rpc_catch_errors |     @rpc_catch_errors | ||||||
|     def _start(self): |     def _start(self): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user