refactor broadcasting to a queue per client
This commit is contained in:
parent
a10b2d003f
commit
3e8d8fd1b0
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
import websockets
|
||||
from fastapi import APIRouter, Depends, WebSocketDisconnect
|
||||
from fastapi.websockets import WebSocket, WebSocketState
|
||||
from pydantic import ValidationError
|
||||
@ -102,7 +103,6 @@ async def message_endpoint(
|
||||
"""
|
||||
try:
|
||||
channel = await channel_manager.on_connect(ws)
|
||||
|
||||
if await is_websocket_alive(ws):
|
||||
|
||||
logger.info(f"Consumer connected - {channel}")
|
||||
@ -115,26 +115,34 @@ async def message_endpoint(
|
||||
# Process the request here
|
||||
await _process_consumer_request(request, channel, rpc)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
except (
|
||||
WebSocketDisconnect,
|
||||
websockets.exceptions.WebSocketException
|
||||
):
|
||||
# Handle client disconnects
|
||||
logger.info(f"Consumer disconnected - {channel}")
|
||||
await channel_manager.on_disconnect(ws)
|
||||
except Exception as e:
|
||||
logger.info(f"Consumer connection failed - {channel}")
|
||||
logger.exception(e)
|
||||
except RuntimeError:
|
||||
# Handle cases like -
|
||||
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.info(f"Consumer connection failed - {channel}")
|
||||
logger.debug(e, exc_info=e)
|
||||
finally:
|
||||
await channel_manager.on_disconnect(ws)
|
||||
|
||||
else:
|
||||
if channel:
|
||||
await channel_manager.on_disconnect(ws)
|
||||
await ws.close()
|
||||
|
||||
except RuntimeError:
|
||||
# WebSocket was closed
|
||||
await channel_manager.on_disconnect(ws)
|
||||
|
||||
# Do nothing
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to serve - {ws.client}")
|
||||
# Log tracebacks to keep track of what errors are happening
|
||||
logger.exception(e)
|
||||
finally:
|
||||
await channel_manager.on_disconnect(ws)
|
||||
|
@ -245,6 +245,7 @@ class ApiServer(RPCHandler):
|
||||
use_colors=False,
|
||||
log_config=None,
|
||||
access_log=True if verbosity != 'error' else False,
|
||||
ws_ping_interval=None # We do this explicitly ourselves
|
||||
)
|
||||
try:
|
||||
self._server = UvicornServer(uvconfig)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from threading import RLock
|
||||
from typing import List, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import WebSocket as FastAPIWebSocket
|
||||
@ -34,6 +35,8 @@ class WebSocketChannel:
|
||||
self._serializer_cls = serializer_cls
|
||||
|
||||
self._subscriptions: List[str] = []
|
||||
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
|
||||
self._relay_task = asyncio.create_task(self.relay())
|
||||
|
||||
# Internal event to signify a closed websocket
|
||||
self._closed = False
|
||||
@ -72,6 +75,7 @@ class WebSocketChannel:
|
||||
"""
|
||||
|
||||
self._closed = True
|
||||
self._relay_task.cancel()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
@ -95,6 +99,20 @@ class WebSocketChannel:
|
||||
"""
|
||||
return message_type in self._subscriptions
|
||||
|
||||
async def relay(self):
|
||||
"""
|
||||
Relay messages from the channel's queue and send them out. This is started
|
||||
as a task.
|
||||
"""
|
||||
while True:
|
||||
message = await self.queue.get()
|
||||
try:
|
||||
await self.send(message)
|
||||
self.queue.task_done()
|
||||
except RuntimeError:
|
||||
# The connection was closed, just exit the task
|
||||
return
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
def __init__(self):
|
||||
@ -155,12 +173,12 @@ class ChannelManager:
|
||||
with self._lock:
|
||||
message_type = data.get('type')
|
||||
for websocket, channel in self.channels.copy().items():
|
||||
try:
|
||||
if channel.subscribed_to(message_type):
|
||||
await channel.send(data)
|
||||
except RuntimeError:
|
||||
# Handle cannot send after close cases
|
||||
await self.on_disconnect(websocket)
|
||||
if channel.subscribed_to(message_type):
|
||||
if not channel.queue.full():
|
||||
channel.queue.put_nowait(data)
|
||||
else:
|
||||
logger.info(f"Channel {channel} is too far behind, disconnecting")
|
||||
await self.on_disconnect(websocket)
|
||||
|
||||
async def send_direct(self, channel, data):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user