initial rework separating server and client impl

This commit is contained in:
Timothy Pogue
2022-08-29 13:41:15 -06:00
parent 8c4e68b8eb
commit 7952e0df25
25 changed files with 1329 additions and 1068 deletions

View File

@@ -1,8 +1,10 @@
import logging
import secrets
from datetime import datetime, timedelta
from typing import Any, Dict, Union
import jwt
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, WebSocket, status
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.http import HTTPBasic, HTTPBasicCredentials
@@ -10,6 +12,8 @@ from freqtrade.rpc.api_server.api_schemas import AccessAndRefreshToken, AccessTo
from freqtrade.rpc.api_server.deps import get_api_config
logger = logging.getLogger(__name__)
ALGORITHM = "HS256"
router_login = APIRouter()
@@ -44,6 +48,24 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
return username
# This should be reimplemented to better realign with the existing tools provided
# by FastAPI regarding API Tokens
async def get_ws_token(
ws: WebSocket,
token: Union[str, None] = None,
api_config: Dict[str, Any] = Depends(get_api_config)
):
secret_ws_token = api_config['ws_token']
if token == secret_ws_token:
# Just return the token if it matches
return token
else:
logger.debug("Denying websocket request")
# If it doesn't match, close the websocket connection
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
def create_token(data: dict, secret_key: str, token_type: str = "access") -> str:
to_encode = data.copy()
if token_type == "access":

View File

@@ -0,0 +1,52 @@
import logging
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from freqtrade.rpc.api_server.deps import get_channel_manager
from freqtrade.rpc.api_server.ws.utils import is_websocket_alive
logger = logging.getLogger(__name__)
# Private router, protected by API Key authentication
router = APIRouter()
@router.websocket("/message/ws")
async def message_endpoint(
ws: WebSocket,
channel_manager=Depends(get_channel_manager)
):
try:
if is_websocket_alive(ws):
logger.info(f"Consumer connected - {ws.client}")
# TODO:
# Return a channel ID, pass that instead of ws to the rest of the methods
channel = await channel_manager.on_connect(ws)
# Keep connection open until explicitly closed, and sleep
try:
while not channel.is_closed():
request = await channel.recv()
# This is where we'd parse the request. For now this should only
# be a list of topics to subscribe too. List[str]
# Maybe allow the consumer to update the topics subscribed
# during runtime?
logger.info(f"Consumer request - {request}")
except WebSocketDisconnect:
# Handle client disconnects
logger.info(f"Consumer disconnected - {ws.client}")
await channel_manager.on_disconnect(ws)
except Exception as e:
logger.info(f"Consumer connection failed - {ws.client}")
logger.exception(e)
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
await channel_manager.on_disconnect(ws)
except Exception:
logger.error(f"Failed to serve - {ws.client}")
await channel_manager.on_disconnect(ws)

View File

@@ -41,6 +41,10 @@ def get_exchange(config=Depends(get_config)):
return ApiServer._exchange
def get_channel_manager():
return ApiServer._channel_manager
def is_webserver_mode(config=Depends(get_config)):
if config['runmode'] != RunMode.WEBSERVER:
raise RPCException('Bot is not in the correct state')

View File

@@ -1,15 +1,20 @@
import asyncio
import logging
from ipaddress import IPv4Address
from threading import Thread
from typing import Any, Dict
import orjson
import uvicorn
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Look into alternatives
from janus import Queue as ThreadedQueue
from starlette.responses import JSONResponse
from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws.channel import ChannelManager
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@@ -43,6 +48,10 @@ class ApiServer(RPCHandler):
_config: Dict[str, Any] = {}
# Exchange - only available in webserver mode.
_exchange = None
# websocket message queue stuff
_channel_manager = None
_thread = None
_loop = None
def __new__(cls, *args, **kwargs):
"""
@@ -64,10 +73,15 @@ class ApiServer(RPCHandler):
return
self._standalone: bool = standalone
self._server = None
self._queue = None
self._background_task = None
ApiServer.__initialized = True
api_config = self._config['api_server']
ApiServer._channel_manager = ChannelManager()
self.app = FastAPI(title="Freqtrade API",
docs_url='/docs' if api_config.get('enable_openapi', False) else None,
redoc_url=None,
@@ -95,6 +109,18 @@ class ApiServer(RPCHandler):
logger.info("Stopping API Server")
self._server.cleanup()
if self._thread and self._loop:
logger.info("Stopping API Server background tasks")
if self._background_task:
# Cancel the queue task
self._background_task.cancel()
# Finally stop the loop
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
@classmethod
def shutdown(cls):
cls.__initialized = False
@@ -104,7 +130,10 @@ class ApiServer(RPCHandler):
cls._rpc = None
def send_msg(self, msg: Dict[str, str]) -> None:
pass
if self._queue:
logger.info(f"Adding message to queue: {msg}")
sync_q = self._queue.sync_q
sync_q.put(msg)
def handle_rpc_exception(self, request, exc):
logger.exception(f"API Error calling: {exc}")
@@ -114,10 +143,12 @@ class ApiServer(RPCHandler):
)
def configure_app(self, app: FastAPI, config):
from freqtrade.rpc.api_server.api_auth import http_basic_or_jwt_token, router_login
from freqtrade.rpc.api_server.api_auth import (get_ws_token, http_basic_or_jwt_token,
router_login)
from freqtrade.rpc.api_server.api_backtest import router as api_backtest
from freqtrade.rpc.api_server.api_v1 import router as api_v1
from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public
from freqtrade.rpc.api_server.api_ws import router as ws_router
from freqtrade.rpc.api_server.web_ui import router_ui
app.include_router(api_v1_public, prefix="/api/v1")
@@ -128,6 +159,9 @@ class ApiServer(RPCHandler):
app.include_router(api_backtest, prefix="/api/v1",
dependencies=[Depends(http_basic_or_jwt_token)],
)
app.include_router(ws_router, prefix="/api/v1",
dependencies=[Depends(get_ws_token)]
)
app.include_router(router_login, prefix="/api/v1", tags=["auth"])
# UI Router MUST be last!
app.include_router(router_ui, prefix='')
@@ -142,6 +176,43 @@ class ApiServer(RPCHandler):
app.add_exception_handler(RPCException, self.handle_rpc_exception)
def start_message_queue(self):
# Create a new loop, as it'll be just for the background thread
self._loop = asyncio.new_event_loop()
# Start the thread
if not self._thread:
self._thread = Thread(target=self._loop.run_forever)
self._thread.start()
else:
raise RuntimeError("Threaded loop is already running")
# Finally, submit the coro to the thread
self._background_task = asyncio.run_coroutine_threadsafe(
self._broadcast_queue_data(), loop=self._loop)
async def _broadcast_queue_data(self):
# Instantiate the queue in this coroutine so it's attached to our loop
self._queue = ThreadedQueue()
async_queue = self._queue.async_q
try:
while True:
logger.debug("Getting queue data...")
# Get data from queue
data = await async_queue.get()
logger.debug(f"Found data: {data}")
# Broadcast it
await self._channel_manager.broadcast(data)
# Sleep, make this configurable?
await asyncio.sleep(0.1)
except asyncio.CancelledError:
# Silently stop
pass
# For testing, shouldn't happen when stable
except Exception as e:
logger.info(f"Exception happened in background task: {e}")
def start_api(self):
"""
Start API ... should be run in thread.
@@ -179,6 +250,7 @@ class ApiServer(RPCHandler):
if self._standalone:
self._server.run()
else:
self.start_message_queue()
self._server.run_in_thread()
except Exception:
logger.exception("Api server failed to start.")

View File

@@ -0,0 +1,146 @@
import logging
from threading import RLock
from typing import Type
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import ORJSONWebSocketSerializer, WebSocketSerializer
from freqtrade.rpc.api_server.ws.types import WebSocketType
logger = logging.getLogger(__name__)
class WebSocketChannel:
"""
Object to help facilitate managing a websocket connection
"""
def __init__(
self,
websocket: WebSocketType,
serializer_cls: Type[WebSocketSerializer] = ORJSONWebSocketSerializer
):
# The WebSocket object
self._websocket = WebSocketProxy(websocket)
# The Serializing class for the WebSocket object
self._serializer_cls = serializer_cls
# Internal event to signify a closed websocket
self._closed = False
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = self._serializer_cls(self._websocket)
async def send(self, data):
"""
Send data on the wrapped websocket
"""
# logger.info(f"Serialized Send - {self._wrapped_ws._serialize(data)}")
await self._wrapped_ws.send(data)
async def recv(self):
"""
Receive data on the wrapped websocket
"""
return await self._wrapped_ws.recv()
async def ping(self):
"""
Ping the websocket
"""
return await self._websocket.ping()
async def close(self):
"""
Close the WebSocketChannel
"""
self._closed = True
def is_closed(self):
return self._closed
class ChannelManager:
def __init__(self):
self.channels = dict()
self._lock = RLock() # Re-entrant Lock
async def on_connect(self, websocket: WebSocketType):
"""
Wrap websocket connection into Channel and add to list
:param websocket: The WebSocket object to attach to the Channel
"""
if hasattr(websocket, "accept"):
try:
await websocket.accept()
except RuntimeError:
# The connection was closed before we could accept it
return
ws_channel = WebSocketChannel(websocket)
with self._lock:
self.channels[websocket] = ws_channel
return ws_channel
async def on_disconnect(self, websocket: WebSocketType):
"""
Call close on the channel if it's not, and remove from channel list
:param websocket: The WebSocket objet attached to the Channel
"""
with self._lock:
channel = self.channels.get(websocket)
if channel:
logger.debug(f"Disconnecting channel - {channel}")
if not channel.is_closed():
await channel.close()
del self.channels[websocket]
async def disconnect_all(self):
"""
Disconnect all Channels
"""
with self._lock:
for websocket, channel in self.channels.items():
if not channel.is_closed():
await channel.close()
self.channels = dict()
async def broadcast(self, data):
"""
Broadcast data on all Channels
:param data: The data to send
"""
with self._lock:
logger.debug(f"Broadcasting data: {data}")
for websocket, channel in self.channels.items():
try:
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
await self.on_disconnect(websocket)
async def send_direct(self, channel, data):
"""
Send data directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send data through
:param data: The data to send
"""
# We iterate over the channels to get reference to the websocket object
# so we can disconnect incase of failure
await channel.send(data)
def has_channels(self):
"""
Flag for more than 0 channels
"""
return len(self.channels) > 0

View File

@@ -0,0 +1,61 @@
from typing import Union
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
from freqtrade.rpc.api_server.ws.types import WebSocketType
class WebSocketProxy:
"""
WebSocketProxy object to bring the FastAPIWebSocket and websockets.WebSocketClientProtocol
under the same API
"""
def __init__(self, websocket: WebSocketType):
self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket
async def send(self, data):
"""
Send data on the wrapped websocket
"""
if isinstance(data, str):
data = data.encode()
if hasattr(self._websocket, "send_bytes"):
await self._websocket.send_bytes(data)
else:
await self._websocket.send(data)
async def recv(self):
"""
Receive data on the wrapped websocket
"""
if hasattr(self._websocket, "receive_bytes"):
return await self._websocket.receive_bytes()
else:
return await self._websocket.recv()
async def ping(self):
"""
Ping the websocket, not supported by FastAPI WebSockets
"""
if hasattr(self._websocket, "ping"):
return await self._websocket.ping()
return False
async def close(self, code: int = 1000):
"""
Close the websocket connection, only supported by FastAPI WebSockets
"""
if hasattr(self._websocket, "close"):
return await self._websocket.close(code)
pass
async def accept(self):
"""
Accept the WebSocket connection, only support by FastAPI WebSockets
"""
if hasattr(self._websocket, "accept"):
return await self._websocket.accept()
pass

View File

@@ -0,0 +1,65 @@
import json
import logging
from abc import ABC, abstractmethod
import msgpack
import orjson
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
logger = logging.getLogger(__name__)
class WebSocketSerializer(ABC):
def __init__(self, websocket: WebSocketProxy):
self._websocket: WebSocketProxy = websocket
@abstractmethod
def _serialize(self, data):
raise NotImplementedError()
@abstractmethod
def _deserialize(self, data):
raise NotImplementedError()
async def send(self, data: bytes):
await self._websocket.send(self._serialize(data))
async def recv(self) -> bytes:
data = await self._websocket.recv()
return self._deserialize(data)
async def close(self, code: int = 1000):
await self._websocket.close(code)
# Going to explore using MsgPack as the serialization,
# as that might be the best method for sending pandas
# dataframes over the wire
class JSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data):
return json.dumps(data)
def _deserialize(self, data):
return json.loads(data)
class ORJSONWebSocketSerializer(WebSocketSerializer):
ORJSON_OPTIONS = orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY
def _serialize(self, data):
return orjson.dumps(data, option=self.ORJSON_OPTIONS)
def _deserialize(self, data):
return orjson.loads(data, option=self.ORJSON_OPTIONS)
class MsgPackWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data):
return msgpack.packb(data, use_bin_type=True)
def _deserialize(self, data):
return msgpack.unpackb(data, raw=False)

View File

@@ -0,0 +1,8 @@
from typing import Any, Dict, TypeVar
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)
MessageType = Dict[str, Any]

View File

@@ -0,0 +1,12 @@
from fastapi import WebSocket
# fastapi does not make this available through it, so import directly from starlette
from starlette.websockets import WebSocketState
async def is_websocket_alive(ws: WebSocket) -> bool:
if (
ws.application_state == WebSocketState.CONNECTED and
ws.client_state == WebSocketState.CONNECTED
):
return True
return False