refactor broadcasting to a queue per client
This commit is contained in:
parent
a10b2d003f
commit
3e8d8fd1b0
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import websockets
|
||||||
from fastapi import APIRouter, Depends, WebSocketDisconnect
|
from fastapi import APIRouter, Depends, WebSocketDisconnect
|
||||||
from fastapi.websockets import WebSocket, WebSocketState
|
from fastapi.websockets import WebSocket, WebSocketState
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@ -102,7 +103,6 @@ async def message_endpoint(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
channel = await channel_manager.on_connect(ws)
|
channel = await channel_manager.on_connect(ws)
|
||||||
|
|
||||||
if await is_websocket_alive(ws):
|
if await is_websocket_alive(ws):
|
||||||
|
|
||||||
logger.info(f"Consumer connected - {channel}")
|
logger.info(f"Consumer connected - {channel}")
|
||||||
@ -115,26 +115,34 @@ async def message_endpoint(
|
|||||||
# Process the request here
|
# Process the request here
|
||||||
await _process_consumer_request(request, channel, rpc)
|
await _process_consumer_request(request, channel, rpc)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except (
|
||||||
|
WebSocketDisconnect,
|
||||||
|
websockets.exceptions.WebSocketException
|
||||||
|
):
|
||||||
# Handle client disconnects
|
# Handle client disconnects
|
||||||
logger.info(f"Consumer disconnected - {channel}")
|
logger.info(f"Consumer disconnected - {channel}")
|
||||||
await channel_manager.on_disconnect(ws)
|
except RuntimeError:
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"Consumer connection failed - {channel}")
|
|
||||||
logger.exception(e)
|
|
||||||
# Handle cases like -
|
# Handle cases like -
|
||||||
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Consumer connection failed - {channel}")
|
||||||
|
logger.debug(e, exc_info=e)
|
||||||
|
finally:
|
||||||
await channel_manager.on_disconnect(ws)
|
await channel_manager.on_disconnect(ws)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
if channel:
|
||||||
|
await channel_manager.on_disconnect(ws)
|
||||||
await ws.close()
|
await ws.close()
|
||||||
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# WebSocket was closed
|
# WebSocket was closed
|
||||||
await channel_manager.on_disconnect(ws)
|
# Do nothing
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to serve - {ws.client}")
|
logger.error(f"Failed to serve - {ws.client}")
|
||||||
# Log tracebacks to keep track of what errors are happening
|
# Log tracebacks to keep track of what errors are happening
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
finally:
|
||||||
await channel_manager.on_disconnect(ws)
|
await channel_manager.on_disconnect(ws)
|
||||||
|
@ -245,6 +245,7 @@ class ApiServer(RPCHandler):
|
|||||||
use_colors=False,
|
use_colors=False,
|
||||||
log_config=None,
|
log_config=None,
|
||||||
access_log=True if verbosity != 'error' else False,
|
access_log=True if verbosity != 'error' else False,
|
||||||
|
ws_ping_interval=None # We do this explicitly ourselves
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self._server = UvicornServer(uvconfig)
|
self._server = UvicornServer(uvconfig)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import WebSocket as FastAPIWebSocket
|
from fastapi import WebSocket as FastAPIWebSocket
|
||||||
@ -34,6 +35,8 @@ class WebSocketChannel:
|
|||||||
self._serializer_cls = serializer_cls
|
self._serializer_cls = serializer_cls
|
||||||
|
|
||||||
self._subscriptions: List[str] = []
|
self._subscriptions: List[str] = []
|
||||||
|
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
|
||||||
|
self._relay_task = asyncio.create_task(self.relay())
|
||||||
|
|
||||||
# Internal event to signify a closed websocket
|
# Internal event to signify a closed websocket
|
||||||
self._closed = False
|
self._closed = False
|
||||||
@ -72,6 +75,7 @@ class WebSocketChannel:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self._closed = True
|
self._closed = True
|
||||||
|
self._relay_task.cancel()
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -95,6 +99,20 @@ class WebSocketChannel:
|
|||||||
"""
|
"""
|
||||||
return message_type in self._subscriptions
|
return message_type in self._subscriptions
|
||||||
|
|
||||||
|
async def relay(self):
|
||||||
|
"""
|
||||||
|
Relay messages from the channel's queue and send them out. This is started
|
||||||
|
as a task.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
message = await self.queue.get()
|
||||||
|
try:
|
||||||
|
await self.send(message)
|
||||||
|
self.queue.task_done()
|
||||||
|
except RuntimeError:
|
||||||
|
# The connection was closed, just exit the task
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class ChannelManager:
|
class ChannelManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -155,11 +173,11 @@ class ChannelManager:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
message_type = data.get('type')
|
message_type = data.get('type')
|
||||||
for websocket, channel in self.channels.copy().items():
|
for websocket, channel in self.channels.copy().items():
|
||||||
try:
|
|
||||||
if channel.subscribed_to(message_type):
|
if channel.subscribed_to(message_type):
|
||||||
await channel.send(data)
|
if not channel.queue.full():
|
||||||
except RuntimeError:
|
channel.queue.put_nowait(data)
|
||||||
# Handle cannot send after close cases
|
else:
|
||||||
|
logger.info(f"Channel {channel} is too far behind, disconnecting")
|
||||||
await self.on_disconnect(websocket)
|
await self.on_disconnect(websocket)
|
||||||
|
|
||||||
async def send_direct(self, channel, data):
|
async def send_direct(self, channel, data):
|
||||||
|
Loading…
Reference in New Issue
Block a user