don't allow to resolve object if there are more than one module containing it

This commit is contained in:
hroff-1902 2019-07-20 12:07:55 +03:00
parent 4d0cf9ec8e
commit a5fe598ff0
7 changed files with 31 additions and 8 deletions

View File

@ -14,6 +14,7 @@ class ExchangeResolver(IResolver):
""" """
This class contains all the logic to load a custom exchange class This class contains all the logic to load a custom exchange class
""" """
type_name = "Exchange"
__slots__ = ['exchange'] __slots__ = ['exchange']

View File

@ -19,6 +19,7 @@ class HyperOptResolver(IResolver):
""" """
This class contains all the logic to load custom hyperopt class This class contains all the logic to load custom hyperopt class
""" """
type_name = "Hyperopt"
__slots__ = ['hyperopt'] __slots__ = ['hyperopt']

View File

@ -9,6 +9,9 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Tuple, Type, Union from typing import Any, Optional, Tuple, Type, Union
from freqtrade import OperationalException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,6 +19,7 @@ class IResolver(object):
""" """
This class contains all the logic to load custom classes This class contains all the logic to load custom classes
""" """
type_name = "Unknown"
@staticmethod @staticmethod
def _get_valid_object(object_type, module_path: Path, def _get_valid_object(object_type, module_path: Path,
@ -43,8 +47,8 @@ class IResolver(object):
) )
return next(valid_objects_gen, None) return next(valid_objects_gen, None)
@staticmethod @classmethod
def _search_object(directory: Path, object_type, object_name: str, def _search_object(self, directory: Path, object_type, object_name: str,
kwargs: dict = {}) -> Union[Tuple[Any, Path], Tuple[None, None]]: kwargs: dict = {}) -> Union[Tuple[Any, Path], Tuple[None, None]]:
""" """
Search for the objectname in the given directory Search for the objectname in the given directory
@ -52,6 +56,7 @@ class IResolver(object):
:return: object instance :return: object instance
""" """
logger.debug("Searching for %s %s in '%s'", object_type.__name__, object_name, directory) logger.debug("Searching for %s %s in '%s'", object_type.__name__, object_name, directory)
objs = []
for entry in directory.iterdir(): for entry in directory.iterdir():
# Only consider python files # Only consider python files
if not str(entry).endswith('.py'): if not str(entry).endswith('.py'):
@ -62,5 +67,16 @@ class IResolver(object):
object_type, module_path, object_name object_type, module_path, object_name
) )
if obj: if obj:
return (obj(**kwargs), module_path) objs.append((obj, module_path))
return (None, None) if len(objs) == 0:
return (None, None)
elif len(objs) == 1:
obj, module_path = objs[0]
return (obj(**kwargs), module_path)
else:
raise OperationalException(
f"Cannot resolve object: found more than one objects of type "
f"`{self.type_name}` with name `{object_name}`. "
"Use unique names for custom strategies, hyperopts and other custom objects "
"so that Freqtrade can be able to resolve them. "
f"Found in modules: {[str(m) for (_, m) in objs]}")

View File

@ -17,6 +17,7 @@ class PairListResolver(IResolver):
""" """
This class contains all the logic to load custom hyperopt class This class contains all the logic to load custom hyperopt class
""" """
type_name = "PairList"
__slots__ = ['pairlist'] __slots__ = ['pairlist']

View File

@ -23,6 +23,7 @@ class StrategyResolver(IResolver):
""" """
This class contains all the logic to load custom strategy class This class contains all the logic to load custom strategy class
""" """
type_name = "Strategy"
__slots__ = ['strategy'] __slots__ = ['strategy']

View File

@ -2,9 +2,8 @@ import logging
import sys import sys
from copy import deepcopy from copy import deepcopy
from freqtrade import constants
from freqtrade.strategy.interface import IStrategy from freqtrade.strategy.interface import IStrategy
# Import Default-Strategy to have hyperopt correctly resolve
from freqtrade.strategy.default_strategy import DefaultStrategy # noqa: F401
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,7 +27,10 @@ def import_strategy(strategy: IStrategy, config: dict) -> IStrategy:
attr = deepcopy(comb) attr = deepcopy(comb)
# Adjust module name # Adjust module name
attr['__module__'] = 'freqtrade.strategy' attr['__module__'] = (
'freqtrade.strategy.default_strategy'
if config.get('strategy') == constants.DEFAULT_STRATEGY
else 'freqtrade.strategy')
name = strategy.__class__.__name__ name = strategy.__class__.__name__
clazz = type(name, (IStrategy,), attr) clazz = type(name, (IStrategy,), attr)

View File

@ -44,7 +44,8 @@ def test_import_strategy(caplog):
def test_search_strategy(): def test_search_strategy():
default_config = {} default_config = {}
default_location = Path(__file__).parent.parent.joinpath('strategy').resolve() # Use default strategy from freqtrade/strategy, not from freqtrade/tests/strategy
default_location = Path(__file__).parent.parent.parent.joinpath('strategy').resolve()
s, _ = StrategyResolver._search_object( s, _ = StrategyResolver._search_object(
directory=default_location, directory=default_location,