persistence: simplify init and pass db_url via config dict

This commit is contained in:
gcarq 2018-06-07 05:25:53 +02:00
parent 5c1ee52815
commit 8583e89550
1 changed files with 17 additions and 19 deletions

View File

@ -10,13 +10,11 @@ from typing import Dict, Optional, Any
import arrow import arrow
from sqlalchemy import (Boolean, Column, DateTime, Float, Integer, String, from sqlalchemy import (Boolean, Column, DateTime, Float, Integer, String,
create_engine) create_engine)
from sqlalchemy.engine import Engine from sqlalchemy import inspect
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.scoping import scoped_session from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.session import sessionmaker from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlalchemy import inspect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,30 +22,30 @@ _CONF = {}
_DECL_BASE: Any = declarative_base() _DECL_BASE: Any = declarative_base()
def init(config: dict, engine: Optional[Engine] = None) -> None: def init(config: Dict) -> None:
""" """
Initializes this module with the given config, Initializes this module with the given config,
registers all known command handlers registers all known command handlers
and starts polling for message updates and starts polling for message updates
:param config: config to use :param config: config to use
:param engine: database engine for sqlalchemy (Optional)
:return: None :return: None
""" """
_CONF.update(config) _CONF.update(config)
if not engine:
if _CONF.get('dry_run', False):
# the user wants dry run to use a DB
if _CONF.get('dry_run_db', False):
engine = create_engine('sqlite:///tradesv3.dry_run.sqlite')
# Otherwise dry run will store in memory
else:
engine = create_engine('sqlite://',
connect_args={'check_same_thread': False},
poolclass=StaticPool,
echo=False)
else:
engine = create_engine('sqlite:///tradesv3.sqlite')
db_url = _CONF.get('db_url', None)
kwargs = {}
if not db_url and _CONF.get('dry_run', False):
# Default to in-memory db if not specified
# and take care of thread ownership if in-memory db
db_url = 'sqlite://'
kwargs.update({
'connect_args': {'check_same_thread': False},
'poolclass': StaticPool,
'echo': False,
})
engine = create_engine(db_url, **kwargs)
session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True)) session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True))
Trade.session = session() Trade.session = session()
Trade.query = session.query_property() Trade.query = session.query_property()
@ -55,7 +53,7 @@ def init(config: dict, engine: Optional[Engine] = None) -> None:
check_migrate(engine) check_migrate(engine)
# Clean dry_run DB # Clean dry_run DB
if _CONF.get('dry_run', False) and _CONF.get('dry_run_db', False): if _CONF.get('dry_run', False) and db_url != 'sqlite://':
clean_dry_run_db() clean_dry_run_db()