diff --git a/freqtrade/resolvers/hyperopt_resolver.py b/freqtrade/resolvers/hyperopt_resolver.py index e96394d69..db51c3ca5 100644 --- a/freqtrade/resolvers/hyperopt_resolver.py +++ b/freqtrade/resolvers/hyperopt_resolver.py @@ -52,14 +52,8 @@ class HyperOptResolver(IResolver): """ current_path = Path(__file__).parent.parent.joinpath('optimize').resolve() - abs_paths = [ - config['user_data_dir'].joinpath('hyperopts'), - current_path, - ] - - if extra_dir: - # Add extra hyperopt directory on top of search paths - abs_paths.insert(0, Path(extra_dir).resolve()) + abs_paths = self.build_search_paths(config, current_path=current_path, + user_subdir='hyperopts', extra_dir=extra_dir) hyperopt = self._load_object(paths=abs_paths, object_type=IHyperOpt, object_name=hyperopt_name, kwargs={'config': config}) @@ -109,14 +103,8 @@ class HyperOptLossResolver(IResolver): """ current_path = Path(__file__).parent.parent.joinpath('optimize').resolve() - abs_paths = [ - config['user_data_dir'].joinpath('hyperopts'), - current_path, - ] - - if extra_dir: - # Add extra hyperopt directory on top of search paths - abs_paths.insert(0, Path(extra_dir).resolve()) + abs_paths = self.build_search_paths(config, current_path=current_path, + user_subdir='hyperopts', extra_dir=extra_dir) hyperoptloss = self._load_object(paths=abs_paths, object_type=IHyperOptLoss, object_name=hyper_loss_name) diff --git a/freqtrade/resolvers/iresolver.py b/freqtrade/resolvers/iresolver.py index 6303d4801..51c4f7dba 100644 --- a/freqtrade/resolvers/iresolver.py +++ b/freqtrade/resolvers/iresolver.py @@ -7,7 +7,7 @@ import importlib.util import inspect import logging from pathlib import Path -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, List, Optional, Tuple, Union, Generator logger = logging.getLogger(__name__) @@ -17,15 +17,29 @@ class IResolver: This class contains all the logic to load custom classes """ + def build_search_paths(self, config, current_path: Path, user_subdir: str, + extra_dir: Optional[str] = None) -> List[Path]: + + abs_paths = [ + config['user_data_dir'].joinpath(user_subdir), + current_path, + ] + + if extra_dir: + # Add extra directory to the top of the search paths + abs_paths.insert(0, Path(extra_dir).resolve()) + + return abs_paths + @staticmethod def _get_valid_object(object_type, module_path: Path, - object_name: str) -> Optional[Type[Any]]: + object_name: str) -> Generator[Any, None, None]: """ - Returns the first object with matching object_type and object_name in the path given. + Generator returning objects with matching object_type and object_name in the path given. :param object_type: object_type (class) :param module_path: absolute path to the module :param object_name: Class name of the object - :return: class or None + :return: generator containing matching objects """ # Generate spec based on absolute path @@ -42,7 +56,7 @@ class IResolver: obj for name, obj in inspect.getmembers(module, inspect.isclass) if object_name == name and object_type in obj.__bases__ ) - return next(valid_objects_gen, None) + return valid_objects_gen @staticmethod def _search_object(directory: Path, object_type, object_name: str, @@ -59,9 +73,9 @@ class IResolver: logger.debug('Ignoring %s', entry) continue module_path = entry.resolve() - obj = IResolver._get_valid_object( - object_type, module_path, object_name - ) + + obj = next(IResolver._get_valid_object(object_type, module_path, object_name), None) + if obj: return (obj(**kwargs), module_path) return (None, None) diff --git a/freqtrade/resolvers/pairlist_resolver.py b/freqtrade/resolvers/pairlist_resolver.py index f38253155..2ddf5de2f 100644 --- a/freqtrade/resolvers/pairlist_resolver.py +++ b/freqtrade/resolvers/pairlist_resolver.py @@ -39,10 +39,8 @@ class PairListResolver(IResolver): """ current_path = Path(__file__).parent.parent.joinpath('pairlist').resolve() - abs_paths = [ - config['user_data_dir'].joinpath('pairlist'), - current_path, - ] + abs_paths = self.build_search_paths(config, current_path=current_path, + user_subdir='pairlist', extra_dir=None) pairlist = self._load_object(paths=abs_paths, object_type=IPairList, object_name=pairlist_name, kwargs=kwargs) diff --git a/freqtrade/resolvers/strategy_resolver.py b/freqtrade/resolvers/strategy_resolver.py index b9c641853..d6fbe9a7a 100644 --- a/freqtrade/resolvers/strategy_resolver.py +++ b/freqtrade/resolvers/strategy_resolver.py @@ -124,14 +124,8 @@ class StrategyResolver(IResolver): """ current_path = Path(__file__).parent.parent.joinpath('strategy').resolve() - abs_paths = [ - config['user_data_dir'].joinpath('strategies'), - current_path, - ] - - if extra_dir: - # Add extra strategy directory on top of search paths - abs_paths.insert(0, Path(extra_dir).resolve()) + abs_paths = self.build_search_paths(config, current_path=current_path, + user_subdir='strategies', extra_dir=extra_dir) if ":" in strategy_name: logger.info("loading base64 encoded strategy")