diff --git a/freqtrade/persistence/models.py b/freqtrade/persistence/models.py index eee07e61c..2315c0acc 100644 --- a/freqtrade/persistence/models.py +++ b/freqtrade/persistence/models.py @@ -2,7 +2,9 @@ This module contains the class to persist trades into SQLite """ import logging -from typing import Any, Dict +import threading +from contextvars import ContextVar +from typing import Any, Dict, Final, Optional from sqlalchemy import create_engine, inspect from sqlalchemy.exc import NoSuchModuleError @@ -19,6 +21,22 @@ from freqtrade.persistence.trade_model import Order, Trade logger = logging.getLogger(__name__) +REQUEST_ID_CTX_KEY: Final[str] = 'request_id' +_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None) + + +def get_request_or_thread_id() -> Optional[str]: + """ + Helper method to get either async context (for fastapi requests), or thread id + """ + id = _request_id_ctx_var.get() + if id is None: + # when not in request context - use thread id + id = str(threading.current_thread().ident) + + return id + + _SQL_DOCS_URL = 'http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls' @@ -53,8 +71,9 @@ def init_db(db_url: str) -> None: # https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope # Scoped sessions proxy requests to the appropriate thread-local session. - # We should use the scoped_session object - not a seperately initialized version - Trade.session = scoped_session(sessionmaker(bind=engine, autoflush=False)) + # Since we also use fastAPI, we need to make it aware of the request id, too + Trade.session = scoped_session(sessionmaker( + bind=engine, autoflush=False), scopefunc=get_request_or_thread_id) Order.session = Trade.session PairLock.session = Trade.session diff --git a/freqtrade/rpc/api_server/deps.py b/freqtrade/rpc/api_server/deps.py index aed97367b..eb41d728d 100644 --- a/freqtrade/rpc/api_server/deps.py +++ b/freqtrade/rpc/api_server/deps.py @@ -1,9 +1,11 @@ from typing import Any, Dict, Iterator, Optional +from uuid import uuid4 from fastapi import Depends from freqtrade.enums import RunMode from freqtrade.persistence import Trade +from freqtrade.persistence.models import _request_id_ctx_var from freqtrade.rpc.rpc import RPC, RPCException from .webserver import ApiServer @@ -15,12 +17,19 @@ def get_rpc_optional() -> Optional[RPC]: return None -def get_rpc() -> Optional[Iterator[RPC]]: +async def get_rpc() -> Optional[Iterator[RPC]]: + _rpc = get_rpc_optional() if _rpc: + request_id = str(uuid4()) + ctx_token = _request_id_ctx_var.set(request_id) Trade.rollback() - yield _rpc - Trade.rollback() + try: + yield _rpc + finally: + Trade.session.remove() + _request_id_ctx_var.reset(ctx_token) + else: raise RPCException('Bot is not in the correct state') diff --git a/tests/rpc/test_rpc_telegram.py b/tests/rpc/test_rpc_telegram.py index b1859f581..521e3b66d 100644 --- a/tests/rpc/test_rpc_telegram.py +++ b/tests/rpc/test_rpc_telegram.py @@ -674,8 +674,9 @@ def test_monthly_handle(default_conf_usdt, update, ticker, fee, mocker, time_mac assert str('Monthly Profit over the last 6 months:') in msg_mock.call_args_list[0][0][0] -def test_profit_handle(default_conf_usdt, update, ticker_usdt, ticker_sell_up, fee, - limit_sell_order_usdt, mocker) -> None: +def test_telegram_profit_handle( + default_conf_usdt, update, ticker_usdt, ticker_sell_up, fee, + limit_sell_order_usdt, mocker) -> None: mocker.patch('freqtrade.rpc.rpc.CryptoToFiatConverter._find_price', return_value=1.1) mocker.patch.multiple( EXMS, @@ -710,6 +711,7 @@ def test_profit_handle(default_conf_usdt, update, ticker_usdt, ticker_sell_up, f # Update the ticker with a market going up mocker.patch(f'{EXMS}.fetch_ticker', ticker_sell_up) # Simulate fulfilled LIMIT_SELL order for trade + trade = Trade.session.scalars(select(Trade)).first() oobj = Order.parse_from_ccxt_object( limit_sell_order_usdt, limit_sell_order_usdt['symbol'], 'sell') trade.orders.append(oobj)