Use combination of thread-local and asyncio-aware session context
This commit is contained in:
parent
b0a7b64d44
commit
62c8dd98d5
@ -2,7 +2,9 @@
|
|||||||
This module contains the class to persist trades into SQLite
|
This module contains the class to persist trades into SQLite
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
import threading
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Dict, Final, Optional
|
||||||
|
|
||||||
from sqlalchemy import create_engine, inspect
|
from sqlalchemy import create_engine, inspect
|
||||||
from sqlalchemy.exc import NoSuchModuleError
|
from sqlalchemy.exc import NoSuchModuleError
|
||||||
@ -19,6 +21,22 @@ from freqtrade.persistence.trade_model import Order, Trade
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
REQUEST_ID_CTX_KEY: Final[str] = 'request_id'
|
||||||
|
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_or_thread_id() -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Helper method to get either async context (for fastapi requests), or thread id
|
||||||
|
"""
|
||||||
|
id = _request_id_ctx_var.get()
|
||||||
|
if id is None:
|
||||||
|
# when not in request context - use thread id
|
||||||
|
id = str(threading.current_thread().ident)
|
||||||
|
|
||||||
|
return id
|
||||||
|
|
||||||
|
|
||||||
_SQL_DOCS_URL = 'http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls'
|
_SQL_DOCS_URL = 'http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls'
|
||||||
|
|
||||||
|
|
||||||
@ -53,8 +71,9 @@ def init_db(db_url: str) -> None:
|
|||||||
|
|
||||||
# https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope
|
# https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope
|
||||||
# Scoped sessions proxy requests to the appropriate thread-local session.
|
# Scoped sessions proxy requests to the appropriate thread-local session.
|
||||||
# We should use the scoped_session object - not a seperately initialized version
|
# Since we also use fastAPI, we need to make it aware of the request id, too
|
||||||
Trade.session = scoped_session(sessionmaker(bind=engine, autoflush=False))
|
Trade.session = scoped_session(sessionmaker(
|
||||||
|
bind=engine, autoflush=False), scopefunc=get_request_or_thread_id)
|
||||||
Order.session = Trade.session
|
Order.session = Trade.session
|
||||||
PairLock.session = Trade.session
|
PairLock.session = Trade.session
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from typing import Any, Dict, Iterator, Optional
|
from typing import Any, Dict, Iterator, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
from freqtrade.enums import RunMode
|
from freqtrade.enums import RunMode
|
||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
|
from freqtrade.persistence.models import _request_id_ctx_var
|
||||||
from freqtrade.rpc.rpc import RPC, RPCException
|
from freqtrade.rpc.rpc import RPC, RPCException
|
||||||
|
|
||||||
from .webserver import ApiServer
|
from .webserver import ApiServer
|
||||||
@ -15,12 +17,19 @@ def get_rpc_optional() -> Optional[RPC]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_rpc() -> Optional[Iterator[RPC]]:
|
async def get_rpc() -> Optional[Iterator[RPC]]:
|
||||||
|
|
||||||
_rpc = get_rpc_optional()
|
_rpc = get_rpc_optional()
|
||||||
if _rpc:
|
if _rpc:
|
||||||
|
request_id = str(uuid4())
|
||||||
|
ctx_token = _request_id_ctx_var.set(request_id)
|
||||||
Trade.rollback()
|
Trade.rollback()
|
||||||
|
try:
|
||||||
yield _rpc
|
yield _rpc
|
||||||
Trade.rollback()
|
finally:
|
||||||
|
Trade.session.remove()
|
||||||
|
_request_id_ctx_var.reset(ctx_token)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RPCException('Bot is not in the correct state')
|
raise RPCException('Bot is not in the correct state')
|
||||||
|
|
||||||
|
@ -674,7 +674,8 @@ def test_monthly_handle(default_conf_usdt, update, ticker, fee, mocker, time_mac
|
|||||||
assert str('Monthly Profit over the last 6 months</b>:') in msg_mock.call_args_list[0][0][0]
|
assert str('Monthly Profit over the last 6 months</b>:') in msg_mock.call_args_list[0][0][0]
|
||||||
|
|
||||||
|
|
||||||
def test_profit_handle(default_conf_usdt, update, ticker_usdt, ticker_sell_up, fee,
|
def test_telegram_profit_handle(
|
||||||
|
default_conf_usdt, update, ticker_usdt, ticker_sell_up, fee,
|
||||||
limit_sell_order_usdt, mocker) -> None:
|
limit_sell_order_usdt, mocker) -> None:
|
||||||
mocker.patch('freqtrade.rpc.rpc.CryptoToFiatConverter._find_price', return_value=1.1)
|
mocker.patch('freqtrade.rpc.rpc.CryptoToFiatConverter._find_price', return_value=1.1)
|
||||||
mocker.patch.multiple(
|
mocker.patch.multiple(
|
||||||
@ -710,6 +711,7 @@ def test_profit_handle(default_conf_usdt, update, ticker_usdt, ticker_sell_up, f
|
|||||||
# Update the ticker with a market going up
|
# Update the ticker with a market going up
|
||||||
mocker.patch(f'{EXMS}.fetch_ticker', ticker_sell_up)
|
mocker.patch(f'{EXMS}.fetch_ticker', ticker_sell_up)
|
||||||
# Simulate fulfilled LIMIT_SELL order for trade
|
# Simulate fulfilled LIMIT_SELL order for trade
|
||||||
|
trade = Trade.session.scalars(select(Trade)).first()
|
||||||
oobj = Order.parse_from_ccxt_object(
|
oobj = Order.parse_from_ccxt_object(
|
||||||
limit_sell_order_usdt, limit_sell_order_usdt['symbol'], 'sell')
|
limit_sell_order_usdt, limit_sell_order_usdt['symbol'], 'sell')
|
||||||
trade.orders.append(oobj)
|
trade.orders.append(oobj)
|
||||||
|
Loading…
Reference in New Issue
Block a user