Simplify implementation of "check_override" by extracting it to function

This commit is contained in:
Matthias 2022-03-12 10:57:03 +01:00
parent 6946203a7c
commit fe62a71f4c
2 changed files with 19 additions and 7 deletions

View File

@ -222,18 +222,22 @@ class StrategyResolver(IResolver):
if strategy:
if strategy.config.get('trading_mode', TradingMode.SPOT) != TradingMode.SPOT:
# Require new method
if type(strategy).populate_entry_trend == IStrategy.populate_entry_trend:
if check_override(strategy, IStrategy, 'populate_entry_trend'):
raise OperationalException("`populate_entry_trend` must be implemented.")
if type(strategy).populate_exit_trend == IStrategy.populate_exit_trend:
if check_override(strategy, IStrategy, 'populate_exit_trend'):
raise OperationalException("`populate_exit_trend` must be implemented.")
else:
# TODO: Implementing buy_trend and sell_trend should raise a deprecation.
if (type(strategy).populate_buy_trend == IStrategy.populate_buy_trend
and type(strategy).populate_entry_trend == IStrategy.populate_entry_trend):
# TODO: Implementing buy_trend and sell_trend should show a deprecation warning
if (
check_override(strategy, IStrategy, 'populate_buy_trend')
and check_override(strategy, IStrategy, 'populate_entry_trend')
):
raise OperationalException(
"`populate_entry_trend` or `populate_buy_trend` must be implemented.")
if (type(strategy).populate_sell_trend == IStrategy.populate_sell_trend
and type(strategy).populate_exit_trend == IStrategy.populate_exit_trend):
if (
check_override(strategy, IStrategy, 'populate_sell_trend')
and check_override(strategy, IStrategy, 'populate_exit_trend')
):
raise OperationalException(
"`populate_exit_trend` or `populate_sell_trend` must be implemented.")
@ -253,3 +257,10 @@ class StrategyResolver(IResolver):
f"Impossible to load Strategy '{strategy_name}'. This class does not exist "
"or contains Python code errors."
)
def check_override(object, parentclass, attribute):
"""
Checks if a object overrides the parent class attribute.
"""
return getattr(type(object), attribute) == getattr(parentclass, attribute)

View File

@ -417,6 +417,7 @@ def test_missing_implements(result, default_conf):
match=r"`populate_entry_trend` must be implemented.*"):
StrategyResolver.load_strategy(default_conf)
@pytest.mark.filterwarnings("ignore:deprecated")
def test_call_deprecated_function(result, default_conf, caplog):
default_location = Path(__file__).parent / "strats"