Move decorators out of API Class

This commit is contained in:
Matthias 2019-10-21 19:43:44 +02:00
parent 8872158a6a
commit a43d436f98

View File

@ -2,7 +2,7 @@ import logging
import threading import threading
from datetime import date, datetime from datetime import date, datetime
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Dict from typing import Dict, Callable, Any
from arrow import Arrow from arrow import Arrow
from flask import Flask, jsonify, request from flask import Flask, jsonify, request
@ -34,6 +34,34 @@ class ArrowJSONEncoder(JSONEncoder):
return JSONEncoder.default(self, obj) return JSONEncoder.default(self, obj)
# Type should really be Callable[[ApiServer, Any], Any], but that will create a circular dependency
def require_login(func: Callable[[Any, Any], Any]):
def func_wrapper(obj, *args, **kwargs):
auth = request.authorization
if auth and obj.check_auth(auth.username, auth.password):
return func(obj, *args, **kwargs)
else:
return jsonify({"error": "Unauthorized"}), 401
return func_wrapper
# Type should really be Callable[[ApiServer], Any], but that will create a circular dependency
def rpc_catch_errors(func: Callable[[Any], Any]):
def func_wrapper(obj, *args, **kwargs):
try:
return func(obj, *args, **kwargs)
except RPCException as e:
logger.exception("API Error calling %s: %s", func.__name__, e)
return obj.rest_error(f"Error querying {func.__name__}: {e}")
return func_wrapper
class ApiServer(RPC): class ApiServer(RPC):
""" """
This class runs api server and provides rpc.rpc functionality to it This class runs api server and provides rpc.rpc functionality to it
@ -41,34 +69,10 @@ class ApiServer(RPC):
This class starts a none blocking thread the api server runs within This class starts a none blocking thread the api server runs within
""" """
def rpc_catch_errors(func):
def func_wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except RPCException as e:
logger.exception("API Error calling %s: %s", func.__name__, e)
return self.rest_error(f"Error querying {func.__name__}: {e}")
return func_wrapper
def check_auth(self, username, password): def check_auth(self, username, password):
return (username == self._config['api_server'].get('username') and return (username == self._config['api_server'].get('username') and
password == self._config['api_server'].get('password')) password == self._config['api_server'].get('password'))
def require_login(func):
def func_wrapper(self, *args, **kwargs):
auth = request.authorization
if auth and self.check_auth(auth.username, auth.password):
return func(self, *args, **kwargs)
else:
return jsonify({"error": "Unauthorized"}), 401
return func_wrapper
def __init__(self, freqtrade) -> None: def __init__(self, freqtrade) -> None:
""" """
Init the api server, and init the super class RPC Init the api server, and init the super class RPC