persistence: simplify init and pass db_url via config dict

This commit is contained in:
gcarq 2018-06-07 05:25:53 +02:00
parent 5c1ee52815
commit 8583e89550

View File

@ -10,13 +10,11 @@ from typing import Dict, Optional, Any
import arrow
from sqlalchemy import (Boolean, Column, DateTime, Float, Integer, String,
create_engine)
from sqlalchemy.engine import Engine
from sqlalchemy import inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.pool import StaticPool
from sqlalchemy import inspect
logger = logging.getLogger(__name__)
@ -24,30 +22,30 @@ _CONF = {}
_DECL_BASE: Any = declarative_base()
def init(config: dict, engine: Optional[Engine] = None) -> None:
def init(config: Dict) -> None:
"""
Initializes this module with the given config,
registers all known command handlers
and starts polling for message updates
:param config: config to use
:param engine: database engine for sqlalchemy (Optional)
:return: None
"""
_CONF.update(config)
if not engine:
if _CONF.get('dry_run', False):
# the user wants dry run to use a DB
if _CONF.get('dry_run_db', False):
engine = create_engine('sqlite:///tradesv3.dry_run.sqlite')
# Otherwise dry run will store in memory
else:
engine = create_engine('sqlite://',
connect_args={'check_same_thread': False},
poolclass=StaticPool,
echo=False)
else:
engine = create_engine('sqlite:///tradesv3.sqlite')
db_url = _CONF.get('db_url', None)
kwargs = {}
if not db_url and _CONF.get('dry_run', False):
# Default to in-memory db if not specified
# and take care of thread ownership if in-memory db
db_url = 'sqlite://'
kwargs.update({
'connect_args': {'check_same_thread': False},
'poolclass': StaticPool,
'echo': False,
})
engine = create_engine(db_url, **kwargs)
session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True))
Trade.session = session()
Trade.query = session.query_property()
@ -55,7 +53,7 @@ def init(config: dict, engine: Optional[Engine] = None) -> None:
check_migrate(engine)
# Clean dry_run DB
if _CONF.get('dry_run', False) and _CONF.get('dry_run_db', False):
if _CONF.get('dry_run', False) and db_url != 'sqlite://':
clean_dry_run_db()