Allow Strategy subclassing in different files by enabling local imports

This commit is contained in:
Matthias 2022-03-20 13:07:06 +01:00
parent fcec071a08
commit 49e087df5b
3 changed files with 45 additions and 24 deletions

View File

@ -164,16 +164,15 @@ class MyAwesomeStrategy2(MyAwesomeStrategy):
Both attributes and methods may be overridden, altering behavior of the original strategy in a way you need. Both attributes and methods may be overridden, altering behavior of the original strategy in a way you need.
!!! Note "Parent-strategy in different files" !!! Note "Parent-strategy in different files"
If you have the parent-strategy in a different file, you'll need to add the following to the top of your "child"-file to ensure proper loading, otherwise freqtrade may not be able to load the parent strategy correctly. If you have the parent-strategy in a different file, you can still import the strategy.
Assuming `myawesomestrategy.py` is the filename, and `MyAwesomeStrategy` the strategy you need to import:
``` python ``` python
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent))
from myawesomestrategy import MyAwesomeStrategy from myawesomestrategy import MyAwesomeStrategy
``` ```
This is the recommended way to derive strategies to avoid problems with hyperopt parameter files.
## Embedding Strategies ## Embedding Strategies
Freqtrade provides you with an easy way to embed the strategy into your configuration file. Freqtrade provides you with an easy way to embed the strategy into your configuration file.

View File

@ -6,6 +6,7 @@ This module load custom objects
import importlib.util import importlib.util
import inspect import inspect
import logging import logging
import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
@ -15,6 +16,22 @@ from freqtrade.exceptions import OperationalException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PathModifier:
def __init__(self, path: Path):
self.path = path
def __enter__(self):
"""Inject path to allow importing with relative imports."""
sys.path.insert(0, str(self.path))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Undo insertion of local path."""
str_path = str(self.path)
if str_path in sys.path:
sys.path.remove(str_path)
class IResolver: class IResolver:
""" """
This class contains all the logic to load custom classes This class contains all the logic to load custom classes
@ -57,27 +74,32 @@ class IResolver:
# Generate spec based on absolute path # Generate spec based on absolute path
# Pass object_name as first argument to have logging print a reasonable name. # Pass object_name as first argument to have logging print a reasonable name.
spec = importlib.util.spec_from_file_location(object_name or "", str(module_path)) with PathModifier(module_path.parent):
if not spec:
return iter([None])
module = importlib.util.module_from_spec(spec) spec = importlib.util.spec_from_file_location(module_path.stem or "", str(module_path))
try: if not spec:
spec.loader.exec_module(module) # type: ignore # importlib does not use typehints
except (ModuleNotFoundError, SyntaxError, ImportError, NameError) as err:
# Catch errors in case a specific module is not installed
logger.warning(f"Could not import {module_path} due to '{err}'")
if enum_failed:
return iter([None]) return iter([None])
valid_objects_gen = ( module = importlib.util.module_from_spec(spec)
(obj, inspect.getsource(module)) for try:
name, obj in inspect.getmembers( spec.loader.exec_module(module) # type: ignore # importlib does not use typehints
module, inspect.isclass) if ((object_name is None or object_name == name) except (ModuleNotFoundError, SyntaxError, ImportError, NameError) as err:
and issubclass(obj, cls.object_type) # Catch errors in case a specific module is not installed
and obj is not cls.object_type) logger.warning(f"Could not import {module_path} due to '{err}'")
) if enum_failed:
return valid_objects_gen return iter([None])
valid_objects_gen = (
(obj, inspect.getsource(module)) for
name, obj in inspect.getmembers(
module, inspect.isclass) if ((object_name is None or object_name == name)
and issubclass(obj, cls.object_type)
and obj is not cls.object_type
and obj.__module__ == module_path.stem or ""
)
)
# The __module__ check ensures we only use strategies that are defined in this folder.
return valid_objects_gen
@classmethod @classmethod
def _search_object(cls, directory: Path, *, object_name: str, add_source: bool = False def _search_object(cls, directory: Path, *, object_name: str, add_source: bool = False

View File

@ -7,7 +7,7 @@ from pandas import DataFrame
import freqtrade.vendor.qtpylib.indicators as qtpylib import freqtrade.vendor.qtpylib.indicators as qtpylib
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
from freqtrade.strategy.interface import IStrategy from freqtrade.strategy import IStrategy
class StrategyTestV2(IStrategy): class StrategyTestV2(IStrategy):