diff --git a/freqtrade/persistence.py b/freqtrade/persistence.py index f9a7d1e3c..63c29dc4f 100644 --- a/freqtrade/persistence.py +++ b/freqtrade/persistence.py @@ -10,13 +10,11 @@ from typing import Dict, Optional, Any import arrow from sqlalchemy import (Boolean, Column, DateTime, Float, Integer, String, create_engine) -from sqlalchemy.engine import Engine +from sqlalchemy import inspect from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm.scoping import scoped_session from sqlalchemy.orm.session import sessionmaker from sqlalchemy.pool import StaticPool -from sqlalchemy import inspect - logger = logging.getLogger(__name__) @@ -24,30 +22,30 @@ _CONF = {} _DECL_BASE: Any = declarative_base() -def init(config: dict, engine: Optional[Engine] = None) -> None: +def init(config: Dict) -> None: """ Initializes this module with the given config, registers all known command handlers and starts polling for message updates :param config: config to use - :param engine: database engine for sqlalchemy (Optional) :return: None """ _CONF.update(config) - if not engine: - if _CONF.get('dry_run', False): - # the user wants dry run to use a DB - if _CONF.get('dry_run_db', False): - engine = create_engine('sqlite:///tradesv3.dry_run.sqlite') - # Otherwise dry run will store in memory - else: - engine = create_engine('sqlite://', - connect_args={'check_same_thread': False}, - poolclass=StaticPool, - echo=False) - else: - engine = create_engine('sqlite:///tradesv3.sqlite') + db_url = _CONF.get('db_url', None) + kwargs = {} + + if not db_url and _CONF.get('dry_run', False): + # Default to in-memory db if not specified + # and take care of thread ownership if in-memory db + db_url = 'sqlite://' + kwargs.update({ + 'connect_args': {'check_same_thread': False}, + 'poolclass': StaticPool, + 'echo': False, + }) + + engine = create_engine(db_url, **kwargs) session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True)) Trade.session = session() Trade.query = session.query_property() @@ -55,7 +53,7 @@ def init(config: dict, engine: Optional[Engine] = None) -> None: check_migrate(engine) # Clean dry_run DB - if _CONF.get('dry_run', False) and _CONF.get('dry_run_db', False): + if _CONF.get('dry_run', False) and db_url != 'sqlite://': clean_dry_run_db()