diff --git a/main.py b/main.py index af6fed2b3..974a20bcc 100755 --- a/main.py +++ b/main.py @@ -230,17 +230,18 @@ def create_trade(stake_amount: float, _exchange: exchange.Exchange) -> Optional[ is_open=True) -def init(config: dict) -> None: +def init(config: dict, db_url: Optional[str]=None) -> None: """ Initializes all modules and updates the config :param config: config as dict + :param db_url: database connector string for sqlalchemy (Optional) :return: None """ global _conf # Initialize all modules telegram.init(config) - persistence.init(config) + persistence.init(config, db_url) exchange.init(config) _conf.update(config) diff --git a/persistence.py b/persistence.py index ce1a1999a..d85a2a27c 100644 --- a/persistence.py +++ b/persistence.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional from sqlalchemy import Boolean, Column, DateTime, Float, Integer, String, create_engine from sqlalchemy.ext.declarative import declarative_base @@ -10,29 +11,28 @@ from sqlalchemy.types import Enum import exchange -_db_handle = None _conf = {} - Base = declarative_base() -def init(config: dict) -> None: +def init(config: dict, db_url: Optional[str]=None) -> None: """ Initializes this module with the given config, registers all known command handlers and starts polling for message updates :param config: config to use + :param db_url: database connector string for sqlalchemy (Optional) :return: None """ - global _db_handle _conf.update(config) - if _conf.get('dry_run', False): - _db_handle = 'sqlite:///tradesv2.dry_run.sqlite' - else: - _db_handle = 'sqlite:///tradesv2.sqlite' + if not db_url: + if _conf.get('dry_run', False): + db_url = 'sqlite:///tradesv2.dry_run.sqlite' + else: + db_url = 'sqlite:///tradesv2.sqlite' - engine = create_engine(_db_handle, echo=False) + engine = create_engine(db_url, echo=False) session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True)) Trade.session = session() Trade.query = session.query_property() diff --git a/test/test_main.py b/test/test_main.py index 8bd06c316..37aa8d67b 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -53,7 +53,7 @@ class TestMain(unittest.TestCase): 'last': 0.07256061 }), buy=MagicMock(return_value='mocked_order_id')): - init(self.conf) + init(self.conf, 'sqlite://') trade = create_trade(15.0, exchange.Exchange.BITTREX) Trade.session.add(trade) Trade.session.flush() @@ -99,10 +99,6 @@ class TestMain(unittest.TestCase): @classmethod def setUpClass(cls): - try: - os.remove('./tradesv2.dry_run.sqlite') - except FileNotFoundError: - pass validate(cls.conf, conf_schema) diff --git a/test/test_telegram.py b/test/test_telegram.py index fe27fe875..ac96b2e0a 100644 --- a/test/test_telegram.py +++ b/test/test_telegram.py @@ -62,7 +62,7 @@ class TestTelegram(unittest.TestCase): 'last': 0.07256061 }), buy=MagicMock(return_value='mocked_order_id')): - init(self.conf) + init(self.conf, 'sqlite://') # Create some test data trade = create_trade(15.0, exchange.Exchange.BITTREX) @@ -86,7 +86,7 @@ class TestTelegram(unittest.TestCase): 'last': 0.07256061 }), buy=MagicMock(return_value='mocked_order_id')): - init(self.conf) + init(self.conf, 'sqlite://') # Create some test data trade = create_trade(15.0, exchange.Exchange.BITTREX) @@ -115,7 +115,7 @@ class TestTelegram(unittest.TestCase): 'last': 0.07256061 }), buy=MagicMock(return_value='mocked_order_id')): - init(self.conf) + init(self.conf, 'sqlite://') # Create some test data trade = create_trade(15.0, exchange.Exchange.BITTREX) @@ -142,7 +142,7 @@ class TestTelegram(unittest.TestCase): 'last': 0.07256061 }), buy=MagicMock(return_value='mocked_order_id')): - init(self.conf) + init(self.conf, 'sqlite://') # Create some test data trade = create_trade(15.0, exchange.Exchange.BITTREX) @@ -164,7 +164,7 @@ class TestTelegram(unittest.TestCase): with patch.dict('main._conf', self.conf): msg_mock = MagicMock() with patch.multiple('main.telegram', _conf=self.conf, init=MagicMock(), send_msg=msg_mock): - init(self.conf) + init(self.conf, 'sqlite://') update_state(State.PAUSED) self.assertEqual(get_state(), State.PAUSED) @@ -176,7 +176,7 @@ class TestTelegram(unittest.TestCase): with patch.dict('main._conf', self.conf): msg_mock = MagicMock() with patch.multiple('main.telegram', _conf=self.conf, init=MagicMock(), send_msg=msg_mock): - init(self.conf) + init(self.conf, 'sqlite://') update_state(State.RUNNING) self.assertEqual(get_state(), State.RUNNING) @@ -186,10 +186,6 @@ class TestTelegram(unittest.TestCase): self.assertIn('Stopping trader', msg_mock.call_args_list[0][0][0]) def setUp(self): - try: - os.remove('./tradesv2.dry_run.sqlite') - except FileNotFoundError: - pass self.update = Update(0) self.update.message = Message(0, 0, datetime.utcnow(), Chat(0, 0))