v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -248,39 +248,55 @@ class DSLPreprocessor(ast.NodeTransformer):
# Step 3. Return the transformed tree
return combined_body
def check_early_exit(self, tree):
def check_early_exit(self, tree, kind):
"""
Checks if a given region or scope in the provided Python code has early exits.
"""
class EarlyExitChecker(ast.NodeVisitor):
def __init__(self):
def __init__(self, kind):
self.has_early_exit = False
self.early_exit_node = None
self.early_exit_type = None
self.kind = kind
self.loop_nest_level = 0
# Early exit is not allowed in any level of dynamic control flow
def visit_Return(self, node):
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "return"
def visit_Break(self, node):
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "break"
def visit_Continue(self, node):
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "continue"
def visit_Raise(self, node):
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "raise"
checker = EarlyExitChecker()
checker.visit(tree)
def visit_Break(self, node):
# For break/continue in inner loops, we don't consider it as early exit
if self.loop_nest_level == 0 and self.kind != "if":
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "break"
def visit_Continue(self, node):
if self.loop_nest_level == 0 and self.kind != "if":
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "continue"
def visit_For(self, node):
self.loop_nest_level += 1
self.generic_visit(node)
self.loop_nest_level -= 1
def visit_While(self, node):
self.loop_nest_level += 1
self.generic_visit(node)
self.loop_nest_level -= 1
checker = EarlyExitChecker(kind)
checker.generic_visit(tree)
if not checker.has_early_exit:
return
raise DSLAstPreprocessorError(
@ -591,7 +607,7 @@ class DSLPreprocessor(ast.NodeTransformer):
if self.is_supported_range_call(node):
constexpr_val = self.get_loop_constexpr(node)
# Check for early exit and raise exception
self.check_early_exit(node)
self.check_early_exit(node, "for")
start, stop, step = self.extract_range_args(node.iter)
unroll, unroll_full = self.extract_unroll_args(node.iter)
used_args, iter_args, flat_args = self.analyze_region_variables(
@ -659,37 +675,42 @@ class DSLPreprocessor(ast.NodeTransformer):
snippet=ast.unparse(node),
)
test = ast.BoolOp(
op=ast.And(),
values=[
ast.Compare(
left=ast.Call(
func=ast.Name(id="type", ctx=ast.Load()),
args=[node.values[0]],
keywords=[],
def short_circuit_eval(value, short_circuit_value):
return ast.BoolOp(
op=ast.And(),
values=[
ast.Compare(
left=ast.Call(
func=ast.Name(id="type", ctx=ast.Load()),
args=[value],
keywords=[],
),
ops=[ast.Eq()],
comparators=[ast.Name(id="bool", ctx=ast.Load())],
),
ops=[ast.Eq()],
comparators=[ast.Name(id="bool", ctx=ast.Load())],
),
ast.Compare(
left=node.values[0],
ops=[ast.Eq()],
comparators=[short_circuit_value],
),
],
)
return ast.copy_location(
ast.IfExp(
ast.Compare(
left=value,
ops=[ast.Eq()],
comparators=[short_circuit_value],
),
],
)
lhs = node.values[0]
for i in range(1, len(node.values)):
test = short_circuit_eval(lhs, short_circuit_value)
lhs = ast.IfExp(
test=test,
body=node.values[0],
body=lhs,
orelse=ast.Call(
func=helper_func,
args=node.values,
args=[lhs, node.values[i]],
keywords=[],
),
),
node,
)
)
return ast.copy_location(lhs, node)
def visit_UnaryOp(self, node):
# Visit child nodes first
@ -916,7 +937,7 @@ class DSLPreprocessor(ast.NodeTransformer):
return node
# Check for early exit and raise exception
self.check_early_exit(node)
self.check_early_exit(node, "while")
used_args, yield_args, flat_args = self.analyze_region_variables(
node, active_symbols
@ -1021,7 +1042,7 @@ class DSLPreprocessor(ast.NodeTransformer):
return node
# Check for early exit and raise exception
self.check_early_exit(node)
self.check_early_exit(node, "if")
used_args, yield_args, flat_args = self.analyze_region_variables(
node, active_symbols