Merge pull request #7692 from wizrds/fix/ws-memory
Fix Memory Leak in Websockets
This commit is contained in:
commit
0aff8c4823
@ -127,13 +127,6 @@ async def message_endpoint(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Consumer connection failed - {channel}: {e}")
|
logger.info(f"Consumer connection failed - {channel}: {e}")
|
||||||
logger.debug(e, exc_info=e)
|
logger.debug(e, exc_info=e)
|
||||||
finally:
|
|
||||||
await channel_manager.on_disconnect(ws)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if channel:
|
|
||||||
await channel_manager.on_disconnect(ws)
|
|
||||||
await ws.close()
|
|
||||||
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# WebSocket was closed
|
# WebSocket was closed
|
||||||
@ -144,4 +137,5 @@ async def message_endpoint(
|
|||||||
# Log tracebacks to keep track of what errors are happening
|
# Log tracebacks to keep track of what errors are happening
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
finally:
|
finally:
|
||||||
await channel_manager.on_disconnect(ws)
|
if channel:
|
||||||
|
await channel_manager.on_disconnect(ws)
|
||||||
|
@ -197,6 +197,7 @@ class ApiServer(RPCHandler):
|
|||||||
# Get data from queue
|
# Get data from queue
|
||||||
message: WSMessageSchemaType = await async_queue.get()
|
message: WSMessageSchemaType = await async_queue.get()
|
||||||
logger.debug(f"Found message of type: {message.get('type')}")
|
logger.debug(f"Found message of type: {message.get('type')}")
|
||||||
|
async_queue.task_done()
|
||||||
# Broadcast it
|
# Broadcast it
|
||||||
await self._ws_channel_manager.broadcast(message)
|
await self._ws_channel_manager.broadcast(message)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -210,6 +211,9 @@ class ApiServer(RPCHandler):
|
|||||||
# Disconnect channels and stop the loop on cancel
|
# Disconnect channels and stop the loop on cancel
|
||||||
await self._ws_channel_manager.disconnect_all()
|
await self._ws_channel_manager.disconnect_all()
|
||||||
self._ws_loop.stop()
|
self._ws_loop.stop()
|
||||||
|
# Avoid adding more items to the queue if they aren't
|
||||||
|
# going to get broadcasted.
|
||||||
|
self._ws_queue = None
|
||||||
|
|
||||||
def start_api(self):
|
def start_api(self):
|
||||||
"""
|
"""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@ -46,7 +47,7 @@ class WebSocketChannel:
|
|||||||
self._relay_task = asyncio.create_task(self.relay())
|
self._relay_task = asyncio.create_task(self.relay())
|
||||||
|
|
||||||
# Internal event to signify a closed websocket
|
# Internal event to signify a closed websocket
|
||||||
self._closed = False
|
self._closed = asyncio.Event()
|
||||||
|
|
||||||
# Wrap the WebSocket in the Serializing class
|
# Wrap the WebSocket in the Serializing class
|
||||||
self._wrapped_ws = self._serializer_cls(self._websocket)
|
self._wrapped_ws = self._serializer_cls(self._websocket)
|
||||||
@ -73,15 +74,26 @@ class WebSocketChannel:
|
|||||||
Add the data to the queue to be sent.
|
Add the data to the queue to be sent.
|
||||||
:returns: True if data added to queue, False otherwise
|
:returns: True if data added to queue, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# This block only runs if the queue is full, it will wait
|
||||||
|
# until self.drain_timeout for the relay to drain the outgoing queue
|
||||||
|
# We can't use asyncio.wait_for here because the queue may have been created with a
|
||||||
|
# different eventloop
|
||||||
|
start = time.time()
|
||||||
|
while self.queue.full():
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
if (time.time() - start) > self.drain_timeout:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If for some reason the queue is still full, just return False
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
self.queue.put_nowait(data)
|
||||||
self.queue.put(data),
|
except asyncio.QueueFull:
|
||||||
timeout=self.drain_timeout
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# If we got here everything is ok
|
||||||
|
return True
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
"""
|
"""
|
||||||
Receive data on the wrapped websocket
|
Receive data on the wrapped websocket
|
||||||
@ -99,14 +111,19 @@ class WebSocketChannel:
|
|||||||
Close the WebSocketChannel
|
Close the WebSocketChannel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._closed = True
|
try:
|
||||||
|
await self.raw_websocket.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._closed.set()
|
||||||
self._relay_task.cancel()
|
self._relay_task.cancel()
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Closed flag
|
Closed flag
|
||||||
"""
|
"""
|
||||||
return self._closed
|
return self._closed.is_set()
|
||||||
|
|
||||||
def set_subscriptions(self, subscriptions: List[str] = []) -> None:
|
def set_subscriptions(self, subscriptions: List[str] = []) -> None:
|
||||||
"""
|
"""
|
||||||
@ -129,7 +146,7 @@ class WebSocketChannel:
|
|||||||
Relay messages from the channel's queue and send them out. This is started
|
Relay messages from the channel's queue and send them out. This is started
|
||||||
as a task.
|
as a task.
|
||||||
"""
|
"""
|
||||||
while True:
|
while not self._closed.is_set():
|
||||||
message = await self.queue.get()
|
message = await self.queue.get()
|
||||||
try:
|
try:
|
||||||
await self._send(message)
|
await self._send(message)
|
||||||
|
@ -264,10 +264,10 @@ class ExternalMessageConsumer:
|
|||||||
# We haven't received data yet. Check the connection and continue.
|
# We haven't received data yet. Check the connection and continue.
|
||||||
try:
|
try:
|
||||||
# ping
|
# ping
|
||||||
ping = await channel.ping()
|
pong = await channel.ping()
|
||||||
|
latency = (await asyncio.wait_for(pong, timeout=self.ping_timeout) * 1000)
|
||||||
|
|
||||||
await asyncio.wait_for(ping, timeout=self.ping_timeout)
|
logger.info(f"Connection to {channel} still alive, latency: {latency}ms")
|
||||||
logger.debug(f"Connection to {channel} still alive...")
|
|
||||||
|
|
||||||
continue
|
continue
|
||||||
except (websockets.exceptions.ConnectionClosed):
|
except (websockets.exceptions.ConnectionClosed):
|
||||||
@ -276,7 +276,7 @@ class ExternalMessageConsumer:
|
|||||||
await asyncio.sleep(self.sleep_time)
|
await asyncio.sleep(self.sleep_time)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s")
|
logger.warning(f"Ping error {channel} - {e} - retrying in {self.sleep_time}s")
|
||||||
logger.debug(e, exc_info=e)
|
logger.debug(e, exc_info=e)
|
||||||
await asyncio.sleep(self.sleep_time)
|
await asyncio.sleep(self.sleep_time)
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ import orjson
|
|||||||
import pandas
|
import pandas
|
||||||
import rapidjson
|
import rapidjson
|
||||||
import websockets
|
import websockets
|
||||||
from dateutil.relativedelta import relativedelta
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("WebSocketClient")
|
logger = logging.getLogger("WebSocketClient")
|
||||||
@ -28,7 +27,7 @@ logger = logging.getLogger("WebSocketClient")
|
|||||||
|
|
||||||
def setup_logging(filename: str):
|
def setup_logging(filename: str):
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.DEBUG,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.FileHandler(filename),
|
logging.FileHandler(filename),
|
||||||
@ -75,16 +74,15 @@ def load_config(configfile):
|
|||||||
|
|
||||||
def readable_timedelta(delta):
|
def readable_timedelta(delta):
|
||||||
"""
|
"""
|
||||||
Convert a dateutil.relativedelta to a readable format
|
Convert a millisecond delta to a readable format
|
||||||
|
|
||||||
:param delta: A dateutil.relativedelta
|
:param delta: A delta between two timestamps in milliseconds
|
||||||
:returns: The readable time difference string
|
:returns: The readable time difference string
|
||||||
"""
|
"""
|
||||||
attrs = ['years', 'months', 'days', 'hours', 'minutes', 'seconds', 'microseconds']
|
seconds, milliseconds = divmod(delta, 1000)
|
||||||
return ", ".join([
|
minutes, seconds = divmod(seconds, 60)
|
||||||
'%d %s' % (getattr(delta, attr), attr if getattr(delta, attr) > 0 else attr[:-1])
|
|
||||||
for attr in attrs if getattr(delta, attr)
|
return f"{int(minutes)}:{int(seconds)}.{int(milliseconds)}"
|
||||||
])
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
@ -170,8 +168,8 @@ class ClientProtocol:
|
|||||||
|
|
||||||
def _calculate_time_difference(self):
|
def _calculate_time_difference(self):
|
||||||
old_last_received_at = self._LAST_RECEIVED_AT
|
old_last_received_at = self._LAST_RECEIVED_AT
|
||||||
self._LAST_RECEIVED_AT = time.time() * 1e6
|
self._LAST_RECEIVED_AT = time.time() * 1e3
|
||||||
time_delta = relativedelta(microseconds=(self._LAST_RECEIVED_AT - old_last_received_at))
|
time_delta = self._LAST_RECEIVED_AT - old_last_received_at
|
||||||
|
|
||||||
return readable_timedelta(time_delta)
|
return readable_timedelta(time_delta)
|
||||||
|
|
||||||
@ -242,12 +240,10 @@ async def create_client(
|
|||||||
):
|
):
|
||||||
# Try pinging
|
# Try pinging
|
||||||
try:
|
try:
|
||||||
pong = ws.ping()
|
pong = await ws.ping()
|
||||||
await asyncio.wait_for(
|
latency = (await asyncio.wait_for(pong, timeout=ping_timeout) * 1000)
|
||||||
pong,
|
|
||||||
timeout=ping_timeout
|
logger.info(f"Connection still alive, latency: {latency}ms")
|
||||||
)
|
|
||||||
logger.info("Connection still alive...")
|
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -272,6 +268,7 @@ async def create_client(
|
|||||||
websockets.exceptions.ConnectionClosedError,
|
websockets.exceptions.ConnectionClosedError,
|
||||||
websockets.exceptions.ConnectionClosedOK
|
websockets.exceptions.ConnectionClosedOK
|
||||||
):
|
):
|
||||||
|
logger.info("Connection was closed")
|
||||||
# Just keep trying to connect again indefinitely
|
# Just keep trying to connect again indefinitely
|
||||||
await asyncio.sleep(sleep_time)
|
await asyncio.sleep(sleep_time)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user