close ws channel if can't accept

This commit is contained in:
Timothy Pogue 2022-11-24 11:35:50 -07:00
parent 48242ca02b
commit 101dec461e
1 changed files with 31 additions and 25 deletions

View File

@ -125,9 +125,14 @@ class WebSocketChannel:
async def accept(self):
"""
Accept the underlying websocket connection
Accept the underlying websocket connection,
if the connection has been closed before we can
accept, just close the channel.
"""
return await self._websocket.accept()
try:
return await self._websocket.accept()
except RuntimeError:
await self.close()
async def close(self):
"""
@ -172,17 +177,18 @@ class WebSocketChannel:
:param **kwargs: Any extra kwargs to pass to gather
"""
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
if not self.is_closed():
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
try:
return await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception:
# If an exception occurred, cancel the rest of the tasks
await self.cancel_channel_tasks()
try:
return await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception:
# If an exception occurred, cancel the rest of the tasks
await self.cancel_channel_tasks()
async def cancel_channel_tasks(self):
"""
@ -191,19 +197,19 @@ class WebSocketChannel:
for task in self._channel_tasks:
task.cancel()
# Wait for tasks to finish cancelling
try:
await task
except (
asyncio.CancelledError,
asyncio.TimeoutError,
WebSocketDisconnect,
ConnectionClosed,
RuntimeError
):
pass
except Exception as e:
logger.info(f"Encountered unknown exception: {e}", exc_info=e)
# Wait for tasks to finish cancelling
try:
await task
except (
asyncio.CancelledError,
asyncio.TimeoutError,
WebSocketDisconnect,
ConnectionClosed,
RuntimeError
):
pass
except Exception as e:
logger.info(f"Encountered unknown exception: {e}", exc_info=e)
self._channel_tasks = []