keep Trade.session private

This commit is contained in:
Matthias 2023-03-02 06:55:33 +01:00
parent 8103656ae1
commit 103bd9e2f2
5 changed files with 13 additions and 13 deletions

View File

@ -20,7 +20,7 @@ def start_convert_db(args: Dict[str, Any]) -> None:
config = setup_utils_configuration(args, RunMode.UTIL_NO_EXCHANGE) config = setup_utils_configuration(args, RunMode.UTIL_NO_EXCHANGE)
init_db(config['db_url']) init_db(config['db_url'])
session_target = Trade.session session_target = Trade._session
init_db(config['db_url_from']) init_db(config['db_url_from'])
logger.info("Starting db migration.") logger.info("Starting db migration.")

View File

@ -54,12 +54,12 @@ def init_db(db_url: str) -> None:
# https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope # https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope
# Scoped sessions proxy requests to the appropriate thread-local session. # Scoped sessions proxy requests to the appropriate thread-local session.
# We should use the scoped_session object - not a seperately initialized version # We should use the scoped_session object - not a seperately initialized version
Trade.session = scoped_session(sessionmaker(bind=engine, autoflush=False)) Trade._session = scoped_session(sessionmaker(bind=engine, autoflush=False))
Order.session = Trade.session Order._session = Trade._session
PairLock.session = Trade.session PairLock._session = Trade._session
Trade.query = Trade.session.query_property() Trade.query = Trade._session.query_property()
Order.query = Trade.session.query_property() Order.query = Trade._session.query_property()
PairLock.query = Trade.session.query_property() PairLock.query = Trade._session.query_property()
previous_tables = inspect(engine).get_table_names() previous_tables = inspect(engine).get_table_names()
ModelBase.metadata.create_all(engine) ModelBase.metadata.create_all(engine)

View File

@ -15,7 +15,7 @@ class PairLock(ModelBase):
""" """
__tablename__ = 'pairlocks' __tablename__ = 'pairlocks'
query: ClassVar[_QueryDescriptorType] query: ClassVar[_QueryDescriptorType]
session: ClassVar[SessionType] _session: ClassVar[SessionType]
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)

View File

@ -37,7 +37,7 @@ class Order(ModelBase):
""" """
__tablename__ = 'orders' __tablename__ = 'orders'
query: ClassVar[_QueryDescriptorType] query: ClassVar[_QueryDescriptorType]
session: ClassVar[SessionType] _session: ClassVar[SessionType]
# Uniqueness should be ensured over pair, order_id # Uniqueness should be ensured over pair, order_id
# its likely that order_id is unique per Pair on some exchanges. # its likely that order_id is unique per Pair on some exchanges.
@ -1179,7 +1179,7 @@ class Trade(ModelBase, LocalTrade):
""" """
__tablename__ = 'trades' __tablename__ = 'trades'
query: ClassVar[_QueryDescriptorType] query: ClassVar[_QueryDescriptorType]
session: ClassVar[SessionType] _session: ClassVar[SessionType]
use_db: bool = True use_db: bool = True

View File

@ -21,8 +21,8 @@ spot, margin, futures = TradingMode.SPOT, TradingMode.MARGIN, TradingMode.FUTURE
def test_init_create_session(default_conf): def test_init_create_session(default_conf):
# Check if init create a session # Check if init create a session
init_db(default_conf['db_url']) init_db(default_conf['db_url'])
assert hasattr(Trade, 'session') assert hasattr(Trade, '_session')
assert 'scoped_session' in type(Trade.session).__name__ assert 'scoped_session' in type(Trade._session).__name__
def test_init_custom_db_url(default_conf, tmpdir): def test_init_custom_db_url(default_conf, tmpdir):
@ -34,7 +34,7 @@ def test_init_custom_db_url(default_conf, tmpdir):
init_db(default_conf['db_url']) init_db(default_conf['db_url'])
assert Path(filename).is_file() assert Path(filename).is_file()
r = Trade.session.execute(text("PRAGMA journal_mode")) r = Trade._session.execute(text("PRAGMA journal_mode"))
assert r.first() == ('wal',) assert r.first() == ('wal',)