Merge pull request #7692 from wizrds/fix/ws-memory

Fix Memory Leak in Websockets
This commit is contained in:
Matthias 2022-11-03 07:17:01 +01:00 committed by GitHub
commit 0aff8c4823
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 39 deletions

View File

@ -127,13 +127,6 @@ async def message_endpoint(
except Exception as e: except Exception as e:
logger.info(f"Consumer connection failed - {channel}: {e}") logger.info(f"Consumer connection failed - {channel}: {e}")
logger.debug(e, exc_info=e) logger.debug(e, exc_info=e)
finally:
await channel_manager.on_disconnect(ws)
else:
if channel:
await channel_manager.on_disconnect(ws)
await ws.close()
except RuntimeError: except RuntimeError:
# WebSocket was closed # WebSocket was closed
@ -144,4 +137,5 @@ async def message_endpoint(
# Log tracebacks to keep track of what errors are happening # Log tracebacks to keep track of what errors are happening
logger.exception(e) logger.exception(e)
finally: finally:
if channel:
await channel_manager.on_disconnect(ws) await channel_manager.on_disconnect(ws)

View File

@ -197,6 +197,7 @@ class ApiServer(RPCHandler):
# Get data from queue # Get data from queue
message: WSMessageSchemaType = await async_queue.get() message: WSMessageSchemaType = await async_queue.get()
logger.debug(f"Found message of type: {message.get('type')}") logger.debug(f"Found message of type: {message.get('type')}")
async_queue.task_done()
# Broadcast it # Broadcast it
await self._ws_channel_manager.broadcast(message) await self._ws_channel_manager.broadcast(message)
except asyncio.CancelledError: except asyncio.CancelledError:
@ -210,6 +211,9 @@ class ApiServer(RPCHandler):
# Disconnect channels and stop the loop on cancel # Disconnect channels and stop the loop on cancel
await self._ws_channel_manager.disconnect_all() await self._ws_channel_manager.disconnect_all()
self._ws_loop.stop() self._ws_loop.stop()
# Avoid adding more items to the queue if they aren't
# going to get broadcasted.
self._ws_queue = None
def start_api(self): def start_api(self):
""" """

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import time
from threading import RLock from threading import RLock
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4 from uuid import uuid4
@ -46,7 +47,7 @@ class WebSocketChannel:
self._relay_task = asyncio.create_task(self.relay()) self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket # Internal event to signify a closed websocket
self._closed = False self._closed = asyncio.Event()
# Wrap the WebSocket in the Serializing class # Wrap the WebSocket in the Serializing class
self._wrapped_ws = self._serializer_cls(self._websocket) self._wrapped_ws = self._serializer_cls(self._websocket)
@ -73,15 +74,26 @@ class WebSocketChannel:
Add the data to the queue to be sent. Add the data to the queue to be sent.
:returns: True if data added to queue, False otherwise :returns: True if data added to queue, False otherwise
""" """
try:
await asyncio.wait_for( # This block only runs if the queue is full, it will wait
self.queue.put(data), # until self.drain_timeout for the relay to drain the outgoing queue
timeout=self.drain_timeout # We can't use asyncio.wait_for here because the queue may have been created with a
) # different eventloop
return True start = time.time()
except asyncio.TimeoutError: while self.queue.full():
await asyncio.sleep(1)
if (time.time() - start) > self.drain_timeout:
return False return False
# If for some reason the queue is still full, just return False
try:
self.queue.put_nowait(data)
except asyncio.QueueFull:
return False
# If we got here everything is ok
return True
async def recv(self): async def recv(self):
""" """
Receive data on the wrapped websocket Receive data on the wrapped websocket
@ -99,14 +111,19 @@ class WebSocketChannel:
Close the WebSocketChannel Close the WebSocketChannel
""" """
self._closed = True try:
await self.raw_websocket.close()
except Exception:
pass
self._closed.set()
self._relay_task.cancel() self._relay_task.cancel()
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
Closed flag Closed flag
""" """
return self._closed return self._closed.is_set()
def set_subscriptions(self, subscriptions: List[str] = []) -> None: def set_subscriptions(self, subscriptions: List[str] = []) -> None:
""" """
@ -129,7 +146,7 @@ class WebSocketChannel:
Relay messages from the channel's queue and send them out. This is started Relay messages from the channel's queue and send them out. This is started
as a task. as a task.
""" """
while True: while not self._closed.is_set():
message = await self.queue.get() message = await self.queue.get()
try: try:
await self._send(message) await self._send(message)

View File

@ -264,10 +264,10 @@ class ExternalMessageConsumer:
# We haven't received data yet. Check the connection and continue. # We haven't received data yet. Check the connection and continue.
try: try:
# ping # ping
ping = await channel.ping() pong = await channel.ping()
latency = (await asyncio.wait_for(pong, timeout=self.ping_timeout) * 1000)
await asyncio.wait_for(ping, timeout=self.ping_timeout) logger.info(f"Connection to {channel} still alive, latency: {latency}ms")
logger.debug(f"Connection to {channel} still alive...")
continue continue
except (websockets.exceptions.ConnectionClosed): except (websockets.exceptions.ConnectionClosed):
@ -276,7 +276,7 @@ class ExternalMessageConsumer:
await asyncio.sleep(self.sleep_time) await asyncio.sleep(self.sleep_time)
break break
except Exception as e: except Exception as e:
logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s") logger.warning(f"Ping error {channel} - {e} - retrying in {self.sleep_time}s")
logger.debug(e, exc_info=e) logger.debug(e, exc_info=e)
await asyncio.sleep(self.sleep_time) await asyncio.sleep(self.sleep_time)

View File

@ -18,7 +18,6 @@ import orjson
import pandas import pandas
import rapidjson import rapidjson
import websockets import websockets
from dateutil.relativedelta import relativedelta
logger = logging.getLogger("WebSocketClient") logger = logging.getLogger("WebSocketClient")
@ -28,7 +27,7 @@ logger = logging.getLogger("WebSocketClient")
def setup_logging(filename: str): def setup_logging(filename: str):
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[ handlers=[
logging.FileHandler(filename), logging.FileHandler(filename),
@ -75,16 +74,15 @@ def load_config(configfile):
def readable_timedelta(delta): def readable_timedelta(delta):
""" """
Convert a dateutil.relativedelta to a readable format Convert a millisecond delta to a readable format
:param delta: A dateutil.relativedelta :param delta: A delta between two timestamps in milliseconds
:returns: The readable time difference string :returns: The readable time difference string
""" """
attrs = ['years', 'months', 'days', 'hours', 'minutes', 'seconds', 'microseconds'] seconds, milliseconds = divmod(delta, 1000)
return ", ".join([ minutes, seconds = divmod(seconds, 60)
'%d %s' % (getattr(delta, attr), attr if getattr(delta, attr) > 0 else attr[:-1])
for attr in attrs if getattr(delta, attr) return f"{int(minutes)}:{int(seconds)}.{int(milliseconds)}"
])
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
@ -170,8 +168,8 @@ class ClientProtocol:
def _calculate_time_difference(self): def _calculate_time_difference(self):
old_last_received_at = self._LAST_RECEIVED_AT old_last_received_at = self._LAST_RECEIVED_AT
self._LAST_RECEIVED_AT = time.time() * 1e6 self._LAST_RECEIVED_AT = time.time() * 1e3
time_delta = relativedelta(microseconds=(self._LAST_RECEIVED_AT - old_last_received_at)) time_delta = self._LAST_RECEIVED_AT - old_last_received_at
return readable_timedelta(time_delta) return readable_timedelta(time_delta)
@ -242,12 +240,10 @@ async def create_client(
): ):
# Try pinging # Try pinging
try: try:
pong = ws.ping() pong = await ws.ping()
await asyncio.wait_for( latency = (await asyncio.wait_for(pong, timeout=ping_timeout) * 1000)
pong,
timeout=ping_timeout logger.info(f"Connection still alive, latency: {latency}ms")
)
logger.info("Connection still alive...")
continue continue
@ -272,6 +268,7 @@ async def create_client(
websockets.exceptions.ConnectionClosedError, websockets.exceptions.ConnectionClosedError,
websockets.exceptions.ConnectionClosedOK websockets.exceptions.ConnectionClosedOK
): ):
logger.info("Connection was closed")
# Just keep trying to connect again indefinitely # Just keep trying to connect again indefinitely
await asyncio.sleep(sleep_time) await asyncio.sleep(sleep_time)