Use combination of thread-local and asyncio-aware session context

This commit is contained in:
Matthias 2023-03-17 20:44:00 +01:00
parent b0a7b64d44
commit 62c8dd98d5
3 changed files with 38 additions and 8 deletions

View File

@ -2,7 +2,9 @@
This module contains the class to persist trades into SQLite
"""
import logging
from typing import Any, Dict
import threading
from contextvars import ContextVar
from typing import Any, Dict, Final, Optional
from sqlalchemy import create_engine, inspect
from sqlalchemy.exc import NoSuchModuleError
@ -19,6 +21,22 @@ from freqtrade.persistence.trade_model import Order, Trade
logger = logging.getLogger(__name__)
REQUEST_ID_CTX_KEY: Final[str] = 'request_id'
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None)
def get_request_or_thread_id() -> Optional[str]:
"""
Helper method to get either async context (for fastapi requests), or thread id
"""
id = _request_id_ctx_var.get()
if id is None:
# when not in request context - use thread id
id = str(threading.current_thread().ident)
return id
_SQL_DOCS_URL = 'http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls'
@ -53,8 +71,9 @@ def init_db(db_url: str) -> None:
# https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope
# Scoped sessions proxy requests to the appropriate thread-local session.
# We should use the scoped_session object - not a seperately initialized version
Trade.session = scoped_session(sessionmaker(bind=engine, autoflush=False))
# Since we also use fastAPI, we need to make it aware of the request id, too
Trade.session = scoped_session(sessionmaker(
bind=engine, autoflush=False), scopefunc=get_request_or_thread_id)
Order.session = Trade.session
PairLock.session = Trade.session

View File

@ -1,9 +1,11 @@
from typing import Any, Dict, Iterator, Optional
from uuid import uuid4
from fastapi import Depends
from freqtrade.enums import RunMode
from freqtrade.persistence import Trade
from freqtrade.persistence.models import _request_id_ctx_var
from freqtrade.rpc.rpc import RPC, RPCException
from .webserver import ApiServer
@ -15,12 +17,19 @@ def get_rpc_optional() -> Optional[RPC]:
return None
def get_rpc() -> Optional[Iterator[RPC]]:
async def get_rpc() -> Optional[Iterator[RPC]]:
_rpc = get_rpc_optional()
if _rpc:
request_id = str(uuid4())
ctx_token = _request_id_ctx_var.set(request_id)
Trade.rollback()
yield _rpc
Trade.rollback()
try:
yield _rpc
finally:
Trade.session.remove()
_request_id_ctx_var.reset(ctx_token)
else:
raise RPCException('Bot is not in the correct state')

View File

@ -674,8 +674,9 @@ def test_monthly_handle(default_conf_usdt, update, ticker, fee, mocker, time_mac
assert str('Monthly Profit over the last 6 months</b>:') in msg_mock.call_args_list[0][0][0]
def test_profit_handle(default_conf_usdt, update, ticker_usdt, ticker_sell_up, fee,
limit_sell_order_usdt, mocker) -> None:
def test_telegram_profit_handle(
default_conf_usdt, update, ticker_usdt, ticker_sell_up, fee,
limit_sell_order_usdt, mocker) -> None:
mocker.patch('freqtrade.rpc.rpc.CryptoToFiatConverter._find_price', return_value=1.1)
mocker.patch.multiple(
EXMS,
@ -710,6 +711,7 @@ def test_profit_handle(default_conf_usdt, update, ticker_usdt, ticker_sell_up, f
# Update the ticker with a market going up
mocker.patch(f'{EXMS}.fetch_ticker', ticker_sell_up)
# Simulate fulfilled LIMIT_SELL order for trade
trade = Trade.session.scalars(select(Trade)).first()
oobj = Order.parse_from_ccxt_object(
limit_sell_order_usdt, limit_sell_order_usdt['symbol'], 'sell')
trade.orders.append(oobj)