Files
cutlass/python/CuTeDSL/base_dsl/ast_helpers.py
2025-07-21 22:03:55 -04:00

526 lines
15 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides helper functions that are generated by the preprocessor.
The preprocessor read through python's ast and changes the input code.
"""
from typing import Callable, Iterator, Optional, overload
from typing_extensions import deprecated
import warnings
from .utils.logger import log
from .common import *
from ._mlir_helpers.arith import ArithValue
class Executor:
"""
The Executor class handles dynamic and compile-time (constexpr) execution
of "for" loops and "if-else-elif" statements.
Methods:
set_functions: Assigns the functions for checking loop bounds and
conditional evaluation.
for_execute: Generates MLIR for OP
while_execute: Generates MLIR while OP
if_execute: generate MLIR if OP
"""
def __init__(self):
self._is_dynamic_expression = None
self._loop_execute_range_dynamic = None
self._if_dynamic = None
self._while_dynamic = None
self._compare_executor = None
self._any_executor = None
self._all_executor = None
def set_functions(
self,
is_dynamic_expression: Callable,
loop_execute_range_dynamic: Callable,
if_dynamic: Callable,
while_dynamic: Callable,
compare_executor: Callable,
any_executor: Callable = None,
all_executor: Callable = None,
):
self._is_dynamic_expression = is_dynamic_expression
self._loop_execute_range_dynamic = loop_execute_range_dynamic
self._if_dynamic = if_dynamic
self._while_dynamic = while_dynamic
self._compare_executor = compare_executor
self._any_executor = any_executor
self._all_executor = all_executor
@staticmethod
def convert_to_list(x):
"""This function is used to convert x to a list.
If x is None, return an empty list.
If x is not a list, return a list containing x.
Otherwise, return x itself.
"""
if x is None:
return []
if not isinstance(x, list):
return [x]
return x
@staticmethod
def converge_ret_val(res):
"""This function is used to converge res (the return value) of the function.
If res is None, return None.
If res is a list and has only one element, return the element.
Otherwise, return res itself.
"""
if res is None:
return res
elif isinstance(res, list) and len(res) == 1:
return res[0]
return res
@staticmethod
def for_constexpr(
func: Callable,
start: int,
stop: int,
step: int,
used_args: list,
iter_args: list,
):
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
loop_results = iter_args
log().debug("iter_args [%s]", iter_args)
for i in range(start, stop, step):
log().debug("i [%s] iter_args [%s]", i, iter_args)
loop_results = func(i, *used_args, *loop_results)
log().debug("loop_results [%s]", loop_results)
if loop_results is None:
loop_results = []
if not isinstance(loop_results, list):
loop_results = [loop_results]
log().debug("done loop_results [%s]", loop_results)
return Executor.converge_ret_val(loop_results)
def for_execute(
self,
func,
start,
stop,
step,
used_args=[],
iter_args=[],
iter_arg_names=[],
unroll=-1,
unroll_full=False,
pipelining=None,
):
assert (
self._loop_execute_range_dynamic
), "Functions must be set before execution."
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
return self._loop_execute_range_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
pipelining,
)
def if_execute(
self,
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
assert self._if_dynamic, "Functions must be set before execution."
# MLIR generation
return self._if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
def while_execute(
self,
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
assert self._while_dynamic, "Functions must be set before execution."
# MLIR generation
return self._while_dynamic(
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
)
# =============================================================================
# Decorator
# =============================================================================
executor = Executor()
def loop_selector(
start,
stop,
step,
*,
used_args=[],
iter_args=[],
iter_arg_names=[],
unroll=-1,
unroll_full=False,
pipelining=None,
):
log().debug(
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
start,
stop,
step,
used_args,
iter_args,
unroll,
unroll_full,
pipelining,
)
from .typing import Integer, Numeric
def _maybe_upcast(value):
if isinstance(value, Integer):
value = value.ir_value()
return value
start = _maybe_upcast(start)
stop = _maybe_upcast(stop)
step = _maybe_upcast(step)
def ir_loop(func):
return executor.for_execute(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
pipelining,
)
return ir_loop
def if_selector(pred, used_args=[], yield_args=[]):
log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
# Handle Numeric types here?
from .typing import Numeric
if isinstance(pred, Numeric):
pred = pred.value
def ir_loop(func):
return func(pred, *used_args, *yield_args)
return ir_loop
def while_selector(pred, used_args=[], yield_args=[]):
def ir_while_loop(func):
return func(pred, *used_args, *yield_args)
return ir_while_loop
def while_executor(
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
return executor.while_execute(
pred,
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
)
def if_executor(
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
return executor.if_execute(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
# =============================================================================
# Range
# =============================================================================
class range:
"""
A range-like object for dynamic loop iteration in the DSL.
This class provides a range interface similar to Python's built-in range,
but is designed to be preprocessed into constructs for dynamic
loop execution.
The class supports both single-argument (stop) and three-argument
(start, stop, step) constructors with additional parameters for loop
optimization:
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
- unroll_full: Whether to fully unroll the loop
- pipelining: Compiler generated pipeline configuration
"""
@overload
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
pass
@overload
def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
pass
def __new__(cls, *args, **kwargs):
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
def __iter__(self) -> Iterator[int]:
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
@deprecated(
"range_dynamic is deprecated and will be removed in the future, please remove it."
)
def range_dynamic(*args, **kwargs):
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
def range_constexpr(*args):
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
# =============================================================================
# If expressions
# =============================================================================
def const_expr(expression):
"""
This function is used to check if the expression is a python value.
If the expression is a python value, return the boolean value of the expression.
If the expression is a dynamic expression, raise an error.
"""
from .typing import Numeric
failed = False
if isinstance(expression, Numeric):
if isinstance(expression.value, (int, float, bool)):
return expression.value
else:
failed = True
elif executor._is_dynamic_expression(expression):
failed = True
if failed:
raise DSLRuntimeError(
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
context={
"If your expression depends on dynamic values": "Remove `const_expr()`",
},
)
return expression
@deprecated(
"dynamic_expr is deprecated and will be removed in the future, please remove it."
)
def dynamic_expr(expression):
return expression
# =============================================================================
# Assertion & casting
# =============================================================================
def assert_executor(test, msg=None):
from .typing import Numeric
fail = False
# Implicit convert dynamic expression to bool is not allowed
# So here explicitly do a None check
if test is not None and executor._is_dynamic_expression(test):
if isinstance(test, Numeric):
try:
test = test.to(bool)
except:
fail = True
else:
fail = True
if not fail:
assert test, msg
else:
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please replace with runtime assert."
)
def bool_cast(value):
if executor._is_dynamic_expression(value):
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please explicitly convert to boolean with expressions like comparision."
)
return bool(value)
def compare_executor(left, comparators, ops):
"""
Executes comparison operations with a left operand and a list of comparators.
Args:
left: The leftmost value in the comparison chain
comparators: A list of values to compare against
ops: A list of comparison operators to apply
Returns:
The result of the comparison chain
Raises:
AssertionError: If the executor function is not set before execution
"""
assert (
executor._compare_executor is not None
), "Function must be set before execution."
return executor._compare_executor(left, comparators, ops)
def any_executor(iterable):
"""Executes the 'any' operation on an iterable, handling both dynamic and static expressions.
:param iterable: An iterable to check if any elements evaluate to True
:type iterable: Iterable
:return: boolean of Python value or IR value
:rtype: bool or cutlass.Boolean
"""
if executor._any_executor and executor._is_dynamic_expression(iterable):
return executor._any_executor(iterable)
else:
return any(iterable)
def all_executor(iterable):
"""Executes the 'all' operation on an iterable, handling both dynamic and static expressions.
:param iterable: An iterable to check if all elements evaluate to True
:type iterable: Iterable
:return: boolean of Python value or IR value
:rtype: bool or cutlass.Boolean
"""
if executor._all_executor and executor._is_dynamic_expression(iterable):
return executor._all_executor(iterable)
else:
return all(iterable)
# =============================================================================
# Control flow checks
# =============================================================================
def range_value_check(*args):
"""
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
"""
try:
args = tuple(arg.__index__() for arg in args)
# Compute range size and warn if it's too large
start = 0
end = 0
step = 1
if len(args) == 1:
end = args[0]
elif len(args) == 2:
start = args[0]
end = args[1]
elif len(args) == 3:
start = args[0]
end = args[1]
step = args[2]
range_length = (abs(end - start) - 1) // abs(step) + 1
if range_length >= 64:
warnings.warn(
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
category=UserWarning,
stacklevel=2,
)
return (start, end, step)
except:
raise DSLRuntimeError(
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
suggestion="Use `range` instead of `range_constexpr`.",
)
def range_perf_warning(filename, lineno, *args):
has_dynamic_expr = False
for arg in args:
if executor._is_dynamic_expression(arg):
has_dynamic_expr = True
break
if not has_dynamic_expr:
warnings.warn_explicit(
(
"This loop is no longer unrolled and may cause performance regression. "
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
),
category=UserWarning,
filename=filename,
lineno=lineno,
)