initial revision

This commit is contained in:
Timothy Pogue
2022-11-14 20:27:45 -07:00
parent a951b49541
commit 659c8c237f
7 changed files with 494 additions and 241 deletions

View File

@@ -1,16 +1,17 @@
import asyncio
import logging
from typing import Any, Dict
from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState
from fastapi import APIRouter, Depends
from fastapi.websockets import WebSocket, WebSocketDisconnect
from pydantic import ValidationError
from websockets.exceptions import WebSocketException
from websockets.exceptions import ConnectionClosed
from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
from freqtrade.rpc.api_server.deps import get_message_stream, get_rpc
from freqtrade.rpc.api_server.ws import WebSocketChannel
from freqtrade.rpc.api_server.ws.channel import ChannelManager
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
WSRequestSchema, WSWhitelistMessage)
from freqtrade.rpc.rpc import RPC
@@ -22,23 +23,63 @@ logger = logging.getLogger(__name__)
router = APIRouter()
async def is_websocket_alive(ws: WebSocket) -> bool:
# async def is_websocket_alive(ws: WebSocket) -> bool:
# """
# Check if a FastAPI Websocket is still open
# """
# if (
# ws.application_state == WebSocketState.CONNECTED and
# ws.client_state == WebSocketState.CONNECTED
# ):
# return True
# return False
class WebSocketChannelClosed(Exception):
"""
Check if a FastAPI Websocket is still open
General WebSocket exception to signal closing the channel
"""
if (
ws.application_state == WebSocketState.CONNECTED and
ws.client_state == WebSocketState.CONNECTED
pass
async def channel_reader(channel: WebSocketChannel, rpc: RPC):
"""
Iterate over the messages from the channel and process the request
"""
try:
async for message in channel:
await _process_consumer_request(message, channel, rpc)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
return True
return False
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def channel_broadcaster(channel: WebSocketChannel, message_stream: MessageStream):
"""
Iterate over messages in the message stream and send them
"""
try:
async for message in message_stream:
await channel.send(message)
except (
RuntimeError,
WebSocketDisconnect,
ConnectionClosed
):
raise WebSocketChannelClosed
except asyncio.CancelledError:
return
async def _process_consumer_request(
request: Dict[str, Any],
channel: WebSocketChannel,
rpc: RPC,
channel_manager: ChannelManager
rpc: RPC
):
"""
Validate and handle a request from a websocket consumer
@@ -75,7 +116,7 @@ async def _process_consumer_request(
# Format response
response = WSWhitelistMessage(data=whitelist)
# Send it back
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
await channel.send(response.dict(exclude_none=True))
elif type == RPCRequestType.ANALYZED_DF:
limit = None
@@ -86,53 +127,76 @@ async def _process_consumer_request(
# For every pair in the generator, send a separate message
for message in rpc._ws_request_analyzed_df(limit):
# Format response
response = WSAnalyzedDFMessage(data=message)
await channel_manager.send_direct(channel, response.dict(exclude_none=True))
await channel.send(response.dict(exclude_none=True))
@router.websocket("/message/ws")
async def message_endpoint(
ws: WebSocket,
websocket: WebSocket,
token: str = Depends(validate_ws_token),
rpc: RPC = Depends(get_rpc),
channel_manager=Depends(get_channel_manager),
token: str = Depends(validate_ws_token)
message_stream: MessageStream = Depends(get_message_stream)
):
"""
Message WebSocket endpoint, facilitates sending RPC messages
"""
try:
channel = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws):
async with WebSocketChannel(websocket).connect() as channel:
try:
logger.info(f"Channel connected - {channel}")
logger.info(f"Consumer connected - {channel}")
channel_tasks = asyncio.gather(
channel_reader(channel, rpc),
channel_broadcaster(channel, message_stream)
)
await channel_tasks
# Keep connection open until explicitly closed, and process requests
try:
while not channel.is_closed():
request = await channel.recv()
finally:
logger.info(f"Channel disconnected - {channel}")
channel_tasks.cancel()
# Process the request here
await _process_consumer_request(request, channel, rpc, channel_manager)
except (WebSocketDisconnect, WebSocketException):
# Handle client disconnects
logger.info(f"Consumer disconnected - {channel}")
except RuntimeError:
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
pass
except Exception as e:
logger.info(f"Consumer connection failed - {channel}: {e}")
logger.debug(e, exc_info=e)
# @router.websocket("/message/ws")
# async def message_endpoint(
# ws: WebSocket,
# rpc: RPC = Depends(get_rpc),
# channel_manager=Depends(get_channel_manager),
# token: str = Depends(validate_ws_token)
# ):
# """
# Message WebSocket endpoint, facilitates sending RPC messages
# """
# try:
# channel = await channel_manager.on_connect(ws)
# if await is_websocket_alive(ws):
except RuntimeError:
# WebSocket was closed
# Do nothing
pass
except Exception as e:
logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening
logger.exception(e)
finally:
if channel:
await channel_manager.on_disconnect(ws)
# logger.info(f"Consumer connected - {channel}")
# # Keep connection open until explicitly closed, and process requests
# try:
# while not channel.is_closed():
# request = await channel.recv()
# # Process the request here
# await _process_consumer_request(request, channel, rpc, channel_manager)
# except (WebSocketDisconnect, WebSocketException):
# # Handle client disconnects
# logger.info(f"Consumer disconnected - {channel}")
# except RuntimeError:
# # Handle cases like -
# # RuntimeError('Cannot call "send" once a closed message has been sent')
# pass
# except Exception as e:
# logger.info(f"Consumer connection failed - {channel}: {e}")
# logger.debug(e, exc_info=e)
# except RuntimeError:
# # WebSocket was closed
# # Do nothing
# pass
# except Exception as e:
# logger.error(f"Failed to serve - {ws.client}")
# # Log tracebacks to keep track of what errors are happening
# logger.exception(e)
# finally:
# if channel:
# await channel_manager.on_disconnect(ws)