better error handling, true async sending, more readable api

This commit is contained in:
Timothy Pogue
2022-11-18 13:32:27 -07:00
parent ba493eb7a7
commit 0cb6f71c02
4 changed files with 88 additions and 71 deletions

View File

@@ -29,6 +29,7 @@ class WebSocketChannel:
# Internal event to signify a closed websocket
self._closed = asyncio.Event()
self._send_timeout_high_limit = 2
# The subscribed message types
self._subscriptions: List[str] = []
@@ -36,6 +37,9 @@ class WebSocketChannel:
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket)
# The async tasks created for the channel
self._channel_tasks: List[asyncio.Task] = []
def __repr__(self):
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
@@ -51,7 +55,14 @@ class WebSocketChannel:
"""
Send a message on the wrapped websocket
"""
await self._wrapped_ws.send(message)
# Without this sleep, messages would send to one channel
# first then another after the first one finished.
# With the sleep call, it gives control to the event
# loop to schedule other channel send methods.
await asyncio.sleep(0)
return await self._wrapped_ws.send(message)
async def recv(self):
"""
@@ -77,7 +88,6 @@ class WebSocketChannel:
"""
self._closed.set()
self._relay_task.cancel()
try:
await self._websocket.close()
@@ -106,23 +116,68 @@ class WebSocketChannel:
"""
return message_type in self._subscriptions
async def run_channel_tasks(self, *tasks, **kwargs):
"""
Create and await on the channel tasks unless an exception
was raised, then cancel them all.
:params *tasks: All coros or tasks to be run concurrently
:param **kwargs: Any extra kwargs to pass to gather
"""
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
try:
await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception:
# If an exception occurred, cancel the rest of the tasks and bubble up
# the error that was caught here
await self.cancel_channel_tasks()
raise
async def cancel_channel_tasks(self):
"""
Cancel and wait on all channel tasks
"""
for task in self._channel_tasks:
task.cancel()
# Wait for tasks to finish cancelling
try:
await asyncio.wait(self._channel_tasks)
except asyncio.CancelledError:
pass
self._channel_tasks = []
async def __aiter__(self):
"""
Generator for received messages
"""
while True:
try:
yield await self.recv()
except Exception:
break
# We can not catch any errors here as websocket.recv is
# the first to catch any disconnects and bubble it up
# so the connection is garbage collected right away
while not self.is_closed():
yield await self.recv()
@asynccontextmanager
async def connect(self):
"""
Context manager for safely opening and closing the websocket connection
"""
try:
await self.accept()
yield self
finally:
await self.close()
@asynccontextmanager
async def create_channel(websocket: WebSocketType, **kwargs):
"""
Context manager for safely opening and closing a WebSocketChannel
"""
channel = WebSocketChannel(websocket, **kwargs)
try:
await channel.accept()
logger.info(f"Connected to channel - {channel}")
yield channel
except Exception:
pass
finally:
await channel.close()
logger.info(f"Disconnected from channel - {channel}")

View File

@@ -17,7 +17,8 @@ class MessageStream:
async def subscribe(self):
waiter = self._waiter
while True:
message, waiter = await waiter
# Shield the future from being cancelled by a task waiting on it
message, waiter = await asyncio.shield(waiter)
yield message
__aiter__ = subscribe