v4.0 update. (#2371)
This commit is contained in:
@ -248,39 +248,55 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
# Step 3. Return the transformed tree
|
||||
return combined_body
|
||||
|
||||
def check_early_exit(self, tree):
|
||||
def check_early_exit(self, tree, kind):
|
||||
"""
|
||||
Checks if a given region or scope in the provided Python code has early exits.
|
||||
"""
|
||||
|
||||
class EarlyExitChecker(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
def __init__(self, kind):
|
||||
self.has_early_exit = False
|
||||
self.early_exit_node = None
|
||||
self.early_exit_type = None
|
||||
self.kind = kind
|
||||
self.loop_nest_level = 0
|
||||
|
||||
# Early exit is not allowed in any level of dynamic control flow
|
||||
def visit_Return(self, node):
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "return"
|
||||
|
||||
def visit_Break(self, node):
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "break"
|
||||
|
||||
def visit_Continue(self, node):
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "continue"
|
||||
|
||||
def visit_Raise(self, node):
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "raise"
|
||||
|
||||
checker = EarlyExitChecker()
|
||||
checker.visit(tree)
|
||||
def visit_Break(self, node):
|
||||
# For break/continue in inner loops, we don't consider it as early exit
|
||||
if self.loop_nest_level == 0 and self.kind != "if":
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "break"
|
||||
|
||||
def visit_Continue(self, node):
|
||||
if self.loop_nest_level == 0 and self.kind != "if":
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "continue"
|
||||
|
||||
def visit_For(self, node):
|
||||
self.loop_nest_level += 1
|
||||
self.generic_visit(node)
|
||||
self.loop_nest_level -= 1
|
||||
|
||||
def visit_While(self, node):
|
||||
self.loop_nest_level += 1
|
||||
self.generic_visit(node)
|
||||
self.loop_nest_level -= 1
|
||||
|
||||
checker = EarlyExitChecker(kind)
|
||||
checker.generic_visit(tree)
|
||||
if not checker.has_early_exit:
|
||||
return
|
||||
raise DSLAstPreprocessorError(
|
||||
@ -591,7 +607,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
if self.is_supported_range_call(node):
|
||||
constexpr_val = self.get_loop_constexpr(node)
|
||||
# Check for early exit and raise exception
|
||||
self.check_early_exit(node)
|
||||
self.check_early_exit(node, "for")
|
||||
start, stop, step = self.extract_range_args(node.iter)
|
||||
unroll, unroll_full = self.extract_unroll_args(node.iter)
|
||||
used_args, iter_args, flat_args = self.analyze_region_variables(
|
||||
@ -659,37 +675,42 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
snippet=ast.unparse(node),
|
||||
)
|
||||
|
||||
test = ast.BoolOp(
|
||||
op=ast.And(),
|
||||
values=[
|
||||
ast.Compare(
|
||||
left=ast.Call(
|
||||
func=ast.Name(id="type", ctx=ast.Load()),
|
||||
args=[node.values[0]],
|
||||
keywords=[],
|
||||
def short_circuit_eval(value, short_circuit_value):
|
||||
return ast.BoolOp(
|
||||
op=ast.And(),
|
||||
values=[
|
||||
ast.Compare(
|
||||
left=ast.Call(
|
||||
func=ast.Name(id="type", ctx=ast.Load()),
|
||||
args=[value],
|
||||
keywords=[],
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[ast.Name(id="bool", ctx=ast.Load())],
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[ast.Name(id="bool", ctx=ast.Load())],
|
||||
),
|
||||
ast.Compare(
|
||||
left=node.values[0],
|
||||
ops=[ast.Eq()],
|
||||
comparators=[short_circuit_value],
|
||||
),
|
||||
],
|
||||
)
|
||||
return ast.copy_location(
|
||||
ast.IfExp(
|
||||
ast.Compare(
|
||||
left=value,
|
||||
ops=[ast.Eq()],
|
||||
comparators=[short_circuit_value],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
lhs = node.values[0]
|
||||
|
||||
for i in range(1, len(node.values)):
|
||||
test = short_circuit_eval(lhs, short_circuit_value)
|
||||
lhs = ast.IfExp(
|
||||
test=test,
|
||||
body=node.values[0],
|
||||
body=lhs,
|
||||
orelse=ast.Call(
|
||||
func=helper_func,
|
||||
args=node.values,
|
||||
args=[lhs, node.values[i]],
|
||||
keywords=[],
|
||||
),
|
||||
),
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
return ast.copy_location(lhs, node)
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
# Visit child nodes first
|
||||
@ -916,7 +937,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
return node
|
||||
|
||||
# Check for early exit and raise exception
|
||||
self.check_early_exit(node)
|
||||
self.check_early_exit(node, "while")
|
||||
|
||||
used_args, yield_args, flat_args = self.analyze_region_variables(
|
||||
node, active_symbols
|
||||
@ -1021,7 +1042,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
return node
|
||||
|
||||
# Check for early exit and raise exception
|
||||
self.check_early_exit(node)
|
||||
self.check_early_exit(node, "if")
|
||||
|
||||
used_args, yield_args, flat_args = self.analyze_region_variables(
|
||||
node, active_symbols
|
||||
|
||||
Reference in New Issue
Block a user