refactor broadcasting to a queue per client

This commit is contained in:
Timothy Pogue 2022-10-09 15:04:52 -06:00
parent a10b2d003f
commit 3e8d8fd1b0
3 changed files with 42 additions and 15 deletions

View File

@ -1,6 +1,7 @@
import logging
from typing import Any, Dict
import websockets
from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState
from pydantic import ValidationError
@ -102,7 +103,6 @@ async def message_endpoint(
"""
try:
channel = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}")
@ -115,26 +115,34 @@ async def message_endpoint(
# Process the request here
await _process_consumer_request(request, channel, rpc)
except WebSocketDisconnect:
except (
WebSocketDisconnect,
websockets.exceptions.WebSocketException
):
# Handle client disconnects
logger.info(f"Consumer disconnected - {channel}")
await channel_manager.on_disconnect(ws)
except Exception as e:
logger.info(f"Consumer connection failed - {channel}")
logger.exception(e)
except RuntimeError:
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
pass
except Exception as e:
logger.info(f"Consumer connection failed - {channel}")
logger.debug(e, exc_info=e)
finally:
await channel_manager.on_disconnect(ws)
else:
if channel:
await channel_manager.on_disconnect(ws)
await ws.close()
except RuntimeError:
# WebSocket was closed
await channel_manager.on_disconnect(ws)
# Do nothing
pass
except Exception as e:
logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening
logger.exception(e)
finally:
await channel_manager.on_disconnect(ws)

View File

@ -245,6 +245,7 @@ class ApiServer(RPCHandler):
use_colors=False,
log_config=None,
access_log=True if verbosity != 'error' else False,
ws_ping_interval=None # We do this explicitly ourselves
)
try:
self._server = UvicornServer(uvconfig)

View File

@ -1,6 +1,7 @@
import asyncio
import logging
from threading import RLock
from typing import List, Optional, Type
from typing import Any, Dict, List, Optional, Type
from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket
@ -34,6 +35,8 @@ class WebSocketChannel:
self._serializer_cls = serializer_cls
self._subscriptions: List[str] = []
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket
self._closed = False
@ -72,6 +75,7 @@ class WebSocketChannel:
"""
self._closed = True
self._relay_task.cancel()
def is_closed(self) -> bool:
"""
@ -95,6 +99,20 @@ class WebSocketChannel:
"""
return message_type in self._subscriptions
async def relay(self):
"""
Relay messages from the channel's queue and send them out. This is started
as a task.
"""
while True:
message = await self.queue.get()
try:
await self.send(message)
self.queue.task_done()
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager:
def __init__(self):
@ -155,11 +173,11 @@ class ChannelManager:
with self._lock:
message_type = data.get('type')
for websocket, channel in self.channels.copy().items():
try:
if channel.subscribed_to(message_type):
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
if not channel.queue.full():
channel.queue.put_nowait(data)
else:
logger.info(f"Channel {channel} is too far behind, disconnecting")
await self.on_disconnect(websocket)
async def send_direct(self, channel, data):