refactor private method - improve some async tests

This commit is contained in:
Matthias 2018-08-14 20:33:03 +02:00
parent 8528143ffa
commit 37e504610a
2 changed files with 39 additions and 30 deletions

View File

@ -378,7 +378,7 @@ class Exchange(object):
one_call = constants.TICKER_INTERVAL_MINUTES[tick_interval] * 60 * _LIMIT * 1000 one_call = constants.TICKER_INTERVAL_MINUTES[tick_interval] * 60 * _LIMIT * 1000
logger.debug("one_call: %s", one_call) logger.debug("one_call: %s", one_call)
input_coroutines = [self.async_get_candle_history( input_coroutines = [self._async_get_candle_history(
pair, tick_interval, since) for since in pair, tick_interval, since) for since in
range(since_ms, int(time.time() * 1000), one_call)] range(since_ms, int(time.time() * 1000), one_call)]
tickers = await asyncio.gather(*input_coroutines, return_exceptions=True) tickers = await asyncio.gather(*input_coroutines, return_exceptions=True)
@ -397,13 +397,13 @@ class Exchange(object):
# loop = asyncio.new_event_loop() # loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop) # asyncio.set_event_loop(loop)
# await self._api_async.load_markets() # await self._api_async.load_markets()
input_coroutines = [self.async_get_candle_history( input_coroutines = [self._async_get_candle_history(
symbol, tick_interval) for symbol in pairs] symbol, tick_interval) for symbol in pairs]
tickers = await asyncio.gather(*input_coroutines, return_exceptions=True) tickers = await asyncio.gather(*input_coroutines, return_exceptions=True)
# await self._api_async.close() # await self._api_async.close()
return tickers return tickers
async def async_get_candle_history(self, pair: str, tick_interval: str, async def _async_get_candle_history(self, pair: str, tick_interval: str,
since_ms: Optional[int] = None) -> Tuple[str, List]: since_ms: Optional[int] = None) -> Tuple[str, List]:
try: try:
# fetch ohlcv asynchronously # fetch ohlcv asynchronously

View File

@ -3,7 +3,8 @@
import logging import logging
from datetime import datetime from datetime import datetime
from random import randint from random import randint
from unittest.mock import MagicMock, PropertyMock import time
from unittest.mock import Mock, MagicMock, PropertyMock
import ccxt import ccxt
import pytest import pytest
@ -13,6 +14,14 @@ from freqtrade.exchange import API_RETRY_COUNT, Exchange
from freqtrade.tests.conftest import get_patched_exchange, log_has from freqtrade.tests.conftest import get_patched_exchange, log_has
# Source: https://stackoverflow.com/questions/29881236/how-to-mock-asyncio-coroutines
def get_mock_coro(return_value):
async def mock_coro(*args, **kwargs):
return return_value
return Mock(wraps=mock_coro)
async def async_load_markets(): async def async_load_markets():
return {} return {}
@ -549,10 +558,10 @@ def test_get_ticker(default_conf, mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_get_candle_history(default_conf, mocker): async def test__async_get_candle_history(default_conf, mocker, caplog):
tick = [ tick = [
[ [
1511686200000, # unix timestamp ms int(time.time() * 1000), # unix timestamp ms
1, # open 1, # open
2, # high 2, # high
3, # low 3, # low
@ -563,28 +572,37 @@ async def test_async_get_candle_history(default_conf, mocker):
async def async_fetch_ohlcv(pair, timeframe, since): async def async_fetch_ohlcv(pair, timeframe, since):
return tick return tick
caplog.set_level(logging.DEBUG)
exchange = get_patched_exchange(mocker, default_conf) exchange = get_patched_exchange(mocker, default_conf)
# Monkey-patch async function # Monkey-patch async function
exchange._api_async.fetch_ohlcv = async_fetch_ohlcv exchange._api_async.fetch_ohlcv = get_mock_coro(tick)
exchange = Exchange(default_conf) exchange = Exchange(default_conf)
pair = 'ETH/BTC' pair = 'ETH/BTC'
res = await exchange.async_get_candle_history(pair, "5m") res = await exchange._async_get_candle_history(pair, "5m")
assert type(res) is tuple assert type(res) is tuple
assert len(res) == 2 assert len(res) == 2
assert res[0] == pair assert res[0] == pair
assert res[1] == tick assert res[1] == tick
assert exchange._api_async.fetch_ohlcv.call_count == 1
assert not log_has(f"Using cached klines data for {pair} ...", caplog.record_tuples)
# test caching
res = await exchange._async_get_candle_history(pair, "5m")
assert exchange._api_async.fetch_ohlcv.call_count == 1
assert log_has(f"Using cached klines data for {pair} ...", caplog.record_tuples)
# exchange = Exchange(default_conf)
await async_ccxt_exception(mocker, default_conf, MagicMock(), await async_ccxt_exception(mocker, default_conf, MagicMock(),
"async_get_candle_history", "fetch_ohlcv", "_async_get_candle_history", "fetch_ohlcv",
pair='ABCD/BTC', tick_interval=default_conf['ticker_interval']) pair='ABCD/BTC', tick_interval=default_conf['ticker_interval'])
# # reinit exchange
# del exchange
api_mock = MagicMock() # api_mock = MagicMock()
with pytest.raises(OperationalException, match=r'Could not fetch ticker data*'): # with pytest.raises(OperationalException, match=r'Could not fetch ticker data*'):
api_mock.fetch_ohlcv = MagicMock(side_effect=ccxt.BaseError) # api_mock.fetch_ohlcv = MagicMock(side_effect=ccxt.BaseError)
exchange = get_patched_exchange(mocker, default_conf, api_mock) # exchange = get_patched_exchange(mocker, default_conf, api_mock)
await exchange.async_get_candle_history(pair, "5m") # await exchange._async_get_candle_history(pair, "5m")
@pytest.mark.asyncio @pytest.mark.asyncio
@ -600,14 +618,14 @@ async def test_async_get_candles_history(default_conf, mocker):
] ]
] ]
async def async_fetch_ohlcv(pair, timeframe, since): async def mock_get_candle_hist(pair, tick_interval, since_ms=None):
return tick return (pair, tick)
exchange = get_patched_exchange(mocker, default_conf) exchange = get_patched_exchange(mocker, default_conf)
# Monkey-patch async function # Monkey-patch async function
exchange._api_async.fetch_ohlcv = async_fetch_ohlcv exchange._api_async.fetch_ohlcv = get_mock_coro(tick)
exchange._api_async.load_markets = async_load_markets exchange._async_get_candle_history = Mock(wraps=mock_get_candle_hist)
pairs = ['ETH/BTC', 'XRP/BTC'] pairs = ['ETH/BTC', 'XRP/BTC']
res = await exchange.async_get_candles_history(pairs, "5m") res = await exchange.async_get_candles_history(pairs, "5m")
@ -618,16 +636,7 @@ async def test_async_get_candles_history(default_conf, mocker):
assert res[0][1] == tick assert res[0][1] == tick
assert res[1][0] == pairs[1] assert res[1][0] == pairs[1]
assert res[1][1] == tick assert res[1][1] == tick
assert exchange._async_get_candle_history.call_count == 2
# await async_ccxt_exception(mocker, default_conf, MagicMock(),
# "async_get_candles_history", "fetch_ohlcv",
# pairs=pairs, tick_interval=default_conf['ticker_interval'])
# api_mock = MagicMock()
# with pytest.raises(OperationalException, match=r'Could not fetch ticker data*'):
# api_mock.fetch_ohlcv = MagicMock(side_effect=ccxt.BaseError)
# exchange = get_patched_exchange(mocker, default_conf, api_mock)
# await exchange.async_get_candles_history('ETH/BTC', "5m")
def make_fetch_ohlcv_mock(data): def make_fetch_ohlcv_mock(data):