From 49e087df5b79d90381ac9ee64d0eaac88e044f74 Mon Sep 17 00:00:00 2001 From: Matthias Date: Sun, 20 Mar 2022 13:07:06 +0100 Subject: [PATCH] Allow Strategy subclassing in different files by enabling local imports --- docs/strategy-advanced.md | 9 ++-- freqtrade/resolvers/iresolver.py | 58 ++++++++++++++++------- tests/strategy/strats/strategy_test_v2.py | 2 +- 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/docs/strategy-advanced.md b/docs/strategy-advanced.md index 3793abacf..fa1c09560 100644 --- a/docs/strategy-advanced.md +++ b/docs/strategy-advanced.md @@ -164,16 +164,15 @@ class MyAwesomeStrategy2(MyAwesomeStrategy): Both attributes and methods may be overridden, altering behavior of the original strategy in a way you need. !!! Note "Parent-strategy in different files" - If you have the parent-strategy in a different file, you'll need to add the following to the top of your "child"-file to ensure proper loading, otherwise freqtrade may not be able to load the parent strategy correctly. + If you have the parent-strategy in a different file, you can still import the strategy. + Assuming `myawesomestrategy.py` is the filename, and `MyAwesomeStrategy` the strategy you need to import: ``` python - import sys - from pathlib import Path - sys.path.append(str(Path(__file__).parent)) - from myawesomestrategy import MyAwesomeStrategy ``` + This is the recommended way to derive strategies to avoid problems with hyperopt parameter files. + ## Embedding Strategies Freqtrade provides you with an easy way to embed the strategy into your configuration file. diff --git a/freqtrade/resolvers/iresolver.py b/freqtrade/resolvers/iresolver.py index c6f97c976..8d132da70 100644 --- a/freqtrade/resolvers/iresolver.py +++ b/freqtrade/resolvers/iresolver.py @@ -6,6 +6,7 @@ This module load custom objects import importlib.util import inspect import logging +import sys from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union @@ -15,6 +16,22 @@ from freqtrade.exceptions import OperationalException logger = logging.getLogger(__name__) +class PathModifier: + def __init__(self, path: Path): + self.path = path + + def __enter__(self): + """Inject path to allow importing with relative imports.""" + sys.path.insert(0, str(self.path)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Undo insertion of local path.""" + str_path = str(self.path) + if str_path in sys.path: + sys.path.remove(str_path) + + class IResolver: """ This class contains all the logic to load custom classes @@ -57,27 +74,32 @@ class IResolver: # Generate spec based on absolute path # Pass object_name as first argument to have logging print a reasonable name. - spec = importlib.util.spec_from_file_location(object_name or "", str(module_path)) - if not spec: - return iter([None]) + with PathModifier(module_path.parent): - module = importlib.util.module_from_spec(spec) - try: - spec.loader.exec_module(module) # type: ignore # importlib does not use typehints - except (ModuleNotFoundError, SyntaxError, ImportError, NameError) as err: - # Catch errors in case a specific module is not installed - logger.warning(f"Could not import {module_path} due to '{err}'") - if enum_failed: + spec = importlib.util.spec_from_file_location(module_path.stem or "", str(module_path)) + if not spec: return iter([None]) - valid_objects_gen = ( - (obj, inspect.getsource(module)) for - name, obj in inspect.getmembers( - module, inspect.isclass) if ((object_name is None or object_name == name) - and issubclass(obj, cls.object_type) - and obj is not cls.object_type) - ) - return valid_objects_gen + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) # type: ignore # importlib does not use typehints + except (ModuleNotFoundError, SyntaxError, ImportError, NameError) as err: + # Catch errors in case a specific module is not installed + logger.warning(f"Could not import {module_path} due to '{err}'") + if enum_failed: + return iter([None]) + + valid_objects_gen = ( + (obj, inspect.getsource(module)) for + name, obj in inspect.getmembers( + module, inspect.isclass) if ((object_name is None or object_name == name) + and issubclass(obj, cls.object_type) + and obj is not cls.object_type + and obj.__module__ == module_path.stem or "" + ) + ) + # The __module__ check ensures we only use strategies that are defined in this folder. + return valid_objects_gen @classmethod def _search_object(cls, directory: Path, *, object_name: str, add_source: bool = False diff --git a/tests/strategy/strats/strategy_test_v2.py b/tests/strategy/strats/strategy_test_v2.py index c57becdad..59f1f569e 100644 --- a/tests/strategy/strats/strategy_test_v2.py +++ b/tests/strategy/strats/strategy_test_v2.py @@ -7,7 +7,7 @@ from pandas import DataFrame import freqtrade.vendor.qtpylib.indicators as qtpylib from freqtrade.persistence import Trade -from freqtrade.strategy.interface import IStrategy +from freqtrade.strategy import IStrategy class StrategyTestV2(IStrategy):