Use combination of thread-local and asyncio-aware session context

This commit is contained in:
Matthias 2023-03-17 20:44:00 +01:00
parent b0a7b64d44
commit 62c8dd98d5
3 changed files with 38 additions and 8 deletions

View File

@ -2,7 +2,9 @@
This module contains the class to persist trades into SQLite This module contains the class to persist trades into SQLite
""" """
import logging import logging
from typing import Any, Dict import threading
from contextvars import ContextVar
from typing import Any, Dict, Final, Optional
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
from sqlalchemy.exc import NoSuchModuleError from sqlalchemy.exc import NoSuchModuleError
@ -19,6 +21,22 @@ from freqtrade.persistence.trade_model import Order, Trade
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUEST_ID_CTX_KEY: Final[str] = 'request_id'
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None)
def get_request_or_thread_id() -> Optional[str]:
"""
Helper method to get either async context (for fastapi requests), or thread id
"""
id = _request_id_ctx_var.get()
if id is None:
# when not in request context - use thread id
id = str(threading.current_thread().ident)
return id
_SQL_DOCS_URL = 'http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls' _SQL_DOCS_URL = 'http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls'
@ -53,8 +71,9 @@ def init_db(db_url: str) -> None:
# https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope # https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope
# Scoped sessions proxy requests to the appropriate thread-local session. # Scoped sessions proxy requests to the appropriate thread-local session.
# We should use the scoped_session object - not a seperately initialized version # Since we also use fastAPI, we need to make it aware of the request id, too
Trade.session = scoped_session(sessionmaker(bind=engine, autoflush=False)) Trade.session = scoped_session(sessionmaker(
bind=engine, autoflush=False), scopefunc=get_request_or_thread_id)
Order.session = Trade.session Order.session = Trade.session
PairLock.session = Trade.session PairLock.session = Trade.session

View File

@ -1,9 +1,11 @@
from typing import Any, Dict, Iterator, Optional from typing import Any, Dict, Iterator, Optional
from uuid import uuid4
from fastapi import Depends from fastapi import Depends
from freqtrade.enums import RunMode from freqtrade.enums import RunMode
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
from freqtrade.persistence.models import _request_id_ctx_var
from freqtrade.rpc.rpc import RPC, RPCException from freqtrade.rpc.rpc import RPC, RPCException
from .webserver import ApiServer from .webserver import ApiServer
@ -15,12 +17,19 @@ def get_rpc_optional() -> Optional[RPC]:
return None return None
def get_rpc() -> Optional[Iterator[RPC]]: async def get_rpc() -> Optional[Iterator[RPC]]:
_rpc = get_rpc_optional() _rpc = get_rpc_optional()
if _rpc: if _rpc:
request_id = str(uuid4())
ctx_token = _request_id_ctx_var.set(request_id)
Trade.rollback() Trade.rollback()
yield _rpc try:
Trade.rollback() yield _rpc
finally:
Trade.session.remove()
_request_id_ctx_var.reset(ctx_token)
else: else:
raise RPCException('Bot is not in the correct state') raise RPCException('Bot is not in the correct state')

View File

@ -674,8 +674,9 @@ def test_monthly_handle(default_conf_usdt, update, ticker, fee, mocker, time_mac
assert str('Monthly Profit over the last 6 months</b>:') in msg_mock.call_args_list[0][0][0] assert str('Monthly Profit over the last 6 months</b>:') in msg_mock.call_args_list[0][0][0]
def test_profit_handle(default_conf_usdt, update, ticker_usdt, ticker_sell_up, fee, def test_telegram_profit_handle(
limit_sell_order_usdt, mocker) -> None: default_conf_usdt, update, ticker_usdt, ticker_sell_up, fee,
limit_sell_order_usdt, mocker) -> None:
mocker.patch('freqtrade.rpc.rpc.CryptoToFiatConverter._find_price', return_value=1.1) mocker.patch('freqtrade.rpc.rpc.CryptoToFiatConverter._find_price', return_value=1.1)
mocker.patch.multiple( mocker.patch.multiple(
EXMS, EXMS,
@ -710,6 +711,7 @@ def test_profit_handle(default_conf_usdt, update, ticker_usdt, ticker_sell_up, f
# Update the ticker with a market going up # Update the ticker with a market going up
mocker.patch(f'{EXMS}.fetch_ticker', ticker_sell_up) mocker.patch(f'{EXMS}.fetch_ticker', ticker_sell_up)
# Simulate fulfilled LIMIT_SELL order for trade # Simulate fulfilled LIMIT_SELL order for trade
trade = Trade.session.scalars(select(Trade)).first()
oobj = Order.parse_from_ccxt_object( oobj = Order.parse_from_ccxt_object(
limit_sell_order_usdt, limit_sell_order_usdt['symbol'], 'sell') limit_sell_order_usdt, limit_sell_order_usdt['symbol'], 'sell')
trade.orders.append(oobj) trade.orders.append(oobj)