refactor StrategyResolver to work with class names

This commit is contained in:
gcarq 2018-03-24 20:44:04 +01:00
parent 6e5c14a95b
commit b4d2a3f495
12 changed files with 85 additions and 112 deletions

View File

@ -82,7 +82,7 @@ class Arguments(object):
'-s', '--strategy', '-s', '--strategy',
help='specify strategy file (default: %(default)s)', help='specify strategy file (default: %(default)s)',
dest='strategy', dest='strategy',
default='default_strategy', default='DefaultStrategy',
type=str, type=str,
metavar='PATH', metavar='PATH',
) )

View File

@ -14,7 +14,7 @@ class Constants(object):
TICKER_INTERVAL = 5 # min TICKER_INTERVAL = 5 # min
HYPEROPT_EPOCH = 100 # epochs HYPEROPT_EPOCH = 100 # epochs
RETRY_TIMEOUT = 30 # sec RETRY_TIMEOUT = 30 # sec
DEFAULT_STRATEGY = 'default_strategy' DEFAULT_STRATEGY = 'DefaultStrategy'
# Required json-schema for user specified config # Required json-schema for user specified config
CONF_SCHEMA = { CONF_SCHEMA = {

View File

@ -7,8 +7,6 @@ import freqtrade.vendor.qtpylib.indicators as qtpylib
from freqtrade.indicator_helpers import fishers_inverse from freqtrade.indicator_helpers import fishers_inverse
from freqtrade.strategy.interface import IStrategy from freqtrade.strategy.interface import IStrategy
class_name = 'DefaultStrategy'
class DefaultStrategy(IStrategy): class DefaultStrategy(IStrategy):
""" """

View File

@ -3,11 +3,11 @@
""" """
This module load custom strategies This module load custom strategies
""" """
import importlib import importlib.util
import inspect
import os import os
import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Dict from typing import Optional, Dict, Type
from pandas import DataFrame from pandas import DataFrame
@ -15,8 +15,6 @@ from freqtrade.constants import Constants
from freqtrade.logger import Logger from freqtrade.logger import Logger
from freqtrade.strategy.interface import IStrategy from freqtrade.strategy.interface import IStrategy
sys.path.insert(0, r'../../user_data/strategies')
class StrategyResolver(object): class StrategyResolver(object):
""" """
@ -38,7 +36,7 @@ class StrategyResolver(object):
else: else:
strategy = Constants.DEFAULT_STRATEGY strategy = Constants.DEFAULT_STRATEGY
# Load the strategy # Try to load the strategy
self._load_strategy(strategy) self._load_strategy(strategy)
# Set attributes # Set attributes
@ -72,26 +70,27 @@ class StrategyResolver(object):
def _load_strategy(self, strategy_name: str) -> None: def _load_strategy(self, strategy_name: str) -> None:
""" """
Search and load the custom strategy. If no strategy found, fallback on the default strategy Search and loads the specified strategy.
Set the object into self.custom_strategy
:param strategy_name: name of the module to import :param strategy_name: name of the module to import
:return: None :return: None
""" """
try: try:
# Start by sanitizing the file name (remove any extensions) current_path = os.path.dirname(os.path.realpath(__file__))
strategy_name = self._sanitize_module_name(filename=strategy_name) abs_paths = [
os.path.join(current_path, '..', '..', 'user_data', 'strategies'),
# Search where can be the strategy file current_path,
path = self._search_strategy(filename=strategy_name) ]
for path in abs_paths:
# Load the strategy self.custom_strategy = self._search_strategy(path, strategy_name)
self.custom_strategy = self._load_class(path + strategy_name) if self.custom_strategy:
self.logger.info('Using resolved strategy %s from \'%s\'', strategy_name, path)
return None
raise ImportError('not found')
# Fallback to the default strategy # Fallback to the default strategy
except (ImportError, TypeError) as error: except (ImportError, TypeError) as error:
self.logger.error( self.logger.error(
"Impossible to load Strategy 'user_data/strategies/%s.py'. This file does not exist" "Impossible to load Strategy '%s'. This class does not exist"
" or contains Python code errors", " or contains Python code errors",
strategy_name strategy_name
) )
@ -100,50 +99,44 @@ class StrategyResolver(object):
error error
) )
def _load_class(self, filename: str) -> IStrategy:
"""
Import a strategy as a module
:param filename: path to the strategy (path from freqtrade/strategy/)
:return: return the strategy class
"""
module = importlib.import_module(filename, __package__)
custom_strategy = getattr(module, module.class_name)
self.logger.info("Load strategy class: %s (%s.py)", module.class_name, filename)
return custom_strategy()
@staticmethod @staticmethod
def _sanitize_module_name(filename: str) -> str: def _get_valid_strategies(module_path: str, strategy_name: str) -> Optional[Type[IStrategy]]:
""" """
Remove any extension from filename Returns a list of all possible strategies for the given module_path
:param filename: filename to sanatize :param module_path: absolute path to the module
:return: return the filename without extensions :param strategy_name: Class name of the strategy
:return: Tuple with (name, class) or None
""" """
filename = os.path.basename(filename)
filename = os.path.splitext(filename)[0]
return filename
@staticmethod # Generate spec based on absolute path
def _search_strategy(filename: str) -> str: spec = importlib.util.spec_from_file_location('user_data.strategies', module_path)
""" module = importlib.util.module_from_spec(spec)
Search for the Strategy file in different folder spec.loader.exec_module(module)
1. search into the user_data/strategies folder
2. search into the freqtrade/strategy folder
3. if nothing found, return None
:param strategy_name: module name to search
:return: module path where is the strategy
"""
pwd = os.path.dirname(os.path.realpath(__file__)) + '/'
user_data = os.path.join(pwd, '..', '..', 'user_data', 'strategies', filename + '.py')
strategy_folder = os.path.join(pwd, filename + '.py')
path = None valid_strategies_gen = (
if os.path.isfile(user_data): obj for name, obj in inspect.getmembers(module, inspect.isclass)
path = 'user_data.strategies.' if strategy_name == name and IStrategy in obj.__bases__
elif os.path.isfile(strategy_folder): )
path = '.' return next(valid_strategies_gen, None)
return path def _search_strategy(self, directory: str, strategy_name: str) -> Optional[IStrategy]:
"""
Search for the strategy_name in the given directory
:param directory: relative or absolute directory path
:return: name of the strategy class
"""
self.logger.debug('Searching for strategy %s in \'%s\'', strategy_name, directory)
for entry in os.listdir(directory):
# Only consider python files
if not entry.endswith('.py'):
self.logger.debug('Ignoring %s', entry)
continue
strategy = StrategyResolver._get_valid_strategies(
os.path.abspath(os.path.join(directory, entry)), strategy_name
)
if strategy:
return strategy()
return None
def populate_indicators(self, dataframe: DataFrame) -> DataFrame: def populate_indicators(self, dataframe: DataFrame) -> DataFrame:
""" """

View File

@ -174,7 +174,7 @@ def test_setup_configuration_without_arguments(mocker, default_conf, caplog) ->
args = [ args = [
'--config', 'config.json', '--config', 'config.json',
'--strategy', 'default_strategy', '--strategy', 'DefaultStrategy',
'backtesting' 'backtesting'
] ]
@ -215,7 +215,7 @@ def test_setup_configuration_with_arguments(mocker, default_conf, caplog) -> Non
args = [ args = [
'--config', 'config.json', '--config', 'config.json',
'--strategy', 'default_strategy', '--strategy', 'DefaultStrategy',
'--datadir', '/foo/bar', '--datadir', '/foo/bar',
'backtesting', 'backtesting',
'--ticker-interval', '1', '--ticker-interval', '1',
@ -277,7 +277,7 @@ def test_start(mocker, default_conf, caplog) -> None:
)) ))
args = [ args = [
'--config', 'config.json', '--config', 'config.json',
'--strategy', 'default_strategy', '--strategy', 'DefaultStrategy',
'backtesting' 'backtesting'
] ]
args = get_args(args) args = get_args(args)
@ -498,7 +498,7 @@ def test_backtest_ticks(default_conf):
def test_backtest_clash_buy_sell(default_conf): def test_backtest_clash_buy_sell(default_conf):
# Override the default buy trend function in our default_strategy # Override the default buy trend function in our DefaultStrategy
def fun(dataframe=None): def fun(dataframe=None):
buy_value = 1 buy_value = 1
sell_value = 1 sell_value = 1
@ -510,7 +510,7 @@ def test_backtest_clash_buy_sell(default_conf):
def test_backtest_only_sell(default_conf): def test_backtest_only_sell(default_conf):
# Override the default buy trend function in our default_strategy # Override the default buy trend function in our DefaultStrategy
def fun(dataframe=None): def fun(dataframe=None):
buy_value = 0 buy_value = 0
sell_value = 1 sell_value = 1
@ -578,12 +578,12 @@ def test_backtest_start_live(default_conf, mocker, caplog):
args.live = True args.live = True
args.datadir = None args.datadir = None
args.export = None args.export = None
args.strategy = 'default_strategy' args.strategy = 'DefaultStrategy'
args.timerange = '-100' # needed due to MagicMock malleability args.timerange = '-100' # needed due to MagicMock malleability
args = [ args = [
'--config', 'config.json', '--config', 'config.json',
'--strategy', 'default_strategy', '--strategy', 'DefaultStrategy',
'backtesting', 'backtesting',
'--ticker-interval', '1', '--ticker-interval', '1',
'--live', '--live',

View File

@ -57,12 +57,12 @@ def test_start(mocker, default_conf, caplog) -> None:
)) ))
args = [ args = [
'--config', 'config.json', '--config', 'config.json',
'--strategy', 'default_strategy', '--strategy', 'DefaultStrategy',
'hyperopt', 'hyperopt',
'--epochs', '5' '--epochs', '5'
] ]
args = get_args(args) args = get_args(args)
StrategyResolver({'strategy': 'default_strategy'}) StrategyResolver({'strategy': 'DefaultStrategy'})
start(args) start(args)
import pprint import pprint
@ -80,7 +80,7 @@ def test_loss_calculation_prefer_correct_trade_count() -> None:
Test Hyperopt.calculate_loss() Test Hyperopt.calculate_loss()
""" """
hyperopt = _HYPEROPT hyperopt = _HYPEROPT
StrategyResolver({'strategy': 'default_strategy'}) StrategyResolver({'strategy': 'DefaultStrategy'})
correct = hyperopt.calculate_loss(1, hyperopt.target_trades, 20) correct = hyperopt.calculate_loss(1, hyperopt.target_trades, 20)
over = hyperopt.calculate_loss(1, hyperopt.target_trades + 100, 20) over = hyperopt.calculate_loss(1, hyperopt.target_trades + 100, 20)
@ -171,7 +171,7 @@ def test_fmin_best_results(mocker, default_conf, caplog) -> None:
mocker.patch('freqtrade.optimize.hyperopt.hyperopt_optimize_conf', return_value=conf) mocker.patch('freqtrade.optimize.hyperopt.hyperopt_optimize_conf', return_value=conf)
mocker.patch('freqtrade.logger.Logger.set_format', MagicMock()) mocker.patch('freqtrade.logger.Logger.set_format', MagicMock())
StrategyResolver({'strategy': 'default_strategy'}) StrategyResolver({'strategy': 'DefaultStrategy'})
hyperopt = Hyperopt(conf) hyperopt = Hyperopt(conf)
hyperopt.trials = create_trials(mocker) hyperopt.trials = create_trials(mocker)
hyperopt.tickerdata_to_dataframe = MagicMock() hyperopt.tickerdata_to_dataframe = MagicMock()
@ -215,7 +215,7 @@ def test_fmin_throw_value_error(mocker, default_conf, caplog) -> None:
conf.update({'spaces': 'all'}) conf.update({'spaces': 'all'})
mocker.patch('freqtrade.optimize.hyperopt.hyperopt_optimize_conf', return_value=conf) mocker.patch('freqtrade.optimize.hyperopt.hyperopt_optimize_conf', return_value=conf)
mocker.patch('freqtrade.logger.Logger.set_format', MagicMock()) mocker.patch('freqtrade.logger.Logger.set_format', MagicMock())
StrategyResolver({'strategy': 'default_strategy'}) StrategyResolver({'strategy': 'DefaultStrategy'})
hyperopt = Hyperopt(conf) hyperopt = Hyperopt(conf)
hyperopt.trials = create_trials(mocker) hyperopt.trials = create_trials(mocker)
hyperopt.tickerdata_to_dataframe = MagicMock() hyperopt.tickerdata_to_dataframe = MagicMock()
@ -258,7 +258,7 @@ def test_resuming_previous_hyperopt_results_succeeds(mocker, default_conf) -> No
mocker.patch('freqtrade.optimize.hyperopt.hyperopt_optimize_conf', return_value=conf) mocker.patch('freqtrade.optimize.hyperopt.hyperopt_optimize_conf', return_value=conf)
mocker.patch('freqtrade.logger.Logger.set_format', MagicMock()) mocker.patch('freqtrade.logger.Logger.set_format', MagicMock())
StrategyResolver({'strategy': 'default_strategy'}) StrategyResolver({'strategy': 'DefaultStrategy'})
hyperopt = Hyperopt(conf) hyperopt = Hyperopt(conf)
hyperopt.trials = trials hyperopt.trials = trials
hyperopt.tickerdata_to_dataframe = MagicMock() hyperopt.tickerdata_to_dataframe = MagicMock()

View File

@ -4,7 +4,7 @@ import pytest
from pandas import DataFrame from pandas import DataFrame
from freqtrade.analyze import Analyze from freqtrade.analyze import Analyze
from freqtrade.strategy.default_strategy import DefaultStrategy, class_name from freqtrade.strategy.default_strategy import DefaultStrategy
@pytest.fixture @pytest.fixture
@ -13,10 +13,6 @@ def result():
return Analyze.parse_ticker_dataframe(json.load(data_file)) return Analyze.parse_ticker_dataframe(json.load(data_file))
def test_default_strategy_class_name():
assert class_name == DefaultStrategy.__name__
def test_default_strategy_structure(): def test_default_strategy_structure():
assert hasattr(DefaultStrategy, 'minimal_roi') assert hasattr(DefaultStrategy, 'minimal_roi')
assert hasattr(DefaultStrategy, 'stoploss') assert hasattr(DefaultStrategy, 'stoploss')

View File

@ -5,20 +5,10 @@ import logging
from freqtrade.strategy.resolver import StrategyResolver from freqtrade.strategy.resolver import StrategyResolver
def test_sanitize_module_name():
assert StrategyResolver._sanitize_module_name('default_strategy') == 'default_strategy'
assert StrategyResolver._sanitize_module_name('default_strategy.py') == 'default_strategy'
assert StrategyResolver._sanitize_module_name('../default_strategy.py') == 'default_strategy'
assert StrategyResolver._sanitize_module_name('../default_strategy') == 'default_strategy'
assert StrategyResolver._sanitize_module_name('.default_strategy') == '.default_strategy'
assert StrategyResolver._sanitize_module_name('foo-bar') == 'foo-bar'
assert StrategyResolver._sanitize_module_name('foo/bar') == 'bar'
def test_search_strategy(): def test_search_strategy():
assert StrategyResolver._search_strategy('default_strategy') == '.' assert StrategyResolver._search_strategy('DefaultStrategy') == '.'
assert StrategyResolver._search_strategy('test_strategy') == 'user_data.strategies.' assert StrategyResolver._search_strategy('TestStrategy') == 'user_data.strategies.'
assert StrategyResolver._search_strategy('super_duper') is None assert StrategyResolver._search_strategy('NotFoundStrategy') is None
def test_strategy_structure(): def test_strategy_structure():
@ -32,7 +22,7 @@ def test_load_strategy(result):
strategy.logger = logging.getLogger(__name__) strategy.logger = logging.getLogger(__name__)
assert not hasattr(StrategyResolver, 'custom_strategy') assert not hasattr(StrategyResolver, 'custom_strategy')
strategy._load_strategy('test_strategy') strategy._load_strategy('TestStrategy')
assert not hasattr(StrategyResolver, 'custom_strategy') assert not hasattr(StrategyResolver, 'custom_strategy')
@ -47,13 +37,13 @@ def test_load_not_found_strategy(caplog):
assert not hasattr(StrategyResolver, 'custom_strategy') assert not hasattr(StrategyResolver, 'custom_strategy')
strategy._load_strategy('NotFoundStrategy') strategy._load_strategy('NotFoundStrategy')
error_msg = "Impossible to load Strategy 'user_data/strategies/{}.py'. This file does not " \ error_msg = "Impossible to load Strategy '{}'. This class does not " \
"exist or contains Python code errors".format('NotFoundStrategy') "exist or contains Python code errors".format('NotFoundStrategy')
assert ('test_strategy', logging.ERROR, error_msg) in caplog.record_tuples assert ('test_strategy', logging.ERROR, error_msg) in caplog.record_tuples
def test_strategy(result): def test_strategy(result):
strategy = StrategyResolver({'strategy': 'default_strategy'}) strategy = StrategyResolver({'strategy': 'DefaultStrategy'})
assert hasattr(strategy.custom_strategy, 'minimal_roi') assert hasattr(strategy.custom_strategy, 'minimal_roi')
assert strategy.minimal_roi[0] == 0.04 assert strategy.minimal_roi[0] == 0.04
@ -76,7 +66,7 @@ def test_strategy(result):
def test_strategy_override_minimal_roi(caplog): def test_strategy_override_minimal_roi(caplog):
caplog.set_level(logging.INFO) caplog.set_level(logging.INFO)
config = { config = {
'strategy': 'default_strategy', 'strategy': 'DefaultStrategy',
'minimal_roi': { 'minimal_roi': {
"0": 0.5 "0": 0.5
} }
@ -94,7 +84,7 @@ def test_strategy_override_minimal_roi(caplog):
def test_strategy_override_stoploss(caplog): def test_strategy_override_stoploss(caplog):
caplog.set_level(logging.INFO) caplog.set_level(logging.INFO)
config = { config = {
'strategy': 'default_strategy', 'strategy': 'DefaultStrategy',
'stoploss': -0.5 'stoploss': -0.5
} }
strategy = StrategyResolver(config) strategy = StrategyResolver(config)
@ -111,7 +101,7 @@ def test_strategy_override_ticker_interval(caplog):
caplog.set_level(logging.INFO) caplog.set_level(logging.INFO)
config = { config = {
'strategy': 'default_strategy', 'strategy': 'DefaultStrategy',
'ticker_interval': 60 'ticker_interval': 60
} }
strategy = StrategyResolver(config) strategy = StrategyResolver(config)
@ -134,7 +124,7 @@ def test_strategy_fallback_default_strategy():
def test_strategy_singleton(): def test_strategy_singleton():
strategy1 = StrategyResolver({'strategy': 'default_strategy'}) strategy1 = StrategyResolver({'strategy': 'DefaultStrategy'})
assert hasattr(strategy1.custom_strategy, 'minimal_roi') assert hasattr(strategy1.custom_strategy, 'minimal_roi')
assert strategy1.minimal_roi[0] == 0.04 assert strategy1.minimal_roi[0] == 0.04

View File

@ -16,7 +16,7 @@ from freqtrade.optimize.__init__ import load_tickerdata_file
from freqtrade.tests.conftest import log_has from freqtrade.tests.conftest import log_has
# Avoid to reinit the same object again and again # Avoid to reinit the same object again and again
_ANALYZE = Analyze({'strategy': 'default_strategy'}) _ANALYZE = Analyze({'strategy': 'DefaultStrategy'})
def test_signaltype_object() -> None: def test_signaltype_object() -> None:

View File

@ -99,7 +99,7 @@ def test_load_config(default_conf, mocker) -> None:
validated_conf = configuration.load_config() validated_conf = configuration.load_config()
assert 'strategy' in validated_conf assert 'strategy' in validated_conf
assert validated_conf['strategy'] == 'default_strategy' assert validated_conf['strategy'] == 'DefaultStrategy'
assert 'dynamic_whitelist' not in validated_conf assert 'dynamic_whitelist' not in validated_conf
assert 'dry_run_db' not in validated_conf assert 'dry_run_db' not in validated_conf
@ -114,7 +114,7 @@ def test_load_config_with_params(default_conf, mocker) -> None:
args = [ args = [
'--dynamic-whitelist', '10', '--dynamic-whitelist', '10',
'--strategy', 'test_strategy', '--strategy', 'TestStrategy',
'--dry-run-db' '--dry-run-db'
] ]
args = Arguments(args, '').get_parsed_arg() args = Arguments(args, '').get_parsed_arg()
@ -125,7 +125,7 @@ def test_load_config_with_params(default_conf, mocker) -> None:
assert 'dynamic_whitelist' in validated_conf assert 'dynamic_whitelist' in validated_conf
assert validated_conf['dynamic_whitelist'] == 10 assert validated_conf['dynamic_whitelist'] == 10
assert 'strategy' in validated_conf assert 'strategy' in validated_conf
assert validated_conf['strategy'] == 'test_strategy' assert validated_conf['strategy'] == 'TestStrategy'
assert 'dry_run_db' in validated_conf assert 'dry_run_db' in validated_conf
assert validated_conf['dry_run_db'] is True assert validated_conf['dry_run_db'] is True
@ -140,7 +140,7 @@ def test_show_info(default_conf, mocker, caplog) -> None:
args = [ args = [
'--dynamic-whitelist', '10', '--dynamic-whitelist', '10',
'--strategy', 'test_strategy', '--strategy', 'TestStrategy',
'--dry-run-db' '--dry-run-db'
] ]
args = Arguments(args, '').get_parsed_arg() args = Arguments(args, '').get_parsed_arg()
@ -184,7 +184,7 @@ def test_setup_configuration_without_arguments(mocker, default_conf, caplog) ->
args = [ args = [
'--config', 'config.json', '--config', 'config.json',
'--strategy', 'default_strategy', '--strategy', 'DefaultStrategy',
'backtesting' 'backtesting'
] ]
@ -228,7 +228,7 @@ def test_setup_configuration_with_arguments(mocker, default_conf, caplog) -> Non
args = [ args = [
'--config', 'config.json', '--config', 'config.json',
'--strategy', 'default_strategy', '--strategy', 'DefaultStrategy',
'--datadir', '/foo/bar', '--datadir', '/foo/bar',
'backtesting', 'backtesting',
'--ticker-interval', '1', '--ticker-interval', '1',

View File

@ -15,19 +15,19 @@ def load_dataframe_pair(pairs):
assert isinstance(pairs[0], str) assert isinstance(pairs[0], str)
dataframe = ld[pairs[0]] dataframe = ld[pairs[0]]
analyze = Analyze({'strategy': 'default_strategy'}) analyze = Analyze({'strategy': 'DefaultStrategy'})
dataframe = analyze.analyze_ticker(dataframe) dataframe = analyze.analyze_ticker(dataframe)
return dataframe return dataframe
def test_dataframe_load(): def test_dataframe_load():
StrategyResolver({'strategy': 'default_strategy'}) StrategyResolver({'strategy': 'DefaultStrategy'})
dataframe = load_dataframe_pair(_pairs) dataframe = load_dataframe_pair(_pairs)
assert isinstance(dataframe, pandas.core.frame.DataFrame) assert isinstance(dataframe, pandas.core.frame.DataFrame)
def test_dataframe_columns_exists(): def test_dataframe_columns_exists():
StrategyResolver({'strategy': 'default_strategy'}) StrategyResolver({'strategy': 'DefaultStrategy'})
dataframe = load_dataframe_pair(_pairs) dataframe = load_dataframe_pair(_pairs)
assert 'high' in dataframe.columns assert 'high' in dataframe.columns
assert 'low' in dataframe.columns assert 'low' in dataframe.columns

View File

@ -10,10 +10,6 @@ import freqtrade.vendor.qtpylib.indicators as qtpylib
import numpy # noqa import numpy # noqa
# Update this variable if you change the class name
class_name = 'TestStrategy'
# This class is a sample. Feel free to customize it. # This class is a sample. Feel free to customize it.
class TestStrategy(IStrategy): class TestStrategy(IStrategy):
""" """