diff --git a/freqtrade/strategy/strategyupdater.py b/freqtrade/strategy/strategyupdater.py index 396d57a8a..5b4bb8be0 100644 --- a/freqtrade/strategy/strategyupdater.py +++ b/freqtrade/strategy/strategyupdater.py @@ -128,6 +128,15 @@ class NameUpdater(ast.NodeTransformer): for key in field_value.keys: self.visit(key) + def check_args(self, node): + if isinstance(node.args, ast.arguments): + self.check_args(node.args) + if hasattr(node, "args"): + if isinstance(node.args, list): + for arg in node.args: + arg.arg = StrategyUpdater.name_mapping[arg.arg] + return node + def visit_Name(self, node): # if the name is in the mapping, update it if node.id in StrategyUpdater.name_mapping: @@ -152,6 +161,8 @@ class NameUpdater(ast.NodeTransformer): # if the function name is in the mapping, update it if node.name in StrategyUpdater.function_mapping: node.name = StrategyUpdater.function_mapping[node.name] + if hasattr(node, "args"): + self.check_args(node) return self.generic_visit(node) def visit_Attribute(self, node): @@ -193,10 +204,10 @@ class NameUpdater(ast.NodeTransformer): # Replace the slice attributes with the values from rename_dict node.slice.value = StrategyUpdater.rename_dict[node.slice.value] if hasattr(node.slice, "elts"): - self.visit_slice_elts(node.slice.elts) + self.visit_elts(node.slice.elts) if hasattr(node.slice, "value"): if hasattr(node.slice.value, "elts"): - self.visit_slice_elts(node.slice.value.elts) + self.visit_elts(node.slice.value.elts) # Check if the target is a Subscript object with a "value" attribute # if isinstance(target, ast.Subscript) and hasattr(target.value, "attr"): # if target.value.attr == "loc": @@ -204,12 +215,25 @@ class NameUpdater(ast.NodeTransformer): return node # elts can have elts (technically recursively) - def visit_slice_elts(self, elts): - for elt in elts: - if isinstance(elt, ast.Constant) and elt.value in StrategyUpdater.rename_dict: - elt.value = StrategyUpdater.rename_dict[elt.value] - elif hasattr(elt, "elts"): - self.visit_slice_elts(elt.elts) + def visit_elts(self, elts): + if isinstance(elts, list): + for elt in elts: + self.visit_elt(elt) + else: + self.visit_elt(elts) + + # sub function again needed since the structure itself is highly flexible ... + def visit_elt(self, elt): + if isinstance(elt, ast.Constant) and elt.value in StrategyUpdater.rename_dict: + elt.value = StrategyUpdater.rename_dict[elt.value] + if hasattr(elt, "elts"): + self.visit_elts(elt.elts) + if hasattr(elt, "args"): + if isinstance(elt.args, ast.arguments): + self.visit_elts(elt.args) + else: + for arg in elt.args: + self.visit_elts(arg) def visit_Constant(self, node): # do not update the names in import statements diff --git a/tests/test_strategy_updater.py b/tests/test_strategy_updater.py index 6997abdce..cf18fcc25 100644 --- a/tests/test_strategy_updater.py +++ b/tests/test_strategy_updater.py @@ -4,6 +4,11 @@ from freqtrade.strategy.strategyupdater import StrategyUpdater def test_strategy_updater(default_conf, caplog) -> None: + modified_code5 = StrategyUpdater.update_code(StrategyUpdater, """ +def confirm_trade_exit(sell_reason: str): + pass +""") + modified_code1 = StrategyUpdater.update_code(StrategyUpdater, """ class testClass(IStrategy): def populate_buy_trend(): @@ -31,9 +36,10 @@ ignore_roi_if_buy_signal = True forcebuy_enable = True """) modified_code4 = StrategyUpdater.update_code(StrategyUpdater, """ -dataframe.loc[reduce(lambda x, y: x & y, conditions), 'buy'] = 1 +dataframe.loc[reduce(lambda x, y: x & y, conditions), ["buy", "buy_tag"]] = (1, "buy_signal_1") dataframe.loc[reduce(lambda x, y: x & y, conditions), 'sell'] = 1 """) + assert "populate_entry_trend" in modified_code1 assert "populate_exit_trend" in modified_code1 assert "check_entry_timeout" in modified_code1 @@ -54,3 +60,6 @@ dataframe.loc[reduce(lambda x, y: x & y, conditions), 'sell'] = 1 assert "enter_long" in modified_code4 assert "exit_long" in modified_code4 + assert "enter_tag" in modified_code4 + + assert "exit_reason" in modified_code5