refactor StrategyResolver to work with class names
This commit is contained in:
@@ -7,8 +7,6 @@ import freqtrade.vendor.qtpylib.indicators as qtpylib
|
||||
from freqtrade.indicator_helpers import fishers_inverse
|
||||
from freqtrade.strategy.interface import IStrategy
|
||||
|
||||
class_name = 'DefaultStrategy'
|
||||
|
||||
|
||||
class DefaultStrategy(IStrategy):
|
||||
"""
|
||||
|
@@ -3,11 +3,11 @@
|
||||
"""
|
||||
This module load custom strategies
|
||||
"""
|
||||
import importlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Dict
|
||||
from typing import Optional, Dict, Type
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
@@ -15,8 +15,6 @@ from freqtrade.constants import Constants
|
||||
from freqtrade.logger import Logger
|
||||
from freqtrade.strategy.interface import IStrategy
|
||||
|
||||
sys.path.insert(0, r'../../user_data/strategies')
|
||||
|
||||
|
||||
class StrategyResolver(object):
|
||||
"""
|
||||
@@ -38,7 +36,7 @@ class StrategyResolver(object):
|
||||
else:
|
||||
strategy = Constants.DEFAULT_STRATEGY
|
||||
|
||||
# Load the strategy
|
||||
# Try to load the strategy
|
||||
self._load_strategy(strategy)
|
||||
|
||||
# Set attributes
|
||||
@@ -72,26 +70,27 @@ class StrategyResolver(object):
|
||||
|
||||
def _load_strategy(self, strategy_name: str) -> None:
|
||||
"""
|
||||
Search and load the custom strategy. If no strategy found, fallback on the default strategy
|
||||
Set the object into self.custom_strategy
|
||||
Search and loads the specified strategy.
|
||||
:param strategy_name: name of the module to import
|
||||
:return: None
|
||||
"""
|
||||
|
||||
try:
|
||||
# Start by sanitizing the file name (remove any extensions)
|
||||
strategy_name = self._sanitize_module_name(filename=strategy_name)
|
||||
|
||||
# Search where can be the strategy file
|
||||
path = self._search_strategy(filename=strategy_name)
|
||||
|
||||
# Load the strategy
|
||||
self.custom_strategy = self._load_class(path + strategy_name)
|
||||
current_path = os.path.dirname(os.path.realpath(__file__))
|
||||
abs_paths = [
|
||||
os.path.join(current_path, '..', '..', 'user_data', 'strategies'),
|
||||
current_path,
|
||||
]
|
||||
for path in abs_paths:
|
||||
self.custom_strategy = self._search_strategy(path, strategy_name)
|
||||
if self.custom_strategy:
|
||||
self.logger.info('Using resolved strategy %s from \'%s\'', strategy_name, path)
|
||||
return None
|
||||
|
||||
raise ImportError('not found')
|
||||
# Fallback to the default strategy
|
||||
except (ImportError, TypeError) as error:
|
||||
self.logger.error(
|
||||
"Impossible to load Strategy 'user_data/strategies/%s.py'. This file does not exist"
|
||||
"Impossible to load Strategy '%s'. This class does not exist"
|
||||
" or contains Python code errors",
|
||||
strategy_name
|
||||
)
|
||||
@@ -100,50 +99,44 @@ class StrategyResolver(object):
|
||||
error
|
||||
)
|
||||
|
||||
def _load_class(self, filename: str) -> IStrategy:
|
||||
"""
|
||||
Import a strategy as a module
|
||||
:param filename: path to the strategy (path from freqtrade/strategy/)
|
||||
:return: return the strategy class
|
||||
"""
|
||||
module = importlib.import_module(filename, __package__)
|
||||
custom_strategy = getattr(module, module.class_name)
|
||||
|
||||
self.logger.info("Load strategy class: %s (%s.py)", module.class_name, filename)
|
||||
return custom_strategy()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_module_name(filename: str) -> str:
|
||||
def _get_valid_strategies(module_path: str, strategy_name: str) -> Optional[Type[IStrategy]]:
|
||||
"""
|
||||
Remove any extension from filename
|
||||
:param filename: filename to sanatize
|
||||
:return: return the filename without extensions
|
||||
Returns a list of all possible strategies for the given module_path
|
||||
:param module_path: absolute path to the module
|
||||
:param strategy_name: Class name of the strategy
|
||||
:return: Tuple with (name, class) or None
|
||||
"""
|
||||
filename = os.path.basename(filename)
|
||||
filename = os.path.splitext(filename)[0]
|
||||
return filename
|
||||
|
||||
@staticmethod
|
||||
def _search_strategy(filename: str) -> str:
|
||||
"""
|
||||
Search for the Strategy file in different folder
|
||||
1. search into the user_data/strategies folder
|
||||
2. search into the freqtrade/strategy folder
|
||||
3. if nothing found, return None
|
||||
:param strategy_name: module name to search
|
||||
:return: module path where is the strategy
|
||||
"""
|
||||
pwd = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||
user_data = os.path.join(pwd, '..', '..', 'user_data', 'strategies', filename + '.py')
|
||||
strategy_folder = os.path.join(pwd, filename + '.py')
|
||||
# Generate spec based on absolute path
|
||||
spec = importlib.util.spec_from_file_location('user_data.strategies', module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
path = None
|
||||
if os.path.isfile(user_data):
|
||||
path = 'user_data.strategies.'
|
||||
elif os.path.isfile(strategy_folder):
|
||||
path = '.'
|
||||
valid_strategies_gen = (
|
||||
obj for name, obj in inspect.getmembers(module, inspect.isclass)
|
||||
if strategy_name == name and IStrategy in obj.__bases__
|
||||
)
|
||||
return next(valid_strategies_gen, None)
|
||||
|
||||
return path
|
||||
def _search_strategy(self, directory: str, strategy_name: str) -> Optional[IStrategy]:
|
||||
"""
|
||||
Search for the strategy_name in the given directory
|
||||
:param directory: relative or absolute directory path
|
||||
:return: name of the strategy class
|
||||
"""
|
||||
self.logger.debug('Searching for strategy %s in \'%s\'', strategy_name, directory)
|
||||
for entry in os.listdir(directory):
|
||||
# Only consider python files
|
||||
if not entry.endswith('.py'):
|
||||
self.logger.debug('Ignoring %s', entry)
|
||||
continue
|
||||
strategy = StrategyResolver._get_valid_strategies(
|
||||
os.path.abspath(os.path.join(directory, entry)), strategy_name
|
||||
)
|
||||
if strategy:
|
||||
return strategy()
|
||||
return None
|
||||
|
||||
def populate_indicators(self, dataframe: DataFrame) -> DataFrame:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user