refactor broadcasting to a queue per client

This commit is contained in:
Timothy Pogue 2022-10-09 15:04:52 -06:00
parent a10b2d003f
commit 3e8d8fd1b0
3 changed files with 42 additions and 15 deletions

View File

@ -1,6 +1,7 @@
import logging import logging
from typing import Any, Dict from typing import Any, Dict
import websockets
from fastapi import APIRouter, Depends, WebSocketDisconnect from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState from fastapi.websockets import WebSocket, WebSocketState
from pydantic import ValidationError from pydantic import ValidationError
@ -102,7 +103,6 @@ async def message_endpoint(
""" """
try: try:
channel = await channel_manager.on_connect(ws) channel = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws): if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}") logger.info(f"Consumer connected - {channel}")
@ -115,26 +115,34 @@ async def message_endpoint(
# Process the request here # Process the request here
await _process_consumer_request(request, channel, rpc) await _process_consumer_request(request, channel, rpc)
except WebSocketDisconnect: except (
WebSocketDisconnect,
websockets.exceptions.WebSocketException
):
# Handle client disconnects # Handle client disconnects
logger.info(f"Consumer disconnected - {channel}") logger.info(f"Consumer disconnected - {channel}")
await channel_manager.on_disconnect(ws) except RuntimeError:
except Exception as e:
logger.info(f"Consumer connection failed - {channel}")
logger.exception(e)
# Handle cases like - # Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent') # RuntimeError('Cannot call "send" once a closed message has been sent')
pass
except Exception as e:
logger.info(f"Consumer connection failed - {channel}")
logger.debug(e, exc_info=e)
finally:
await channel_manager.on_disconnect(ws) await channel_manager.on_disconnect(ws)
else: else:
if channel:
await channel_manager.on_disconnect(ws)
await ws.close() await ws.close()
except RuntimeError: except RuntimeError:
# WebSocket was closed # WebSocket was closed
await channel_manager.on_disconnect(ws) # Do nothing
pass
except Exception as e: except Exception as e:
logger.error(f"Failed to serve - {ws.client}") logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening # Log tracebacks to keep track of what errors are happening
logger.exception(e) logger.exception(e)
finally:
await channel_manager.on_disconnect(ws) await channel_manager.on_disconnect(ws)

View File

@ -245,6 +245,7 @@ class ApiServer(RPCHandler):
use_colors=False, use_colors=False,
log_config=None, log_config=None,
access_log=True if verbosity != 'error' else False, access_log=True if verbosity != 'error' else False,
ws_ping_interval=None # We do this explicitly ourselves
) )
try: try:
self._server = UvicornServer(uvconfig) self._server = UvicornServer(uvconfig)

View File

@ -1,6 +1,7 @@
import asyncio
import logging import logging
from threading import RLock from threading import RLock
from typing import List, Optional, Type from typing import Any, Dict, List, Optional, Type
from uuid import uuid4 from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocket as FastAPIWebSocket
@ -34,6 +35,8 @@ class WebSocketChannel:
self._serializer_cls = serializer_cls self._serializer_cls = serializer_cls
self._subscriptions: List[str] = [] self._subscriptions: List[str] = []
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket # Internal event to signify a closed websocket
self._closed = False self._closed = False
@ -72,6 +75,7 @@ class WebSocketChannel:
""" """
self._closed = True self._closed = True
self._relay_task.cancel()
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
@ -95,6 +99,20 @@ class WebSocketChannel:
""" """
return message_type in self._subscriptions return message_type in self._subscriptions
async def relay(self):
"""
Relay messages from the channel's queue and send them out. This is started
as a task.
"""
while True:
message = await self.queue.get()
try:
await self.send(message)
self.queue.task_done()
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager: class ChannelManager:
def __init__(self): def __init__(self):
@ -155,11 +173,11 @@ class ChannelManager:
with self._lock: with self._lock:
message_type = data.get('type') message_type = data.get('type')
for websocket, channel in self.channels.copy().items(): for websocket, channel in self.channels.copy().items():
try:
if channel.subscribed_to(message_type): if channel.subscribed_to(message_type):
await channel.send(data) if not channel.queue.full():
except RuntimeError: channel.queue.put_nowait(data)
# Handle cannot send after close cases else:
logger.info(f"Channel {channel} is too far behind, disconnecting")
await self.on_disconnect(websocket) await self.on_disconnect(websocket)
async def send_direct(self, channel, data): async def send_direct(self, channel, data):