use inmemory db for tests

This commit is contained in:
gcarq 2017-09-08 21:39:31 +02:00
parent 09e4c6893e
commit 689cd11a6c
4 changed files with 19 additions and 26 deletions

View File

@ -230,17 +230,18 @@ def create_trade(stake_amount: float, _exchange: exchange.Exchange) -> Optional[
is_open=True) is_open=True)
def init(config: dict) -> None: def init(config: dict, db_url: Optional[str]=None) -> None:
""" """
Initializes all modules and updates the config Initializes all modules and updates the config
:param config: config as dict :param config: config as dict
:param db_url: database connector string for sqlalchemy (Optional)
:return: None :return: None
""" """
global _conf global _conf
# Initialize all modules # Initialize all modules
telegram.init(config) telegram.init(config)
persistence.init(config) persistence.init(config, db_url)
exchange.init(config) exchange.init(config)
_conf.update(config) _conf.update(config)

View File

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Optional
from sqlalchemy import Boolean, Column, DateTime, Float, Integer, String, create_engine from sqlalchemy import Boolean, Column, DateTime, Float, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -10,29 +11,28 @@ from sqlalchemy.types import Enum
import exchange import exchange
_db_handle = None
_conf = {} _conf = {}
Base = declarative_base() Base = declarative_base()
def init(config: dict) -> None: def init(config: dict, db_url: Optional[str]=None) -> None:
""" """
Initializes this module with the given config, Initializes this module with the given config,
registers all known command handlers registers all known command handlers
and starts polling for message updates and starts polling for message updates
:param config: config to use :param config: config to use
:param db_url: database connector string for sqlalchemy (Optional)
:return: None :return: None
""" """
global _db_handle
_conf.update(config) _conf.update(config)
if not db_url:
if _conf.get('dry_run', False): if _conf.get('dry_run', False):
_db_handle = 'sqlite:///tradesv2.dry_run.sqlite' db_url = 'sqlite:///tradesv2.dry_run.sqlite'
else: else:
_db_handle = 'sqlite:///tradesv2.sqlite' db_url = 'sqlite:///tradesv2.sqlite'
engine = create_engine(_db_handle, echo=False) engine = create_engine(db_url, echo=False)
session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True)) session = scoped_session(sessionmaker(bind=engine, autoflush=True, autocommit=True))
Trade.session = session() Trade.session = session()
Trade.query = session.query_property() Trade.query = session.query_property()

View File

@ -53,7 +53,7 @@ class TestMain(unittest.TestCase):
'last': 0.07256061 'last': 0.07256061
}), }),
buy=MagicMock(return_value='mocked_order_id')): buy=MagicMock(return_value='mocked_order_id')):
init(self.conf) init(self.conf, 'sqlite://')
trade = create_trade(15.0, exchange.Exchange.BITTREX) trade = create_trade(15.0, exchange.Exchange.BITTREX)
Trade.session.add(trade) Trade.session.add(trade)
Trade.session.flush() Trade.session.flush()
@ -99,10 +99,6 @@ class TestMain(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
try:
os.remove('./tradesv2.dry_run.sqlite')
except FileNotFoundError:
pass
validate(cls.conf, conf_schema) validate(cls.conf, conf_schema)

View File

@ -62,7 +62,7 @@ class TestTelegram(unittest.TestCase):
'last': 0.07256061 'last': 0.07256061
}), }),
buy=MagicMock(return_value='mocked_order_id')): buy=MagicMock(return_value='mocked_order_id')):
init(self.conf) init(self.conf, 'sqlite://')
# Create some test data # Create some test data
trade = create_trade(15.0, exchange.Exchange.BITTREX) trade = create_trade(15.0, exchange.Exchange.BITTREX)
@ -86,7 +86,7 @@ class TestTelegram(unittest.TestCase):
'last': 0.07256061 'last': 0.07256061
}), }),
buy=MagicMock(return_value='mocked_order_id')): buy=MagicMock(return_value='mocked_order_id')):
init(self.conf) init(self.conf, 'sqlite://')
# Create some test data # Create some test data
trade = create_trade(15.0, exchange.Exchange.BITTREX) trade = create_trade(15.0, exchange.Exchange.BITTREX)
@ -115,7 +115,7 @@ class TestTelegram(unittest.TestCase):
'last': 0.07256061 'last': 0.07256061
}), }),
buy=MagicMock(return_value='mocked_order_id')): buy=MagicMock(return_value='mocked_order_id')):
init(self.conf) init(self.conf, 'sqlite://')
# Create some test data # Create some test data
trade = create_trade(15.0, exchange.Exchange.BITTREX) trade = create_trade(15.0, exchange.Exchange.BITTREX)
@ -142,7 +142,7 @@ class TestTelegram(unittest.TestCase):
'last': 0.07256061 'last': 0.07256061
}), }),
buy=MagicMock(return_value='mocked_order_id')): buy=MagicMock(return_value='mocked_order_id')):
init(self.conf) init(self.conf, 'sqlite://')
# Create some test data # Create some test data
trade = create_trade(15.0, exchange.Exchange.BITTREX) trade = create_trade(15.0, exchange.Exchange.BITTREX)
@ -164,7 +164,7 @@ class TestTelegram(unittest.TestCase):
with patch.dict('main._conf', self.conf): with patch.dict('main._conf', self.conf):
msg_mock = MagicMock() msg_mock = MagicMock()
with patch.multiple('main.telegram', _conf=self.conf, init=MagicMock(), send_msg=msg_mock): with patch.multiple('main.telegram', _conf=self.conf, init=MagicMock(), send_msg=msg_mock):
init(self.conf) init(self.conf, 'sqlite://')
update_state(State.PAUSED) update_state(State.PAUSED)
self.assertEqual(get_state(), State.PAUSED) self.assertEqual(get_state(), State.PAUSED)
@ -176,7 +176,7 @@ class TestTelegram(unittest.TestCase):
with patch.dict('main._conf', self.conf): with patch.dict('main._conf', self.conf):
msg_mock = MagicMock() msg_mock = MagicMock()
with patch.multiple('main.telegram', _conf=self.conf, init=MagicMock(), send_msg=msg_mock): with patch.multiple('main.telegram', _conf=self.conf, init=MagicMock(), send_msg=msg_mock):
init(self.conf) init(self.conf, 'sqlite://')
update_state(State.RUNNING) update_state(State.RUNNING)
self.assertEqual(get_state(), State.RUNNING) self.assertEqual(get_state(), State.RUNNING)
@ -186,10 +186,6 @@ class TestTelegram(unittest.TestCase):
self.assertIn('Stopping trader', msg_mock.call_args_list[0][0][0]) self.assertIn('Stopping trader', msg_mock.call_args_list[0][0][0])
def setUp(self): def setUp(self):
try:
os.remove('./tradesv2.dry_run.sqlite')
except FileNotFoundError:
pass
self.update = Update(0) self.update = Update(0)
self.update.message = Message(0, 0, datetime.utcnow(), Chat(0, 0)) self.update.message = Message(0, 0, datetime.utcnow(), Chat(0, 0))