Add test for __code__ loading

This commit is contained in:
Matthias 2020-09-17 07:38:56 +02:00
parent ba10bd7756
commit becccca3d1
4 changed files with 21 additions and 10 deletions

View File

@ -51,7 +51,8 @@ class IResolver:
:param object_name: Class name of the object
:param enum_failed: If True, will return None for modules which fail.
Otherwise, failing modules are skipped.
:return: generator containing matching objects
:return: generator containing tuple of matching objects
Tuple format: [Object, source]
"""
# Generate spec based on absolute path
@ -67,7 +68,8 @@ class IResolver:
return iter([None])
valid_objects_gen = (
(obj, inspect.getsource(module)) for name, obj in inspect.getmembers(
(obj, inspect.getsource(module)) for
name, obj in inspect.getmembers(
module, inspect.isclass) if ((object_name is None or object_name == name)
and issubclass(obj, cls.object_type)
and obj is not cls.object_type)
@ -75,7 +77,7 @@ class IResolver:
return valid_objects_gen
@classmethod
def _search_object(cls, directory: Path, object_name: str
def _search_object(cls, directory: Path, object_name: str, add_source: bool = False
) -> Union[Tuple[Any, Path], Tuple[None, None]]:
"""
Search for the objectname in the given directory
@ -94,12 +96,14 @@ class IResolver:
obj = next(cls._get_valid_object(module_path, object_name), None)
if obj:
obj[0].__file__ = str(entry)
if add_source:
obj[0].__code__ = obj[1]
return (obj[0], module_path)
return (None, None)
@classmethod
def _load_object(cls, paths: List[Path], object_name: str,
def _load_object(cls, paths: List[Path], object_name: str, add_source: bool = False,
kwargs: dict = {}) -> Optional[Any]:
"""
Try to load object from path list.
@ -108,7 +112,8 @@ class IResolver:
for _path in paths:
try:
(module, module_path) = cls._search_object(directory=_path,
object_name=object_name)
object_name=object_name,
add_source=add_source)
if module:
logger.info(
f"Using resolved {cls.object_type.__name__.lower()[1:]} {object_name} "

View File

@ -174,7 +174,9 @@ class StrategyResolver(IResolver):
strategy = StrategyResolver._load_object(paths=abs_paths,
object_name=strategy_name,
kwargs={'config': config})
add_source=True,
kwargs={'config': config},
)
if strategy:
strategy._populate_fun_len = len(getfullargspec(strategy.populate_indicators).args)
strategy._buy_fun_len = len(getfullargspec(strategy.populate_buy_trend).args)

View File

@ -921,7 +921,6 @@ def test_api_pair_history(botclient, ohlcv_history):
assert rc.json['data_stop_ts'] == 1515715200000
def test_api_plot_config(botclient):
ftbot, client = botclient

View File

@ -18,13 +18,15 @@ def test_search_strategy():
s, _ = StrategyResolver._search_object(
directory=default_location,
object_name='DefaultStrategy'
object_name='DefaultStrategy',
add_source=True,
)
assert issubclass(s, IStrategy)
s, _ = StrategyResolver._search_object(
directory=default_location,
object_name='NotFoundStrategy'
object_name='NotFoundStrategy',
add_source=True,
)
assert s is None
@ -53,6 +55,9 @@ def test_load_strategy(default_conf, result):
'strategy_path': str(Path(__file__).parents[2] / 'freqtrade/templates')
})
strategy = StrategyResolver.load_strategy(default_conf)
assert isinstance(strategy.__code__, str)
assert 'class SampleStrategy' in strategy.__code__
assert isinstance(strategy.__file__, str)
assert 'rsi' in strategy.advise_indicators(result, {'pair': 'ETH/BTC'})