1752 lines
62 KiB
Python
1752 lines
62 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module defines the `DSLPreprocessor` class, which acts as a Python preprocessor.
|
|
It uses Python's AST and rewrites specific Python statements such as `for` and `if-else`.
|
|
|
|
The preprocessor operates on the following constructs:
|
|
- `for` loops:
|
|
- Rewrites `for` loops with the `@loop_selector` decorator.
|
|
- Supports `range`, `range_dynamic` for loop iteration.
|
|
- `if-elif-else` statements:
|
|
- Rewrites conditional statements with the `@if_selector` decorator.
|
|
- Supports `dynamic_expr` and `const_expr` in the condition expressions.
|
|
|
|
Additionally, both `for` loops and `if-else` statements require `yield`
|
|
operation generation. The preprocessor handles this by:
|
|
- Using a `ScopeManager` to track symbols across different scopes during AST traversal.
|
|
- Identifying read-only, read-write, and active variables for DSL constructs.
|
|
- Generating `yield` operations for symbols that are classified as read-write or write.
|
|
|
|
It is designed to be generic and can handle `for` and `if` constructs from other dialects.
|
|
In such cases, the user's DSL should implement `@loop_selector` and `@if_selector`
|
|
to generate dialect-specific operations for `for` and `if` statements.
|
|
"""
|
|
|
|
import ast
|
|
import importlib
|
|
import inspect
|
|
import textwrap
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from typing import List, Set, Dict, Any, Callable, Optional
|
|
from types import ModuleType
|
|
from collections import OrderedDict
|
|
|
|
from .common import *
|
|
from .utils.logger import log
|
|
|
|
|
|
class OrderedSet:
|
|
"""
|
|
A deterministic set implementation for ordered operations.
|
|
"""
|
|
|
|
def __init__(self, iterable=None):
|
|
self._dict = dict.fromkeys(iterable or [])
|
|
|
|
def add(self, item):
|
|
self._dict[item] = None
|
|
|
|
def __iter__(self):
|
|
return iter(self._dict)
|
|
|
|
def __and__(self, other):
|
|
return OrderedSet(key for key in self._dict if key in other)
|
|
|
|
def __or__(self, other):
|
|
new_dict = self._dict.copy()
|
|
new_dict.update(dict.fromkeys(other))
|
|
return OrderedSet(new_dict)
|
|
|
|
def __sub__(self, other):
|
|
return OrderedSet(key for key in self._dict if key not in other)
|
|
|
|
def intersections(self, others):
|
|
"""Compute the intersection of this set with multiple other sets.
|
|
|
|
:param others: A list of sets to compute intersections with
|
|
:type others: List[Set[str]]
|
|
:return: A new ordered set containing elements that appear in this set
|
|
and at least one of the other sets
|
|
"""
|
|
result = OrderedSet()
|
|
for key in self._dict:
|
|
for other in reversed(others):
|
|
if key in other:
|
|
result.add(key)
|
|
break
|
|
return result
|
|
|
|
|
|
@dataclass
|
|
class ScopeManager:
|
|
"""
|
|
Manages symbol scopes during AST traversal.
|
|
Manage nested scopes during transformations.
|
|
"""
|
|
|
|
scopes: List[Set[str]]
|
|
|
|
@classmethod
|
|
def create(cls) -> "ScopeManager":
|
|
return cls([])
|
|
|
|
def add_to_scope(self, name: str) -> None:
|
|
self.scopes[-1].add(name)
|
|
|
|
def get_active_symbols(self) -> List[Set[str]]:
|
|
return self.scopes.copy()
|
|
|
|
def __enter__(self) -> "ScopeManager":
|
|
self.scopes.append(set())
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
self.scopes.pop()
|
|
|
|
|
|
class DSLPreprocessor(ast.NodeTransformer):
|
|
"""
|
|
A preprocessor for transforming Python ASTs. It supports:
|
|
|
|
- Rewriting `for` loops with the `@loop_selector` decorator.
|
|
- Rewriting `if-elif-else` statements with the `@if_selector` decorator.
|
|
- Generating `yield` operations for read-write or write symbols.
|
|
"""
|
|
|
|
DECORATOR_FOR_STATEMENT = "loop_selector"
|
|
DECORATOR_IF_STATEMENT = "if_selector"
|
|
DECORATOR_WHILE_STATEMENT = "while_selector"
|
|
IF_EXECUTOR = "if_executor"
|
|
WHILE_EXECUTOR = "while_executor"
|
|
ASSERT_EXECUTOR = "assert_executor"
|
|
BOOL_CAST = "bool_cast"
|
|
IMPLICIT_DOWNCAST_NUMERIC_TYPE = "implicitDowncastNumericType"
|
|
SUPPORTED_FOR_RANGE_STATEMENTS = {"range", "range_dynamic", "range_constexpr"}
|
|
COMPARE_EXECUTOR = "compare_executor"
|
|
ANY_EXECUTOR = "any_executor"
|
|
ALL_EXECUTOR = "all_executor"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.counter = 0 # Unique function names for multiple loops
|
|
self.scope_manager = ScopeManager.create()
|
|
self.processed_functions = set()
|
|
self.function_counter = 0
|
|
self.function_name = "<unknown function>"
|
|
self.class_name = None
|
|
self.file_name = "<unknown filename>"
|
|
self.function_depth = 0
|
|
self.local_closures = set()
|
|
self.function_globals = None
|
|
|
|
def _get_module_imports(self, decorated_func):
|
|
"""Extract imports from the module containing the decorated function"""
|
|
imports = OrderedDict()
|
|
# Get the module containing the decorated function
|
|
if module := inspect.getmodule(decorated_func):
|
|
try:
|
|
# Get the module source code
|
|
source = inspect.getsource(module)
|
|
module_ast = ast.parse(source)
|
|
|
|
# Extract imports from the full module
|
|
alias = lambda n: n.asname if n.asname else n.name
|
|
for node in ast.walk(module_ast):
|
|
if isinstance(node, ast.Import):
|
|
for name in node.names:
|
|
imports[(name.name, None)] = alias(name)
|
|
elif isinstance(node, ast.ImportFrom):
|
|
module_name = node.module
|
|
if node.level > 0:
|
|
# Handle relative imports
|
|
package_name = module.__package__.rsplit(
|
|
".", node.level - 1
|
|
)[0]
|
|
module_name = f"{package_name}.{module_name}"
|
|
for name in node.names:
|
|
imports[(module_name, name.name)] = alias(name)
|
|
except (IOError, TypeError):
|
|
pass
|
|
|
|
return imports
|
|
|
|
def exec(self, function_name, original_function, code_object, exec_globals):
|
|
# Get imports from the original module
|
|
module_imports = self._get_module_imports(original_function)
|
|
|
|
# Import all required modules
|
|
for (module_path, attr_name), alias_name in module_imports.items():
|
|
try:
|
|
module = importlib.import_module(module_path)
|
|
if attr_name:
|
|
if attr_name == "*":
|
|
if hasattr(module, "__all__"):
|
|
attrs = module.__all__
|
|
else:
|
|
attrs = [
|
|
name for name in dir(module) if not name.startswith("_")
|
|
]
|
|
else:
|
|
attrs = [attr_name]
|
|
|
|
for attr in attrs:
|
|
alias = attr if attr_name == "*" else alias_name
|
|
exec_globals[alias] = getattr(module, attr)
|
|
else:
|
|
exec_globals[alias_name] = module
|
|
except (ImportError, AttributeError) as e:
|
|
raise ImportError(f"Failed to import {module_path}: {str(e)}")
|
|
|
|
# Execute the transformed code
|
|
log().info(
|
|
"ASTPreprocessor Executing transformed code for function [%s]",
|
|
function_name,
|
|
)
|
|
exec(code_object, exec_globals)
|
|
return exec_globals.get(function_name)
|
|
|
|
@staticmethod
|
|
def print_ast(transformed_tree=None):
|
|
print("#", "-" * 40, "Transformed AST", "-" * 40)
|
|
unparsed_code = ast.unparse(transformed_tree)
|
|
print(unparsed_code)
|
|
print("#", "-" * 40, "End Transformed AST", "-" * 40)
|
|
|
|
def make_func_param_name(self, base_name, used_names):
|
|
"""Generate a unique parameter name that doesn't collide with existing names."""
|
|
if base_name not in used_names:
|
|
return base_name
|
|
|
|
i = 0
|
|
while f"{base_name}_{i}" in used_names:
|
|
i += 1
|
|
return f"{base_name}_{i}"
|
|
|
|
def transform_function(self, func_name, function_pointer):
|
|
"""
|
|
Transforms a function.
|
|
"""
|
|
# Skip if the function has already been processed
|
|
if function_pointer in self.processed_functions:
|
|
log().info(
|
|
"ASTPreprocessor Skipping already processed function [%s]", func_name
|
|
)
|
|
return []
|
|
|
|
# Step 1. Parse the given function
|
|
file_name = inspect.getsourcefile(function_pointer)
|
|
lines, start_line = inspect.getsourcelines(function_pointer)
|
|
dedented_source = textwrap.dedent("".join(lines))
|
|
tree = ast.parse(dedented_source, filename=file_name)
|
|
# Bump the line numbers so they match the real source file
|
|
ast.increment_lineno(tree, start_line - 1)
|
|
|
|
# Step 1.2 Check the decorator
|
|
if not self.check_decorator(tree.body[0]):
|
|
log().info(
|
|
"[%s] - Skipping function due to missing decorator",
|
|
func_name,
|
|
)
|
|
return []
|
|
|
|
self.processed_functions.add(function_pointer)
|
|
log().info("ASTPreprocessor Transforming function [%s]", func_name)
|
|
|
|
# Step 2. Transform the function
|
|
transformed_tree = self.visit(tree)
|
|
ast.fix_missing_locations(transformed_tree)
|
|
combined_body = transformed_tree.body
|
|
|
|
# Step 3. Return the transformed tree
|
|
return combined_body
|
|
|
|
def check_early_exit(self, tree, kind):
|
|
"""
|
|
Checks if a given region or scope in the provided Python code has early exits.
|
|
"""
|
|
|
|
class EarlyExitChecker(ast.NodeVisitor):
|
|
def __init__(self, kind):
|
|
self.has_early_exit = False
|
|
self.early_exit_node = None
|
|
self.early_exit_type = None
|
|
self.kind = kind
|
|
self.loop_nest_level = 0
|
|
|
|
# Early exit is not allowed in any level of dynamic control flow
|
|
def visit_Return(self, node):
|
|
self.has_early_exit = True
|
|
self.early_exit_node = node
|
|
self.early_exit_type = "return"
|
|
|
|
def visit_Raise(self, node):
|
|
self.has_early_exit = True
|
|
self.early_exit_node = node
|
|
self.early_exit_type = "raise"
|
|
|
|
def visit_Break(self, node):
|
|
# For break/continue in inner loops, we don't consider it as early exit
|
|
if self.loop_nest_level == 0 and self.kind != "if":
|
|
self.has_early_exit = True
|
|
self.early_exit_node = node
|
|
self.early_exit_type = "break"
|
|
|
|
def visit_Continue(self, node):
|
|
if self.loop_nest_level == 0 and self.kind != "if":
|
|
self.has_early_exit = True
|
|
self.early_exit_node = node
|
|
self.early_exit_type = "continue"
|
|
|
|
def visit_For(self, node):
|
|
self.loop_nest_level += 1
|
|
self.generic_visit(node)
|
|
self.loop_nest_level -= 1
|
|
|
|
def visit_While(self, node):
|
|
self.loop_nest_level += 1
|
|
self.generic_visit(node)
|
|
self.loop_nest_level -= 1
|
|
|
|
checker = EarlyExitChecker(kind)
|
|
checker.generic_visit(tree)
|
|
if not checker.has_early_exit:
|
|
return
|
|
raise DSLAstPreprocessorError(
|
|
message=f"Early exit ({checker.early_exit_type}) is not allowed in `{self.function_name}`"
|
|
+ (f" in `{self.class_name}`" if self.class_name else ""),
|
|
filename=self.file_name,
|
|
snippet=ast.unparse(tree),
|
|
suggestion=(
|
|
"If predicates are constant expression, write like "
|
|
"`if const_expr(...)` or `for ... in range_constexpr(...)`. "
|
|
"In that case, early exit will be executed by Python "
|
|
"interpreter, so it's supported."
|
|
),
|
|
)
|
|
|
|
def is_node_constexpr(self, node) -> bool:
|
|
"""
|
|
Determines if the node is a constexpr.
|
|
Supported nodes are if, while statements.
|
|
"""
|
|
if isinstance(node, ast.If) or isinstance(node, ast.While):
|
|
if isinstance(node.test, ast.Call):
|
|
func = node.test.func
|
|
|
|
if isinstance(func, ast.Attribute) and func.attr == "const_expr":
|
|
return True
|
|
|
|
elif isinstance(func, ast.Name) and func.id == "const_expr":
|
|
return True
|
|
return False
|
|
|
|
def _get_range_kind(self, iter_node):
|
|
"""
|
|
Return "range", "range_dynamic", "range_constexpr" or None for the iterable
|
|
"""
|
|
if isinstance(iter_node, ast.Call):
|
|
func = iter_node.func
|
|
if (
|
|
isinstance(func, ast.Name)
|
|
and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS
|
|
):
|
|
return func.id, True
|
|
if (
|
|
isinstance(func, ast.Attribute)
|
|
and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS
|
|
):
|
|
return func.attr, False
|
|
return None, None
|
|
|
|
def transform(self, original_function, exec_globals):
|
|
"""
|
|
Transforms the provided function using the preprocessor.
|
|
"""
|
|
self.file_name = inspect.getsourcefile(original_function)
|
|
self.function_globals = exec_globals
|
|
transformed_tree = self.transform_function(
|
|
original_function.__name__, original_function
|
|
)
|
|
unified_tree = ast.Module(body=transformed_tree, type_ignores=[])
|
|
unified_tree = ast.fix_missing_locations(unified_tree)
|
|
|
|
return unified_tree
|
|
|
|
def analyze_region_variables(
|
|
self, node: Union[ast.For, ast.If], active_symbols: List[Set[str]]
|
|
):
|
|
"""
|
|
Analyze variables in different code regions to identify read-only, write-only,
|
|
and active variables for DSL constructs.
|
|
"""
|
|
|
|
# we need orderedset to keep the insertion order the same. otherwise generated IR is different each time
|
|
read_args = OrderedSet()
|
|
write_args = OrderedSet()
|
|
local_closure = self.local_closures
|
|
file_name = self.file_name
|
|
region_node = node
|
|
|
|
class RegionAnalyzer(ast.NodeVisitor):
|
|
|
|
def visit_Name(self, node):
|
|
"""
|
|
Mark every load as read, and every store as write.
|
|
"""
|
|
if isinstance(node.ctx, ast.Load):
|
|
read_args.add(node.id)
|
|
elif isinstance(node.ctx, ast.Store):
|
|
write_args.add(node.id)
|
|
|
|
@staticmethod
|
|
def get_call_base(func_node):
|
|
if isinstance(func_node, ast.Attribute):
|
|
# If the .value is another Attribute, keep digging
|
|
if isinstance(func_node.value, ast.Attribute):
|
|
return RegionAnalyzer.get_call_base(func_node.value)
|
|
# If the .value is a Name, that's our base
|
|
elif isinstance(func_node.value, ast.Name):
|
|
return func_node.value.id
|
|
else:
|
|
# Could be something else (lambda, call, etc.)
|
|
return None
|
|
elif isinstance(func_node, ast.Name):
|
|
return None
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_function_name(func_node: ast.Call):
|
|
if isinstance(func_node.func, ast.Name):
|
|
function_name = func_node.func.id
|
|
# Check if it's a method or attribute call
|
|
elif isinstance(func_node.func, ast.Attribute):
|
|
function_name = func_node.func.attr
|
|
else:
|
|
function_name = None
|
|
return function_name
|
|
|
|
def visit_Call(self, node):
|
|
base_name = RegionAnalyzer.get_call_base(node.func)
|
|
|
|
if isinstance(node.func, ast.Name):
|
|
func_name = node.func.id
|
|
if func_name in local_closure:
|
|
raise DSLAstPreprocessorError(
|
|
f"Function `{func_name}` is a closure and is not supported in for/if statements",
|
|
filename=file_name,
|
|
snippet=ast.unparse(region_node),
|
|
)
|
|
|
|
# Classes are mutable by default. Mark them as write. If they are
|
|
# dataclass(frozen=True), treat them as read in runtime.
|
|
if base_name is not None and base_name not in ("self"):
|
|
write_args.add(base_name)
|
|
|
|
self.generic_visit(node)
|
|
|
|
analyzer = RegionAnalyzer()
|
|
analyzer.visit(ast.Module(body=node))
|
|
|
|
# Argument can be Load and Store. We should just mark it as Store.
|
|
read_args = read_args - write_args
|
|
|
|
used_args = read_args.intersections(active_symbols)
|
|
iter_args = write_args.intersections(active_symbols)
|
|
flattend_args = used_args | iter_args
|
|
|
|
return list(used_args), list(iter_args), list(flattend_args)
|
|
|
|
def extract_range_args(self, iter_node):
|
|
args = iter_node.args
|
|
if len(args) == 1:
|
|
return (
|
|
self.visit(ast.Constant(value=0)),
|
|
self.visit(args[0]),
|
|
self.visit(ast.Constant(value=1)),
|
|
False,
|
|
)
|
|
elif len(args) == 2:
|
|
return (
|
|
self.visit(args[0]),
|
|
self.visit(args[1]),
|
|
self.visit(ast.Constant(value=1)),
|
|
False,
|
|
)
|
|
elif len(args) == 3:
|
|
return self.visit(args[0]), self.visit(args[1]), self.visit(args[2]), True
|
|
else:
|
|
raise DSLAstPreprocessorError(
|
|
"Unsupported number of arguments in range", filename=self.file_name
|
|
)
|
|
|
|
def extract_unroll_args(self, iter_node):
|
|
keywords = {kw.arg: kw.value for kw in iter_node.keywords}
|
|
return (
|
|
keywords.get("unroll", ast.Constant(value=-1)),
|
|
keywords.get("unroll_full", ast.Constant(value=False)),
|
|
)
|
|
|
|
def extract_pipelining_args(self, iter_node):
|
|
keywords = {kw.arg: kw.value for kw in iter_node.keywords}
|
|
return keywords.get("pipelining", ast.Constant(value=None))
|
|
|
|
def create_loop_function(
|
|
self,
|
|
func_name,
|
|
node,
|
|
start,
|
|
stop,
|
|
step,
|
|
unroll,
|
|
unroll_full,
|
|
pipelining,
|
|
used_args,
|
|
iter_args,
|
|
flattened_args,
|
|
):
|
|
"""
|
|
Creates a loop body function with the `loop_selector` decorator.
|
|
"""
|
|
|
|
func_args = [ast.arg(arg=node.target.id, annotation=None)]
|
|
func_args += [ast.arg(arg=var, annotation=None) for var in flattened_args]
|
|
|
|
# Create the loop body
|
|
transformed_body = []
|
|
for stmt in node.body:
|
|
transformed_stmt = self.visit(stmt) # Recursively visit inner statements
|
|
if isinstance(transformed_stmt, list):
|
|
transformed_body.extend(transformed_stmt)
|
|
else:
|
|
transformed_body.append(transformed_stmt)
|
|
|
|
# Handle the return for a single iterated argument correctly
|
|
if len(iter_args) == 0:
|
|
transformed_body.append(ast.Return())
|
|
else:
|
|
transformed_body.append(
|
|
ast.Return(
|
|
value=ast.List(
|
|
elts=[ast.Name(id=var, ctx=ast.Load()) for var in iter_args],
|
|
ctx=ast.Load(),
|
|
)
|
|
)
|
|
)
|
|
|
|
# Define the decorator with parameters
|
|
decorator = ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id=self.DECORATOR_FOR_STATEMENT, ctx=ast.Load()),
|
|
args=[start, stop, step],
|
|
keywords=[
|
|
ast.keyword(arg="unroll", value=unroll),
|
|
ast.keyword(arg="unroll_full", value=unroll_full),
|
|
ast.keyword(arg="pipelining", value=pipelining),
|
|
ast.keyword(
|
|
arg="used_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="iter_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in iter_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="iter_arg_names",
|
|
value=ast.List(
|
|
elts=[ast.Constant(value=arg) for arg in iter_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
],
|
|
),
|
|
node,
|
|
)
|
|
|
|
return ast.copy_location(
|
|
ast.FunctionDef(
|
|
name=func_name,
|
|
args=ast.arguments(
|
|
posonlyargs=[],
|
|
args=func_args,
|
|
kwonlyargs=[],
|
|
kw_defaults=[],
|
|
defaults=[],
|
|
),
|
|
body=transformed_body,
|
|
decorator_list=[decorator],
|
|
),
|
|
node,
|
|
)
|
|
|
|
def create_loop_call(self, func_name, iter_args):
|
|
"""
|
|
Assigns the returned value from the loop function directly (without a tuple unpacking).
|
|
"""
|
|
if len(iter_args) == 0:
|
|
return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load()))
|
|
elif len(iter_args) == 1:
|
|
return ast.Assign(
|
|
targets=[ast.Name(id=iter_args[0], ctx=ast.Store())],
|
|
value=ast.Name(id=func_name, ctx=ast.Load()),
|
|
)
|
|
else:
|
|
return ast.Assign(
|
|
targets=[
|
|
ast.Tuple(
|
|
elts=[ast.Name(id=var, ctx=ast.Store()) for var in iter_args],
|
|
ctx=ast.Store(),
|
|
)
|
|
],
|
|
value=ast.Name(id=func_name, ctx=ast.Load()),
|
|
)
|
|
|
|
def visit_BoolOp(self, node):
|
|
# Visit child nodes first
|
|
self.generic_visit(node)
|
|
|
|
# It is necessary to expand short circuit evaluation explicit here
|
|
# Although we do not support inline if-else for IR generation, this is actually evaluated in Python
|
|
# So it's fine here
|
|
# Transform "and" to "and_"
|
|
if isinstance(node.op, ast.And):
|
|
# Create an if-else statement in AST form
|
|
# if type(lhs) == bool and lhs == False:
|
|
# return lhs
|
|
# else
|
|
# return and_(lhs, rhs)
|
|
short_circuit_value = ast.Constant(value=False)
|
|
helper_func = ast.Name(id="and_", ctx=ast.Load())
|
|
# Transform "or" to "or_"
|
|
elif isinstance(node.op, ast.Or):
|
|
# Create an if-else statement in AST form
|
|
# if type(lhs) == bool and lhs == True:
|
|
# return lhs
|
|
# else
|
|
# return or_(lhs, rhs)
|
|
short_circuit_value = ast.Constant(value=True)
|
|
helper_func = ast.Name(id="or_", ctx=ast.Load())
|
|
else:
|
|
# BoolOp should be either And or Or
|
|
raise DSLAstPreprocessorError(
|
|
f"Unsupported boolean operation: {node.op}",
|
|
filename=self.file_name,
|
|
snippet=ast.unparse(node),
|
|
)
|
|
|
|
def short_circuit_eval(value, short_circuit_value):
|
|
return ast.BoolOp(
|
|
op=ast.And(),
|
|
values=[
|
|
ast.Compare(
|
|
left=ast.Call(
|
|
func=ast.Name(id="type", ctx=ast.Load()),
|
|
args=[value],
|
|
keywords=[],
|
|
),
|
|
ops=[ast.Eq()],
|
|
comparators=[ast.Name(id="bool", ctx=ast.Load())],
|
|
),
|
|
ast.Compare(
|
|
left=value,
|
|
ops=[ast.Eq()],
|
|
comparators=[short_circuit_value],
|
|
),
|
|
],
|
|
)
|
|
|
|
lhs = node.values[0]
|
|
|
|
for i in range(1, len(node.values)):
|
|
test = short_circuit_eval(lhs, short_circuit_value)
|
|
lhs = ast.IfExp(
|
|
test=test,
|
|
body=lhs,
|
|
orelse=ast.Call(
|
|
func=helper_func,
|
|
args=[lhs, node.values[i]],
|
|
keywords=[],
|
|
),
|
|
)
|
|
|
|
return ast.copy_location(lhs, node)
|
|
|
|
def visit_UnaryOp(self, node):
|
|
# Visit child nodes first
|
|
self.generic_visit(node)
|
|
|
|
# Transform "not" to "~" as we overload __invert__
|
|
if isinstance(node.op, ast.Not):
|
|
func_name = ast.Name(id="not_", ctx=ast.Load())
|
|
return ast.copy_location(
|
|
ast.Call(func=func_name, args=[node.operand], keywords=[]), node
|
|
)
|
|
|
|
return node
|
|
|
|
@staticmethod
|
|
def _insert_range_value_check(node):
|
|
"""
|
|
Insert a check for range arguments
|
|
"""
|
|
range_inputs = node.iter.args
|
|
check_call = ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id="range_value_check", ctx=ast.Load()),
|
|
args=range_inputs,
|
|
keywords=[],
|
|
),
|
|
node.iter,
|
|
)
|
|
node.iter = ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id="range", ctx=ast.Load()),
|
|
args=[ast.Starred(value=check_call, ctx=ast.Load())],
|
|
keywords=[],
|
|
),
|
|
node.iter,
|
|
)
|
|
|
|
def visit_For(self, node):
|
|
active_symbols = self.scope_manager.get_active_symbols()
|
|
|
|
with self.scope_manager:
|
|
if isinstance(node.target, ast.Name):
|
|
self.scope_manager.add_to_scope(node.target.id)
|
|
|
|
# For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop.
|
|
range_kind, is_builtin_range = self._get_range_kind(node.iter)
|
|
if range_kind == "range_constexpr" or range_kind == None:
|
|
self.generic_visit(node)
|
|
if range_kind == "range_constexpr":
|
|
# Rewrite range_constexpr to range
|
|
node.iter.func.id = "range"
|
|
self._insert_range_value_check(node)
|
|
return node
|
|
|
|
if range_kind == "range_dynamic":
|
|
# Generate a warning
|
|
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
|
warnings.warn_explicit(
|
|
"range_dynamic is deprecated and will be removed in the future, please remove it.",
|
|
category=DeprecationWarning,
|
|
filename=self.file_name,
|
|
lineno=node.iter.lineno,
|
|
)
|
|
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
|
|
|
warning_call = None
|
|
if range_kind == "range" and is_builtin_range:
|
|
# Warn about possible performance regression due to behavior change
|
|
warning_call = ast.Expr(
|
|
ast.Call(
|
|
func=ast.Name(id="range_perf_warning", ctx=ast.Load()),
|
|
args=[
|
|
ast.Constant(value=self.file_name),
|
|
ast.Constant(value=node.iter.lineno),
|
|
]
|
|
+ node.iter.args,
|
|
keywords=[],
|
|
)
|
|
)
|
|
ast.copy_location(warning_call, node.iter)
|
|
|
|
new_for_node = self.transform_for_loop(node, active_symbols)
|
|
|
|
return new_for_node if warning_call is None else [warning_call] + new_for_node
|
|
|
|
@staticmethod
|
|
def _hoist_expr_to_assignments(expr, name):
|
|
return ast.copy_location(
|
|
ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=expr), expr
|
|
)
|
|
|
|
def _build_select_and_assign(self, *, name, test, body, orelse, location):
|
|
node = ast.copy_location(
|
|
ast.Assign(
|
|
targets=[ast.Name(id=name, ctx=ast.Store())],
|
|
value=ast.IfExp(
|
|
test=test,
|
|
body=body,
|
|
orelse=orelse,
|
|
),
|
|
),
|
|
location,
|
|
)
|
|
self.generic_visit(node)
|
|
return node
|
|
|
|
def _handle_negative_step(self, node, start_expr, stop_expr, step_expr):
|
|
# hoist start, stop, step to assignments
|
|
start_ori_name = f"start_ori_{self.counter}"
|
|
start = self._hoist_expr_to_assignments(start_expr, start_ori_name)
|
|
stop_ori_name = f"stop_ori_{self.counter}"
|
|
stop = self._hoist_expr_to_assignments(stop_expr, stop_ori_name)
|
|
step_ori_name = f"step_ori_{self.counter}"
|
|
step = self._hoist_expr_to_assignments(step_expr, step_ori_name)
|
|
|
|
extra_exprs = [start, stop, step]
|
|
|
|
# Handle possible negative step, generates the following code in Python:
|
|
# isNegative = step < 0
|
|
isNegative_name = f"isNegative_{self.counter}"
|
|
isNegative = ast.copy_location(
|
|
ast.Assign(
|
|
targets=[ast.Name(id=isNegative_name, ctx=ast.Store())],
|
|
value=ast.Compare(
|
|
left=ast.Name(id=step_ori_name, ctx=ast.Load()),
|
|
ops=[ast.Lt()],
|
|
comparators=[ast.Constant(value=0)],
|
|
),
|
|
),
|
|
step,
|
|
)
|
|
|
|
# start = stop if isNegative else start
|
|
start_name = f"start_{self.counter}"
|
|
start = self._build_select_and_assign(
|
|
name=start_name,
|
|
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
|
body=ast.Name(id=stop_ori_name, ctx=ast.Load()),
|
|
orelse=ast.Name(id=start_ori_name, ctx=ast.Load()),
|
|
location=start,
|
|
)
|
|
|
|
# stop = start if isNegative else stop
|
|
stop_name = f"stop_{self.counter}"
|
|
stop = self._build_select_and_assign(
|
|
name=stop_name,
|
|
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
|
body=ast.Name(id=start_ori_name, ctx=ast.Load()),
|
|
orelse=ast.Name(id=stop_ori_name, ctx=ast.Load()),
|
|
location=stop,
|
|
)
|
|
|
|
# step = -step if isNegative else step
|
|
step_name = f"step_{self.counter}"
|
|
step = self._build_select_and_assign(
|
|
name=step_name,
|
|
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
|
body=ast.UnaryOp(
|
|
op=ast.USub(), operand=ast.Name(id=step_ori_name, ctx=ast.Load())
|
|
),
|
|
orelse=ast.Name(id=step_ori_name, ctx=ast.Load()),
|
|
location=step,
|
|
)
|
|
|
|
# offset = start + stop if isNegative else 0
|
|
offset_name = f"offset_{self.counter}"
|
|
offset = self._build_select_and_assign(
|
|
name=offset_name,
|
|
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
|
body=ast.BinOp(
|
|
op=ast.Add(),
|
|
left=ast.Name(id=start_name, ctx=ast.Load()),
|
|
right=ast.Name(id=stop_name, ctx=ast.Load()),
|
|
),
|
|
orelse=ast.Constant(value=0),
|
|
location=node,
|
|
)
|
|
|
|
extra_exprs.append(isNegative)
|
|
extra_exprs.append(start)
|
|
extra_exprs.append(stop)
|
|
extra_exprs.append(step)
|
|
extra_exprs.append(offset)
|
|
|
|
# Add this to begining of loop body
|
|
# for i in range(start, stop, step):
|
|
# i = offset - i if isNegative else i
|
|
assert isinstance(node.target, ast.Name)
|
|
|
|
target_name = node.target.id
|
|
target = self._build_select_and_assign(
|
|
name=target_name,
|
|
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
|
body=ast.BinOp(
|
|
op=ast.Sub(),
|
|
left=ast.Name(id=offset_name, ctx=ast.Load()),
|
|
right=ast.Name(id=target_name, ctx=ast.Load()),
|
|
),
|
|
orelse=ast.Name(id=target_name, ctx=ast.Load()),
|
|
location=node.target,
|
|
)
|
|
|
|
node.body.insert(0, target)
|
|
|
|
return (
|
|
ast.Name(id=start_name, ctx=ast.Load()),
|
|
ast.Name(id=stop_name, ctx=ast.Load()),
|
|
ast.Name(id=step_name, ctx=ast.Load()),
|
|
extra_exprs,
|
|
)
|
|
|
|
def transform_for_loop(self, node, active_symbols):
|
|
# Check for early exit and raise exception
|
|
self.check_early_exit(node, "for")
|
|
if node.orelse:
|
|
raise DSLAstPreprocessorError(
|
|
"dynamic for loop with else is not supported",
|
|
filename=self.file_name,
|
|
snippet=ast.unparse(node),
|
|
)
|
|
|
|
# Get loop target variable name
|
|
target_var_name = None
|
|
target_var_is_active_before_loop = False
|
|
if isinstance(node.target, ast.Name):
|
|
target_var_name = node.target.id
|
|
for active_symbol in active_symbols:
|
|
if target_var_name in active_symbol:
|
|
target_var_is_active_before_loop = True
|
|
active_symbols.remove(active_symbol)
|
|
break
|
|
|
|
# Add necessary exprs to handle this
|
|
if target_var_is_active_before_loop:
|
|
# Initialize an extra loop carried variable
|
|
loop_carried_var_name = f"loop_carried_var_{self.counter}"
|
|
pre_loop_expr = ast.copy_location(
|
|
ast.Assign(
|
|
targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())],
|
|
value=ast.Name(id=target_var_name, ctx=ast.Load()),
|
|
),
|
|
node,
|
|
)
|
|
# append an extra assignment to the loop carried variable
|
|
node.body.append(
|
|
ast.copy_location(
|
|
ast.Assign(
|
|
targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())],
|
|
value=ast.Name(id=target_var_name, ctx=ast.Load()),
|
|
),
|
|
node,
|
|
)
|
|
)
|
|
active_symbols.append({loop_carried_var_name})
|
|
|
|
start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter)
|
|
unroll, unroll_full = self.extract_unroll_args(node.iter)
|
|
pipelining = self.extract_pipelining_args(node.iter)
|
|
used_args, iter_args, flat_args = self.analyze_region_variables(
|
|
node, active_symbols
|
|
)
|
|
|
|
if has_step:
|
|
start, stop, step, exprs = self._handle_negative_step(
|
|
node, start_expr, stop_expr, step_expr
|
|
)
|
|
else:
|
|
start, stop, step, exprs = start_expr, stop_expr, step_expr, []
|
|
|
|
if target_var_is_active_before_loop:
|
|
exprs.append(pre_loop_expr)
|
|
|
|
func_name = f"loop_body_{self.counter}"
|
|
self.counter += 1
|
|
|
|
func_def = self.create_loop_function(
|
|
func_name,
|
|
node,
|
|
start,
|
|
stop,
|
|
step,
|
|
unroll,
|
|
unroll_full,
|
|
pipelining,
|
|
used_args,
|
|
iter_args,
|
|
flat_args,
|
|
)
|
|
|
|
assign = ast.copy_location(self.create_loop_call(func_name, iter_args), node)
|
|
|
|
# This should work fine as it modifies the AST structure
|
|
exprs = exprs + [func_def, assign]
|
|
|
|
if target_var_is_active_before_loop:
|
|
# Create a new assignment to the target variable
|
|
exprs.append(
|
|
ast.copy_location(
|
|
ast.Assign(
|
|
targets=[ast.Name(id=target_var_name, ctx=ast.Store())],
|
|
value=ast.Name(id=loop_carried_var_name, ctx=ast.Load()),
|
|
),
|
|
node,
|
|
)
|
|
)
|
|
|
|
return exprs
|
|
|
|
def visit_Name(self, node):
|
|
self.generic_visit(node)
|
|
return node
|
|
|
|
def visit_Assert(self, node):
|
|
test = self.visit(node.test)
|
|
|
|
args = [ast.keyword(arg="test", value=test)]
|
|
if node.msg:
|
|
msg = self.visit(node.msg)
|
|
args.append(ast.keyword(arg="msg", value=msg))
|
|
|
|
# Rewrite to assert_executor(test, msg)
|
|
new_node = ast.Expr(
|
|
ast.Call(
|
|
func=ast.Name(id=self.ASSERT_EXECUTOR, ctx=ast.Load()),
|
|
args=[],
|
|
keywords=args,
|
|
)
|
|
)
|
|
|
|
# Propagate line number from original node to new node
|
|
ast.copy_location(new_node, node)
|
|
return new_node
|
|
|
|
def visit_Call(self, node):
|
|
func = node.func
|
|
self.generic_visit(node)
|
|
|
|
# Rewrite call to some built-in functions
|
|
if isinstance(func, ast.Name):
|
|
# Check if the function is 'bool'
|
|
if func.id == "bool":
|
|
return ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id=self.BOOL_CAST, ctx=ast.Load()),
|
|
args=[node.args[0]],
|
|
keywords=[],
|
|
),
|
|
node,
|
|
)
|
|
elif func.id in ["any", "all"]:
|
|
helper_func = (
|
|
self.ANY_EXECUTOR if func.id == "any" else self.ALL_EXECUTOR
|
|
)
|
|
return ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id=helper_func, ctx=ast.Load()),
|
|
args=[node.args[0]],
|
|
keywords=[],
|
|
),
|
|
node,
|
|
)
|
|
elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
|
|
def create_downcast_call(arg):
|
|
return ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(
|
|
id=self.IMPLICIT_DOWNCAST_NUMERIC_TYPE, ctx=ast.Load()
|
|
),
|
|
args=[arg],
|
|
keywords=[],
|
|
),
|
|
arg,
|
|
)
|
|
module = self.function_globals.get(func.value.id)
|
|
if isinstance(module, ModuleType) and module.__package__.endswith(
|
|
"._mlir.dialects"
|
|
):
|
|
# Check if argument is Numeric, if so, call ir_value()
|
|
args = []
|
|
for arg in node.args:
|
|
args.append(create_downcast_call(arg))
|
|
kwargs = []
|
|
for kwarg in node.keywords:
|
|
kwargs.append(
|
|
ast.copy_location(
|
|
ast.keyword(
|
|
arg=kwarg.arg,
|
|
value=create_downcast_call(kwarg.value),
|
|
),
|
|
kwarg,
|
|
)
|
|
)
|
|
return ast.copy_location(
|
|
ast.Call(func=func, args=args, keywords=kwargs), node
|
|
)
|
|
|
|
return node
|
|
|
|
def visit_ClassDef(self, node):
|
|
self.class_name = node.name
|
|
self.generic_visit(node)
|
|
self.class_name = None
|
|
return node
|
|
|
|
def _visit_target(self, target):
|
|
if isinstance(target, ast.Name):
|
|
self.scope_manager.add_to_scope(target.id)
|
|
elif isinstance(target, ast.Tuple):
|
|
for t in target.elts:
|
|
if isinstance(t, ast.Name):
|
|
self.scope_manager.add_to_scope(t.id)
|
|
|
|
def visit_Assign(self, node):
|
|
for target in node.targets:
|
|
self._visit_target(target)
|
|
self.generic_visit(node)
|
|
return node
|
|
|
|
def visit_AugAssign(self, node):
|
|
self._visit_target(node.target)
|
|
self.generic_visit(node)
|
|
return node
|
|
|
|
def check_decorator(self, node: ast.AST) -> bool:
|
|
"""
|
|
Check if the function has the correct decorator for preprocessing.
|
|
"""
|
|
if not isinstance(node, ast.FunctionDef):
|
|
return False
|
|
decorator_list = node.decorator_list
|
|
if len(decorator_list) == 0:
|
|
return False
|
|
|
|
for d in decorator_list:
|
|
if isinstance(d, ast.Call):
|
|
if isinstance(d.func, ast.Attribute):
|
|
if d.func.attr in ["jit", "kernel"]:
|
|
if d.keywords == []:
|
|
return True
|
|
for keyword in d.keywords:
|
|
if keyword.arg == "preprocess":
|
|
try:
|
|
if isinstance(keyword.value, ast.Constant):
|
|
return keyword.value.value
|
|
else:
|
|
return ast.literal_eval(keyword.value)
|
|
except:
|
|
pass
|
|
|
|
elif isinstance(d, ast.Attribute):
|
|
if d.attr in ["jit", "kernel"]:
|
|
return True
|
|
|
|
return False
|
|
|
|
def remove_dsl_decorator(self, decorator_list):
|
|
"""
|
|
Remove .jit and .kernel decorators
|
|
The decorator can be in two forms:
|
|
- @jit(...)
|
|
- @jit
|
|
"""
|
|
new_decorator_list = []
|
|
decorator_names = ["jit", "kernel"]
|
|
for d in decorator_list:
|
|
is_jit_or_kernel = False
|
|
if isinstance(d, ast.Call):
|
|
if isinstance(d.func, ast.Attribute):
|
|
if d.func.attr in decorator_names:
|
|
is_jit_or_kernel = True
|
|
elif isinstance(d, ast.Attribute):
|
|
if d.attr in decorator_names:
|
|
is_jit_or_kernel = True
|
|
|
|
if not is_jit_or_kernel:
|
|
new_decorator_list.append(d)
|
|
return new_decorator_list
|
|
|
|
def visit_FunctionDef(self, node):
|
|
with self.scope_manager:
|
|
self.function_counter += 1
|
|
self.function_name = node.name
|
|
if self.function_depth > 0:
|
|
self.local_closures.add(node.name)
|
|
|
|
self.function_depth += 1
|
|
|
|
# Add function name and arguments
|
|
self.scope_manager.add_to_scope(node.name)
|
|
for arg in node.args.args:
|
|
self.scope_manager.add_to_scope(arg.arg)
|
|
|
|
self.generic_visit(node)
|
|
|
|
self.function_depth -= 1
|
|
|
|
# Remove .jit and .kernel decorators
|
|
node.decorator_list = self.remove_dsl_decorator(node.decorator_list)
|
|
return node
|
|
|
|
def visit_With(self, node):
|
|
with self.scope_manager:
|
|
for item in node.items:
|
|
if isinstance(item.optional_vars, ast.Name):
|
|
self.scope_manager.add_to_scope(item.optional_vars.id)
|
|
self.generic_visit(node)
|
|
|
|
return node
|
|
|
|
def visit_While(self, node):
|
|
active_symbols = self.scope_manager.get_active_symbols()
|
|
# print(active_symbols)
|
|
with self.scope_manager:
|
|
# Constexpr doesn't get preprocessed
|
|
if self.is_node_constexpr(node):
|
|
self.generic_visit(node)
|
|
return node
|
|
|
|
# Check for early exit and raise exception
|
|
self.check_early_exit(node, "while")
|
|
|
|
used_args, yield_args, flat_args = self.analyze_region_variables(
|
|
node, active_symbols
|
|
)
|
|
func_name = f"while_region_{self.counter}"
|
|
self.counter += 1
|
|
|
|
func_def = self.create_while_function(
|
|
func_name, node, used_args, yield_args, flat_args
|
|
)
|
|
assign = ast.copy_location(
|
|
self.create_loop_call(func_name, yield_args), node
|
|
)
|
|
|
|
return [func_def, assign]
|
|
|
|
def visit_Try(self, node):
|
|
with self.scope_manager:
|
|
self.generic_visit(node)
|
|
return node
|
|
|
|
def visit_ExceptHandler(self, node):
|
|
with self.scope_manager:
|
|
if node.name: # Exception variable
|
|
self.scope_manager.add_to_scope(node.name)
|
|
self.generic_visit(node)
|
|
return node
|
|
|
|
def create_if_call(self, func_name, yield_args, flat_args):
|
|
"""Creates the assignment statement for the if function call"""
|
|
if not yield_args:
|
|
return ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load()))
|
|
elif len(yield_args) == 1:
|
|
return ast.Assign(
|
|
targets=[ast.Name(id=yield_args[0], ctx=ast.Store())],
|
|
value=ast.Name(id=func_name, ctx=ast.Load()),
|
|
)
|
|
else:
|
|
return ast.Assign(
|
|
targets=[
|
|
ast.Tuple(
|
|
elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args],
|
|
ctx=ast.Store(),
|
|
)
|
|
],
|
|
value=ast.Name(id=func_name, ctx=ast.Load()),
|
|
)
|
|
|
|
def visit_IfExp(self, node):
|
|
"""
|
|
Visits an inline if-else expression (ternary operator).
|
|
This is the Python equivalent of `x if condition else y`.
|
|
"""
|
|
self.generic_visit(node)
|
|
# Emit
|
|
# node if type(pred) == bool else select_(pred, body, orelse)
|
|
# so if pred is a python bool, use python to short-circuit and avoid emit arith.select
|
|
return ast.copy_location(
|
|
ast.IfExp(
|
|
test=ast.Compare(
|
|
left=ast.Call(
|
|
func=ast.Name(id="type", ctx=ast.Load()),
|
|
args=[node.test],
|
|
keywords=[],
|
|
),
|
|
ops=[ast.Eq()],
|
|
comparators=[ast.Name(id="bool", ctx=ast.Load())],
|
|
),
|
|
body=node, # Original ternary expression
|
|
orelse=ast.Call(
|
|
func=ast.Name(id="select_", ctx=ast.Load()),
|
|
args=[
|
|
node.test,
|
|
node.body,
|
|
node.orelse,
|
|
],
|
|
keywords=[],
|
|
),
|
|
),
|
|
node,
|
|
)
|
|
|
|
cmpops = {
|
|
"Eq": "==",
|
|
"NotEq": "!=",
|
|
"Lt": "<",
|
|
"LtE": "<=",
|
|
"Gt": ">",
|
|
"GtE": ">=",
|
|
"Is": "is",
|
|
"IsNot": "is not",
|
|
"In": "in",
|
|
"NotIn": "not in",
|
|
}
|
|
def compare_ops_to_str(self, node):
|
|
names = [
|
|
ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops
|
|
]
|
|
return ast.List(elts=names, ctx=ast.Load())
|
|
|
|
def visit_Compare(self, node):
|
|
self.generic_visit(node)
|
|
|
|
comparator_strs = self.compare_ops_to_str(node)
|
|
|
|
keywords = [
|
|
ast.keyword(arg="left", value=node.left),
|
|
ast.keyword(
|
|
arg="comparators", value=ast.List(elts=node.comparators, ctx=ast.Load())
|
|
),
|
|
ast.keyword(arg="ops", value=comparator_strs),
|
|
]
|
|
|
|
call = ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id=self.COMPARE_EXECUTOR, ctx=ast.Load()),
|
|
args=[],
|
|
keywords=keywords,
|
|
),
|
|
node,
|
|
)
|
|
|
|
return call
|
|
|
|
def visit_If(self, node):
|
|
active_symbols = self.scope_manager.get_active_symbols()
|
|
with self.scope_manager:
|
|
# non-dynamic expr doesn't get preprocessed
|
|
if self.is_node_constexpr(node):
|
|
self.generic_visit(node)
|
|
return node
|
|
|
|
# Check for early exit and raise exception
|
|
self.check_early_exit(node, "if")
|
|
|
|
used_args, yield_args, flat_args = self.analyze_region_variables(
|
|
node, active_symbols
|
|
)
|
|
func_name = f"if_region_{self.counter}"
|
|
self.counter += 1
|
|
|
|
func_def = self.create_if_function(
|
|
func_name, node, used_args, yield_args, flat_args
|
|
)
|
|
assign = ast.copy_location(
|
|
self.create_if_call(func_name, yield_args, flat_args), node
|
|
)
|
|
|
|
return [func_def, assign]
|
|
|
|
def create_if_function(
|
|
self, func_name, node, used_args, yield_args, flattened_args
|
|
):
|
|
test_expr = self.visit(node.test)
|
|
pred_name = self.make_func_param_name("pred", flattened_args)
|
|
func_args = [ast.arg(arg=pred_name, annotation=None)]
|
|
func_args += [ast.arg(arg=var, annotation=None) for var in flattened_args]
|
|
func_args_then_else = [
|
|
ast.arg(arg=var, annotation=None) for var in flattened_args
|
|
]
|
|
|
|
then_body = []
|
|
for stmt in node.body:
|
|
transformed_stmt = self.visit(stmt) # Recursively visit inner statements
|
|
if isinstance(transformed_stmt, list):
|
|
then_body.extend(transformed_stmt)
|
|
else:
|
|
then_body.append(transformed_stmt)
|
|
|
|
# Create common return list for all blocks
|
|
return_list = ast.List(
|
|
elts=[ast.Name(id=var, ctx=ast.Load()) for var in yield_args],
|
|
ctx=ast.Load(),
|
|
)
|
|
|
|
# Create common function arguments
|
|
func_decorator_arguments = ast.arguments(
|
|
posonlyargs=[], args=func_args, kwonlyargs=[], kw_defaults=[], defaults=[]
|
|
)
|
|
func_then_else_arguments = ast.arguments(
|
|
posonlyargs=[],
|
|
args=func_args_then_else,
|
|
kwonlyargs=[],
|
|
kw_defaults=[],
|
|
defaults=[],
|
|
)
|
|
|
|
then_block_name = f"then_block_{self.counter}"
|
|
else_block_name = f"else_block_{self.counter}"
|
|
elif_region_name = f"elif_region_{self.counter}"
|
|
self.counter += 1
|
|
|
|
# Create then block
|
|
then_block = ast.copy_location(
|
|
ast.FunctionDef(
|
|
name=then_block_name,
|
|
args=func_then_else_arguments,
|
|
body=then_body + [ast.Return(value=return_list)],
|
|
decorator_list=[],
|
|
),
|
|
node,
|
|
)
|
|
|
|
# Decorator keywords
|
|
decorator_keywords = [
|
|
ast.keyword(
|
|
arg="pred", value=test_expr
|
|
), # ast.Name(id="pred", ctx=ast.Load())
|
|
ast.keyword(
|
|
arg="used_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="yield_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in yield_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
]
|
|
|
|
# Create decorator
|
|
decorator = ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id=self.DECORATOR_IF_STATEMENT, ctx=ast.Load()),
|
|
args=[],
|
|
keywords=decorator_keywords,
|
|
),
|
|
node,
|
|
)
|
|
|
|
# Executor keywords
|
|
execute_keywords = [
|
|
ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())),
|
|
ast.keyword(
|
|
arg="used_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="yield_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in yield_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="yield_arg_names",
|
|
value=ast.List(
|
|
elts=[ast.Constant(value=arg) for arg in yield_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="then_block", value=ast.Name(id=then_block_name, ctx=ast.Load())
|
|
),
|
|
]
|
|
|
|
# Handle different cases
|
|
if not yield_args and node.orelse == []:
|
|
# No yield_args case - only then_block needed
|
|
execute_call = ast.copy_location(
|
|
ast.Call(
|
|
func=ast.copy_location(
|
|
ast.Name(id=self.IF_EXECUTOR, ctx=ast.Load()), node
|
|
),
|
|
args=[],
|
|
keywords=execute_keywords,
|
|
),
|
|
node,
|
|
)
|
|
func_body = [then_block, ast.Return(value=execute_call)]
|
|
else:
|
|
# Create else block based on node.orelse
|
|
if node.orelse:
|
|
if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
|
|
# Handle elif case
|
|
elif_node = node.orelse[0]
|
|
nested_if_name = elif_region_name
|
|
# Recursion for nested elif
|
|
nested_if = self.create_if_function(
|
|
nested_if_name, elif_node, used_args, yield_args, flattened_args
|
|
)
|
|
else_block = ast.FunctionDef(
|
|
name=else_block_name,
|
|
args=func_then_else_arguments,
|
|
body=[
|
|
nested_if,
|
|
ast.Return(
|
|
value=ast.Name(id=nested_if_name, ctx=ast.Load())
|
|
),
|
|
],
|
|
decorator_list=[],
|
|
)
|
|
else:
|
|
|
|
else_body = []
|
|
for stmt in node.orelse:
|
|
transformed_stmt = self.visit(
|
|
stmt
|
|
) # Recursively visit inner statements
|
|
if isinstance(transformed_stmt, list):
|
|
else_body.extend(transformed_stmt)
|
|
else:
|
|
else_body.append(transformed_stmt)
|
|
|
|
# Regular else block
|
|
else_block = ast.FunctionDef(
|
|
name=else_block_name,
|
|
args=func_then_else_arguments,
|
|
body=else_body + [ast.Return(value=return_list)],
|
|
decorator_list=[],
|
|
)
|
|
else:
|
|
# Default else block
|
|
else_block = ast.FunctionDef(
|
|
name=else_block_name,
|
|
args=func_then_else_arguments,
|
|
body=[ast.Return(value=return_list)],
|
|
decorator_list=[],
|
|
)
|
|
|
|
# Add else_block to execute keywords
|
|
execute_keywords.append(
|
|
ast.keyword(
|
|
arg="else_block", value=ast.Name(id=else_block_name, ctx=ast.Load())
|
|
)
|
|
)
|
|
|
|
execute_call = ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id=self.IF_EXECUTOR, ctx=ast.Load()),
|
|
args=[],
|
|
keywords=execute_keywords,
|
|
),
|
|
node,
|
|
)
|
|
func_body = [
|
|
then_block,
|
|
ast.copy_location(else_block, node),
|
|
ast.Return(value=execute_call),
|
|
]
|
|
|
|
return ast.copy_location(
|
|
ast.FunctionDef(
|
|
name=func_name,
|
|
args=func_decorator_arguments,
|
|
body=func_body,
|
|
decorator_list=[decorator],
|
|
),
|
|
node,
|
|
)
|
|
|
|
def create_while_function(
|
|
self, func_name, node, used_args, yield_args, flattened_args
|
|
):
|
|
"""Create a while function that looks like:
|
|
|
|
@while_selector(pred, used_args=[], yield_args=[])
|
|
def while_region(pred, flattened_args):
|
|
def while_before_block(*used_args, *yield_args):
|
|
# Note that during eval of pred can possibly alter yield_args
|
|
return *pred, yield_args
|
|
def while_after_block(*used_args, yield_args):
|
|
...loop_body_transformed...
|
|
return yield_args
|
|
return self.while_executor(pred, used_args, yield_args,
|
|
while_before_block, while_after_block, constexpr)
|
|
yield_args = while_region(pred, flattened_args)
|
|
|
|
Which will later be executed as psuedo-code:
|
|
|
|
# Dynamic mode:
|
|
scf.WhileOp(types(yield_args), yield_args)
|
|
with InsertionPoint(before_block):
|
|
cond, yield_args = while_before_block(*flattened_args)
|
|
scf.ConditionOp(cond, yield_args)
|
|
with InsertionPoint(after_block):
|
|
yield_args = while_after_block(yield_args)
|
|
scf.YieldOp(yield_args)
|
|
return while_op.results_
|
|
|
|
# Const mode:
|
|
cond, yield_args = while_before_block(yield_args)
|
|
while pred:
|
|
yield_args = body_block(yield_args)
|
|
cond, yield_args = while_before_block(yield_args)
|
|
return yield_args
|
|
"""
|
|
test_expr = self.visit(node.test)
|
|
pred_name = self.make_func_param_name("pred", flattened_args)
|
|
|
|
# Section: decorator construction
|
|
decorator_keywords = [
|
|
ast.keyword(arg="pred", value=test_expr),
|
|
ast.keyword(
|
|
arg="used_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="yield_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in yield_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
]
|
|
decorator = ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id=self.DECORATOR_WHILE_STATEMENT, ctx=ast.Load()),
|
|
args=[],
|
|
keywords=decorator_keywords,
|
|
),
|
|
node,
|
|
)
|
|
|
|
# Section: Shared initialization for before and after blocks
|
|
while_before_block_name = f"while_before_block_{self.counter}"
|
|
while_after_block_name = f"while_after_block_{self.counter}"
|
|
self.counter += 1
|
|
block_args_args = [ast.arg(arg=var, annotation=None) for var in used_args]
|
|
block_args_args += [ast.arg(arg=var, annotation=None) for var in yield_args]
|
|
block_args = ast.arguments(
|
|
posonlyargs=[],
|
|
args=block_args_args,
|
|
kwonlyargs=[],
|
|
kw_defaults=[],
|
|
defaults=[],
|
|
)
|
|
|
|
yield_args_ast_name_list = ast.List(
|
|
elts=[ast.Name(id=var, ctx=ast.Load()) for var in yield_args],
|
|
ctx=ast.Load(),
|
|
)
|
|
|
|
# Section: while_before_block FunctionDef, which contains condition
|
|
while_before_return_list = ast.List(
|
|
elts=[test_expr, yield_args_ast_name_list],
|
|
ctx=ast.Load(),
|
|
)
|
|
while_before_stmts = [ast.Return(value=while_before_return_list)]
|
|
while_before_block = ast.copy_location(
|
|
ast.FunctionDef(
|
|
name=while_before_block_name,
|
|
args=block_args,
|
|
body=while_before_stmts,
|
|
decorator_list=[],
|
|
),
|
|
test_expr,
|
|
)
|
|
|
|
# Section: while_after_block FunctionDef, which contains loop body
|
|
while_after_stmts = []
|
|
for stmt in node.body:
|
|
transformed_stmt = self.visit(stmt) # Recursively visit inner statements
|
|
if isinstance(transformed_stmt, list):
|
|
while_after_stmts.extend(transformed_stmt)
|
|
else:
|
|
while_after_stmts.append(transformed_stmt)
|
|
while_after_stmts.append(ast.Return(value=yield_args_ast_name_list))
|
|
|
|
while_after_block = ast.copy_location(
|
|
ast.FunctionDef(
|
|
name=while_after_block_name,
|
|
args=block_args,
|
|
body=while_after_stmts,
|
|
decorator_list=[],
|
|
),
|
|
node,
|
|
)
|
|
|
|
# Section: Execute via executor
|
|
execute_keywords = [
|
|
ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())),
|
|
ast.keyword(
|
|
arg="used_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in used_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="yield_args",
|
|
value=ast.List(
|
|
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in yield_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
ast.keyword(
|
|
arg="while_before_block",
|
|
value=ast.Name(id=while_before_block_name, ctx=ast.Load()),
|
|
),
|
|
ast.keyword(
|
|
arg="while_after_block",
|
|
value=ast.Name(id=while_after_block_name, ctx=ast.Load()),
|
|
),
|
|
ast.keyword(
|
|
arg="yield_arg_names",
|
|
value=ast.List(
|
|
elts=[ast.Constant(value=arg) for arg in yield_args],
|
|
ctx=ast.Load(),
|
|
),
|
|
),
|
|
]
|
|
|
|
execute_call = ast.Call(
|
|
func=ast.Name(id=self.WHILE_EXECUTOR, ctx=ast.Load()),
|
|
args=[],
|
|
keywords=execute_keywords,
|
|
)
|
|
|
|
# Putting everything together, FunctionDef for while_region
|
|
func_args_args = [ast.arg(arg=pred_name, annotation=None)]
|
|
func_args_args += [ast.arg(arg=var, annotation=None) for var in flattened_args]
|
|
func_args = ast.arguments(
|
|
posonlyargs=[],
|
|
args=func_args_args,
|
|
kwonlyargs=[],
|
|
kw_defaults=[],
|
|
defaults=[],
|
|
)
|
|
|
|
return ast.copy_location(
|
|
ast.FunctionDef(
|
|
name=func_name,
|
|
args=func_args,
|
|
body=[
|
|
while_before_block,
|
|
while_after_block,
|
|
ast.Return(value=execute_call),
|
|
],
|
|
decorator_list=[decorator],
|
|
),
|
|
node,
|
|
)
|