Adding tests

added more code inside NameUpdater to grab more variables.
This commit is contained in:
hippocritical
2023-01-01 18:57:38 +01:00
parent a51e44eea3
commit 762dd4f024
2 changed files with 42 additions and 9 deletions

View File

@@ -128,6 +128,15 @@ class NameUpdater(ast.NodeTransformer):
for key in field_value.keys:
self.visit(key)
def check_args(self, node):
if isinstance(node.args, ast.arguments):
self.check_args(node.args)
if hasattr(node, "args"):
if isinstance(node.args, list):
for arg in node.args:
arg.arg = StrategyUpdater.name_mapping[arg.arg]
return node
def visit_Name(self, node):
# if the name is in the mapping, update it
if node.id in StrategyUpdater.name_mapping:
@@ -152,6 +161,8 @@ class NameUpdater(ast.NodeTransformer):
# if the function name is in the mapping, update it
if node.name in StrategyUpdater.function_mapping:
node.name = StrategyUpdater.function_mapping[node.name]
if hasattr(node, "args"):
self.check_args(node)
return self.generic_visit(node)
def visit_Attribute(self, node):
@@ -193,10 +204,10 @@ class NameUpdater(ast.NodeTransformer):
# Replace the slice attributes with the values from rename_dict
node.slice.value = StrategyUpdater.rename_dict[node.slice.value]
if hasattr(node.slice, "elts"):
self.visit_slice_elts(node.slice.elts)
self.visit_elts(node.slice.elts)
if hasattr(node.slice, "value"):
if hasattr(node.slice.value, "elts"):
self.visit_slice_elts(node.slice.value.elts)
self.visit_elts(node.slice.value.elts)
# Check if the target is a Subscript object with a "value" attribute
# if isinstance(target, ast.Subscript) and hasattr(target.value, "attr"):
# if target.value.attr == "loc":
@@ -204,12 +215,25 @@ class NameUpdater(ast.NodeTransformer):
return node
# elts can have elts (technically recursively)
def visit_slice_elts(self, elts):
for elt in elts:
if isinstance(elt, ast.Constant) and elt.value in StrategyUpdater.rename_dict:
elt.value = StrategyUpdater.rename_dict[elt.value]
elif hasattr(elt, "elts"):
self.visit_slice_elts(elt.elts)
def visit_elts(self, elts):
if isinstance(elts, list):
for elt in elts:
self.visit_elt(elt)
else:
self.visit_elt(elts)
# sub function again needed since the structure itself is highly flexible ...
def visit_elt(self, elt):
if isinstance(elt, ast.Constant) and elt.value in StrategyUpdater.rename_dict:
elt.value = StrategyUpdater.rename_dict[elt.value]
if hasattr(elt, "elts"):
self.visit_elts(elt.elts)
if hasattr(elt, "args"):
if isinstance(elt.args, ast.arguments):
self.visit_elts(elt.args)
else:
for arg in elt.args:
self.visit_elts(arg)
def visit_Constant(self, node):
# do not update the names in import statements