From a43d436f9863543fc9a010c256a8b0659f087713 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 21 Oct 2019 19:43:44 +0200 Subject: [PATCH] Move decorators out of API Class --- freqtrade/rpc/api_server.py | 54 ++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/freqtrade/rpc/api_server.py b/freqtrade/rpc/api_server.py index 711202b27..95ee51eaf 100644 --- a/freqtrade/rpc/api_server.py +++ b/freqtrade/rpc/api_server.py @@ -2,7 +2,7 @@ import logging import threading from datetime import date, datetime from ipaddress import IPv4Address -from typing import Dict +from typing import Dict, Callable, Any from arrow import Arrow from flask import Flask, jsonify, request @@ -34,6 +34,34 @@ class ArrowJSONEncoder(JSONEncoder): return JSONEncoder.default(self, obj) +# Type should really be Callable[[ApiServer, Any], Any], but that will create a circular dependency +def require_login(func: Callable[[Any, Any], Any]): + + def func_wrapper(obj, *args, **kwargs): + + auth = request.authorization + if auth and obj.check_auth(auth.username, auth.password): + return func(obj, *args, **kwargs) + else: + return jsonify({"error": "Unauthorized"}), 401 + + return func_wrapper + + +# Type should really be Callable[[ApiServer], Any], but that will create a circular dependency +def rpc_catch_errors(func: Callable[[Any], Any]): + + def func_wrapper(obj, *args, **kwargs): + + try: + return func(obj, *args, **kwargs) + except RPCException as e: + logger.exception("API Error calling %s: %s", func.__name__, e) + return obj.rest_error(f"Error querying {func.__name__}: {e}") + + return func_wrapper + + class ApiServer(RPC): """ This class runs api server and provides rpc.rpc functionality to it @@ -41,34 +69,10 @@ class ApiServer(RPC): This class starts a none blocking thread the api server runs within """ - def rpc_catch_errors(func): - - def func_wrapper(self, *args, **kwargs): - - try: - return func(self, *args, **kwargs) - except RPCException as e: - logger.exception("API Error calling %s: %s", func.__name__, e) - return self.rest_error(f"Error querying {func.__name__}: {e}") - - return func_wrapper - def check_auth(self, username, password): return (username == self._config['api_server'].get('username') and password == self._config['api_server'].get('password')) - def require_login(func): - - def func_wrapper(self, *args, **kwargs): - - auth = request.authorization - if auth and self.check_auth(auth.username, auth.password): - return func(self, *args, **kwargs) - else: - return jsonify({"error": "Unauthorized"}), 401 - - return func_wrapper - def __init__(self, freqtrade) -> None: """ Init the api server, and init the super class RPC