diff --git a/freqtrade/persistence.py b/freqtrade/persistence.py index b5ec4c0ec..8c23cf713 100644 --- a/freqtrade/persistence.py +++ b/freqtrade/persistence.py @@ -5,9 +5,11 @@ from typing import Optional, Dict import arrow from sqlalchemy import Boolean, Column, DateTime, Float, Integer, String, create_engine +from sqlalchemy.engine import Engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm.scoping import scoped_session from sqlalchemy.orm.session import sessionmaker +from sqlalchemy.pool import StaticPool logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -17,23 +19,25 @@ _CONF = {} _DECL_BASE = declarative_base() -def init(config: dict, db_url: Optional[str] = None) -> None: +def init(config: dict, engine: Optional[Engine] = None) -> None: """ Initializes this module with the given config, registers all known command handlers and starts polling for message updates :param config: config to use - :param db_url: database connector string for sqlalchemy (Optional) + :param engine: database engine for sqlalchemy (Optional) :return: None """ _CONF.update(config) - if not db_url: + if not engine: if _CONF.get('dry_run', False): - db_url = 'sqlite://' + engine = create_engine('sqlite://', + connect_args={'check_same_thread': False}, + poolclass=StaticPool, + echo=False) else: - db_url = 'sqlite:///tradesv3.sqlite' + engine = create_engine('sqlite:///tradesv3.sqlite') - engine = create_engine(db_url, echo=False) session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True)) Trade.session = session() Trade.query = session.query_property() diff --git a/freqtrade/tests/test_main.py b/freqtrade/tests/test_main.py index 151ecaabc..f114b4dde 100644 --- a/freqtrade/tests/test_main.py +++ b/freqtrade/tests/test_main.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock import pytest import requests +from sqlalchemy import create_engine from freqtrade.exchange import Exchanges from freqtrade.main import create_trade, handle_trade, close_trade_if_fulfilled, init, \ @@ -20,7 +21,7 @@ def test_process_trade_creation(default_conf, ticker, mocker): validate_pairs=MagicMock(), get_ticker=ticker, buy=MagicMock(return_value='mocked_limit_buy')) - init(default_conf, 'sqlite://') + init(default_conf, create_engine('sqlite://')) trades = Trade.query.filter(Trade.is_open.is_(True)).all() assert len(trades) == 0 @@ -49,7 +50,7 @@ def test_process_exchange_failures(default_conf, ticker, mocker): validate_pairs=MagicMock(), get_ticker=ticker, buy=MagicMock(side_effect=requests.exceptions.RequestException)) - init(default_conf, 'sqlite://') + init(default_conf, create_engine('sqlite://')) result = _process() assert result is False assert sleep_mock.has_calls() @@ -64,7 +65,7 @@ def test_process_runtime_error(default_conf, ticker, mocker): validate_pairs=MagicMock(), get_ticker=ticker, buy=MagicMock(side_effect=RuntimeError)) - init(default_conf, 'sqlite://') + init(default_conf, create_engine('sqlite://')) assert get_state() == State.RUNNING result = _process() @@ -82,7 +83,7 @@ def test_process_trade_handling(default_conf, ticker, limit_buy_order, mocker): get_ticker=ticker, buy=MagicMock(return_value='mocked_limit_buy'), get_order=MagicMock(return_value=limit_buy_order)) - init(default_conf, 'sqlite://') + init(default_conf, create_engine('sqlite://')) trades = Trade.query.filter(Trade.is_open.is_(True)).all() assert len(trades) == 0 @@ -106,7 +107,7 @@ def test_create_trade(default_conf, ticker, limit_buy_order, mocker): # Save state of current whitelist whitelist = copy.deepcopy(default_conf['exchange']['pair_whitelist']) - init(default_conf, 'sqlite://') + init(default_conf, create_engine('sqlite://')) trade = create_trade(15.0) Trade.session.add(trade) Trade.session.flush() @@ -167,7 +168,7 @@ def test_handle_trade(default_conf, limit_buy_order, limit_sell_order, mocker): }), buy=MagicMock(return_value='mocked_limit_buy'), sell=MagicMock(return_value='mocked_limit_sell')) - init(default_conf, 'sqlite://') + init(default_conf, create_engine('sqlite://')) trade = create_trade(15.0) trade.update(limit_buy_order) Trade.session.add(trade) @@ -197,7 +198,7 @@ def test_close_trade(default_conf, ticker, limit_buy_order, limit_sell_order, mo buy=MagicMock(return_value='mocked_limit_buy')) # Create trade and sell it - init(default_conf, 'sqlite://') + init(default_conf, create_engine('sqlite://')) trade = create_trade(15.0) trade.update(limit_buy_order) trade.update(limit_sell_order)