Merge pull request #7692 from wizrds/fix/ws-memory

Fix Memory Leak in Websockets
This commit is contained in:
Matthias 2022-11-03 07:17:01 +01:00 committed by GitHub
commit 0aff8c4823
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 39 deletions

View File

@ -127,13 +127,6 @@ async def message_endpoint(
except Exception as e:
logger.info(f"Consumer connection failed - {channel}: {e}")
logger.debug(e, exc_info=e)
finally:
await channel_manager.on_disconnect(ws)
else:
if channel:
await channel_manager.on_disconnect(ws)
await ws.close()
except RuntimeError:
# WebSocket was closed
@ -144,4 +137,5 @@ async def message_endpoint(
# Log tracebacks to keep track of what errors are happening
logger.exception(e)
finally:
await channel_manager.on_disconnect(ws)
if channel:
await channel_manager.on_disconnect(ws)

View File

@ -197,6 +197,7 @@ class ApiServer(RPCHandler):
# Get data from queue
message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.get('type')}")
async_queue.task_done()
# Broadcast it
await self._ws_channel_manager.broadcast(message)
except asyncio.CancelledError:
@ -210,6 +211,9 @@ class ApiServer(RPCHandler):
# Disconnect channels and stop the loop on cancel
await self._ws_channel_manager.disconnect_all()
self._ws_loop.stop()
# Avoid adding more items to the queue if they aren't
# going to get broadcasted.
self._ws_queue = None
def start_api(self):
"""

View File

@ -1,5 +1,6 @@
import asyncio
import logging
import time
from threading import RLock
from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4
@ -46,7 +47,7 @@ class WebSocketChannel:
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket
self._closed = False
self._closed = asyncio.Event()
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = self._serializer_cls(self._websocket)
@ -73,15 +74,26 @@ class WebSocketChannel:
Add the data to the queue to be sent.
:returns: True if data added to queue, False otherwise
"""
# This block only runs if the queue is full, it will wait
# until self.drain_timeout for the relay to drain the outgoing queue
# We can't use asyncio.wait_for here because the queue may have been created with a
# different eventloop
start = time.time()
while self.queue.full():
await asyncio.sleep(1)
if (time.time() - start) > self.drain_timeout:
return False
# If for some reason the queue is still full, just return False
try:
await asyncio.wait_for(
self.queue.put(data),
timeout=self.drain_timeout
)
return True
except asyncio.TimeoutError:
self.queue.put_nowait(data)
except asyncio.QueueFull:
return False
# If we got here everything is ok
return True
async def recv(self):
"""
Receive data on the wrapped websocket
@ -99,14 +111,19 @@ class WebSocketChannel:
Close the WebSocketChannel
"""
self._closed = True
try:
await self.raw_websocket.close()
except Exception:
pass
self._closed.set()
self._relay_task.cancel()
def is_closed(self) -> bool:
"""
Closed flag
"""
return self._closed
return self._closed.is_set()
def set_subscriptions(self, subscriptions: List[str] = []) -> None:
"""
@ -129,7 +146,7 @@ class WebSocketChannel:
Relay messages from the channel's queue and send them out. This is started
as a task.
"""
while True:
while not self._closed.is_set():
message = await self.queue.get()
try:
await self._send(message)

View File

@ -264,10 +264,10 @@ class ExternalMessageConsumer:
# We haven't received data yet. Check the connection and continue.
try:
# ping
ping = await channel.ping()
pong = await channel.ping()
latency = (await asyncio.wait_for(pong, timeout=self.ping_timeout) * 1000)
await asyncio.wait_for(ping, timeout=self.ping_timeout)
logger.debug(f"Connection to {channel} still alive...")
logger.info(f"Connection to {channel} still alive, latency: {latency}ms")
continue
except (websockets.exceptions.ConnectionClosed):
@ -276,7 +276,7 @@ class ExternalMessageConsumer:
await asyncio.sleep(self.sleep_time)
break
except Exception as e:
logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s")
logger.warning(f"Ping error {channel} - {e} - retrying in {self.sleep_time}s")
logger.debug(e, exc_info=e)
await asyncio.sleep(self.sleep_time)

View File

@ -18,7 +18,6 @@ import orjson
import pandas
import rapidjson
import websockets
from dateutil.relativedelta import relativedelta
logger = logging.getLogger("WebSocketClient")
@ -28,7 +27,7 @@ logger = logging.getLogger("WebSocketClient")
def setup_logging(filename: str):
logging.basicConfig(
level=logging.INFO,
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(filename),
@ -75,16 +74,15 @@ def load_config(configfile):
def readable_timedelta(delta):
"""
Convert a dateutil.relativedelta to a readable format
Convert a millisecond delta to a readable format
:param delta: A dateutil.relativedelta
:param delta: A delta between two timestamps in milliseconds
:returns: The readable time difference string
"""
attrs = ['years', 'months', 'days', 'hours', 'minutes', 'seconds', 'microseconds']
return ", ".join([
'%d %s' % (getattr(delta, attr), attr if getattr(delta, attr) > 0 else attr[:-1])
for attr in attrs if getattr(delta, attr)
])
seconds, milliseconds = divmod(delta, 1000)
minutes, seconds = divmod(seconds, 60)
return f"{int(minutes)}:{int(seconds)}.{int(milliseconds)}"
# ----------------------------------------------------------------------------
@ -170,8 +168,8 @@ class ClientProtocol:
def _calculate_time_difference(self):
old_last_received_at = self._LAST_RECEIVED_AT
self._LAST_RECEIVED_AT = time.time() * 1e6
time_delta = relativedelta(microseconds=(self._LAST_RECEIVED_AT - old_last_received_at))
self._LAST_RECEIVED_AT = time.time() * 1e3
time_delta = self._LAST_RECEIVED_AT - old_last_received_at
return readable_timedelta(time_delta)
@ -242,12 +240,10 @@ async def create_client(
):
# Try pinging
try:
pong = ws.ping()
await asyncio.wait_for(
pong,
timeout=ping_timeout
)
logger.info("Connection still alive...")
pong = await ws.ping()
latency = (await asyncio.wait_for(pong, timeout=ping_timeout) * 1000)
logger.info(f"Connection still alive, latency: {latency}ms")
continue
@ -272,6 +268,7 @@ async def create_client(
websockets.exceptions.ConnectionClosedError,
websockets.exceptions.ConnectionClosedOK
):
logger.info("Connection was closed")
# Just keep trying to connect again indefinitely
await asyncio.sleep(sleep_time)