better error handling, true async sending, more readable api
This commit is contained in:
@@ -29,6 +29,7 @@ class WebSocketChannel:
|
||||
|
||||
# Internal event to signify a closed websocket
|
||||
self._closed = asyncio.Event()
|
||||
self._send_timeout_high_limit = 2
|
||||
|
||||
# The subscribed message types
|
||||
self._subscriptions: List[str] = []
|
||||
@@ -36,6 +37,9 @@ class WebSocketChannel:
|
||||
# Wrap the WebSocket in the Serializing class
|
||||
self._wrapped_ws = serializer_cls(self._websocket)
|
||||
|
||||
# The async tasks created for the channel
|
||||
self._channel_tasks: List[asyncio.Task] = []
|
||||
|
||||
def __repr__(self):
|
||||
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
|
||||
|
||||
@@ -51,7 +55,14 @@ class WebSocketChannel:
|
||||
"""
|
||||
Send a message on the wrapped websocket
|
||||
"""
|
||||
await self._wrapped_ws.send(message)
|
||||
|
||||
# Without this sleep, messages would send to one channel
|
||||
# first then another after the first one finished.
|
||||
# With the sleep call, it gives control to the event
|
||||
# loop to schedule other channel send methods.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return await self._wrapped_ws.send(message)
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
@@ -77,7 +88,6 @@ class WebSocketChannel:
|
||||
"""
|
||||
|
||||
self._closed.set()
|
||||
self._relay_task.cancel()
|
||||
|
||||
try:
|
||||
await self._websocket.close()
|
||||
@@ -106,23 +116,68 @@ class WebSocketChannel:
|
||||
"""
|
||||
return message_type in self._subscriptions
|
||||
|
||||
async def run_channel_tasks(self, *tasks, **kwargs):
|
||||
"""
|
||||
Create and await on the channel tasks unless an exception
|
||||
was raised, then cancel them all.
|
||||
|
||||
:params *tasks: All coros or tasks to be run concurrently
|
||||
:param **kwargs: Any extra kwargs to pass to gather
|
||||
"""
|
||||
|
||||
# Wrap the coros into tasks if they aren't already
|
||||
self._channel_tasks = [
|
||||
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self._channel_tasks, **kwargs)
|
||||
except Exception:
|
||||
# If an exception occurred, cancel the rest of the tasks and bubble up
|
||||
# the error that was caught here
|
||||
await self.cancel_channel_tasks()
|
||||
raise
|
||||
|
||||
async def cancel_channel_tasks(self):
|
||||
"""
|
||||
Cancel and wait on all channel tasks
|
||||
"""
|
||||
for task in self._channel_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to finish cancelling
|
||||
try:
|
||||
await asyncio.wait(self._channel_tasks)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._channel_tasks = []
|
||||
|
||||
async def __aiter__(self):
|
||||
"""
|
||||
Generator for received messages
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
yield await self.recv()
|
||||
except Exception:
|
||||
break
|
||||
# We can not catch any errors here as websocket.recv is
|
||||
# the first to catch any disconnects and bubble it up
|
||||
# so the connection is garbage collected right away
|
||||
while not self.is_closed():
|
||||
yield await self.recv()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self):
|
||||
"""
|
||||
Context manager for safely opening and closing the websocket connection
|
||||
"""
|
||||
try:
|
||||
await self.accept()
|
||||
yield self
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_channel(websocket: WebSocketType, **kwargs):
|
||||
"""
|
||||
Context manager for safely opening and closing a WebSocketChannel
|
||||
"""
|
||||
channel = WebSocketChannel(websocket, **kwargs)
|
||||
try:
|
||||
await channel.accept()
|
||||
logger.info(f"Connected to channel - {channel}")
|
||||
|
||||
yield channel
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await channel.close()
|
||||
logger.info(f"Disconnected from channel - {channel}")
|
||||
|
@@ -17,7 +17,8 @@ class MessageStream:
|
||||
async def subscribe(self):
|
||||
waiter = self._waiter
|
||||
while True:
|
||||
message, waiter = await waiter
|
||||
# Shield the future from being cancelled by a task waiting on it
|
||||
message, waiter = await asyncio.shield(waiter)
|
||||
yield message
|
||||
|
||||
__aiter__ = subscribe
|
||||
|
Reference in New Issue
Block a user