close ws channel if can't accept

This commit is contained in:
Timothy Pogue 2022-11-24 11:35:50 -07:00
parent 48242ca02b
commit 101dec461e
1 changed files with 31 additions and 25 deletions

View File

@ -125,9 +125,14 @@ class WebSocketChannel:
async def accept(self): async def accept(self):
""" """
Accept the underlying websocket connection Accept the underlying websocket connection,
if the connection has been closed before we can
accept, just close the channel.
""" """
return await self._websocket.accept() try:
return await self._websocket.accept()
except RuntimeError:
await self.close()
async def close(self): async def close(self):
""" """
@ -172,17 +177,18 @@ class WebSocketChannel:
:param **kwargs: Any extra kwargs to pass to gather :param **kwargs: Any extra kwargs to pass to gather
""" """
# Wrap the coros into tasks if they aren't already if not self.is_closed():
self._channel_tasks = [ # Wrap the coros into tasks if they aren't already
task if isinstance(task, asyncio.Task) else asyncio.create_task(task) self._channel_tasks = [
for task in tasks task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
] for task in tasks
]
try: try:
return await asyncio.gather(*self._channel_tasks, **kwargs) return await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception: except Exception:
# If an exception occurred, cancel the rest of the tasks # If an exception occurred, cancel the rest of the tasks
await self.cancel_channel_tasks() await self.cancel_channel_tasks()
async def cancel_channel_tasks(self): async def cancel_channel_tasks(self):
""" """
@ -191,19 +197,19 @@ class WebSocketChannel:
for task in self._channel_tasks: for task in self._channel_tasks:
task.cancel() task.cancel()
# Wait for tasks to finish cancelling # Wait for tasks to finish cancelling
try: try:
await task await task
except ( except (
asyncio.CancelledError, asyncio.CancelledError,
asyncio.TimeoutError, asyncio.TimeoutError,
WebSocketDisconnect, WebSocketDisconnect,
ConnectionClosed, ConnectionClosed,
RuntimeError RuntimeError
): ):
pass pass
except Exception as e: except Exception as e:
logger.info(f"Encountered unknown exception: {e}", exc_info=e) logger.info(f"Encountered unknown exception: {e}", exc_info=e)
self._channel_tasks = [] self._channel_tasks = []