Adding tests
This commit is contained in:
@@ -70,7 +70,6 @@ class StrategyUpdater:
|
||||
|
||||
# update the code
|
||||
new_code = StrategyUpdater.update_code(self, old_code)
|
||||
|
||||
# write the modified code to the destination folder
|
||||
with open(source_file, 'w') as f:
|
||||
f.write(new_code)
|
||||
@@ -82,8 +81,7 @@ class StrategyUpdater:
|
||||
tree = ast.parse(code)
|
||||
|
||||
# use the AST to update the code
|
||||
updated_code = self.modify_ast(
|
||||
tree)
|
||||
updated_code = self.modify_ast(self, tree)
|
||||
|
||||
# return the modified code without executing it
|
||||
return updated_code
|
||||
@@ -107,18 +105,8 @@ class NameUpdater(ast.NodeTransformer):
|
||||
# traverse the AST recursively by calling the visitor method for each child node
|
||||
if hasattr(node, "_fields"):
|
||||
for field_name, field_value in ast.iter_fields(node):
|
||||
if not isinstance(field_value, ast.AST):
|
||||
continue # to avoid unnecessary loops
|
||||
self.visit(field_value)
|
||||
self.generic_visit(field_value)
|
||||
self.check_fields(field_value)
|
||||
self.check_strategy_and_config_settings(node, field_value)
|
||||
# add this check to handle the case where field_value is a slice
|
||||
if isinstance(field_value, ast.Slice):
|
||||
self.visit(field_value)
|
||||
# add this check to handle the case where field_value is a target
|
||||
if isinstance(field_value, ast.expr_context):
|
||||
self.visit(field_value)
|
||||
self.check_fields(field_value)
|
||||
|
||||
def check_fields(self, field_value):
|
||||
if isinstance(field_value, list):
|
||||
@@ -139,10 +127,6 @@ class NameUpdater(ast.NodeTransformer):
|
||||
target.id == "unfilledtimeout"):
|
||||
for key in field_value.keys:
|
||||
self.visit(key)
|
||||
# Check if the target is a Subscript object with a "value" attribute
|
||||
if isinstance(target, ast.Subscript) and hasattr(target.value, "attr"):
|
||||
if target.value.attr == "loc":
|
||||
self.visit(target)
|
||||
|
||||
def visit_Name(self, node):
|
||||
# if the name is in the mapping, update it
|
||||
@@ -154,11 +138,14 @@ class NameUpdater(ast.NodeTransformer):
|
||||
# do not update the names in import statements
|
||||
return node
|
||||
|
||||
# This function is currently never successfully triggered
|
||||
# since freqtrade currently only allows valid code to be processed.
|
||||
# The module .hyper does not anymore exist and by that fails to even
|
||||
# reach this function to be updated currently.
|
||||
def visit_ImportFrom(self, node):
|
||||
# do not update the names in import statements
|
||||
if hasattr(node, "module"):
|
||||
if node.module == "freqtrade.strategy.hyper":
|
||||
node.module = "freqtrade.strategy"
|
||||
# if hasattr(node, "module"):
|
||||
# if node.module == "freqtrade.strategy.hyper":
|
||||
# node.module = "freqtrade.strategy"
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
@@ -210,6 +197,10 @@ class NameUpdater(ast.NodeTransformer):
|
||||
if hasattr(node.slice, "value"):
|
||||
if hasattr(node.slice.value, "elts"):
|
||||
self.visit_slice_elts(node.slice.value.elts)
|
||||
# Check if the target is a Subscript object with a "value" attribute
|
||||
# if isinstance(target, ast.Subscript) and hasattr(target.value, "attr"):
|
||||
# if target.value.attr == "loc":
|
||||
# self.visit(target)
|
||||
return node
|
||||
|
||||
# elts can have elts (technically recursively)
|
||||
|
Reference in New Issue
Block a user