526 lines
15 KiB
Python
526 lines
15 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module provides helper functions that are generated by the preprocessor.
|
|
The preprocessor read through python's ast and changes the input code.
|
|
"""
|
|
|
|
from typing import Callable, Iterator, Optional, overload
|
|
from typing_extensions import deprecated
|
|
import warnings
|
|
|
|
from .utils.logger import log
|
|
from .common import *
|
|
|
|
from ._mlir_helpers.arith import ArithValue
|
|
|
|
class Executor:
|
|
"""
|
|
The Executor class handles dynamic and compile-time (constexpr) execution
|
|
of "for" loops and "if-else-elif" statements.
|
|
|
|
Methods:
|
|
set_functions: Assigns the functions for checking loop bounds and
|
|
conditional evaluation.
|
|
|
|
for_execute: Generates MLIR for OP
|
|
while_execute: Generates MLIR while OP
|
|
if_execute: generate MLIR if OP
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._is_dynamic_expression = None
|
|
self._loop_execute_range_dynamic = None
|
|
self._if_dynamic = None
|
|
self._while_dynamic = None
|
|
self._compare_executor = None
|
|
self._any_executor = None
|
|
self._all_executor = None
|
|
|
|
def set_functions(
|
|
self,
|
|
is_dynamic_expression: Callable,
|
|
loop_execute_range_dynamic: Callable,
|
|
if_dynamic: Callable,
|
|
while_dynamic: Callable,
|
|
compare_executor: Callable,
|
|
any_executor: Callable = None,
|
|
all_executor: Callable = None,
|
|
):
|
|
self._is_dynamic_expression = is_dynamic_expression
|
|
self._loop_execute_range_dynamic = loop_execute_range_dynamic
|
|
self._if_dynamic = if_dynamic
|
|
self._while_dynamic = while_dynamic
|
|
self._compare_executor = compare_executor
|
|
self._any_executor = any_executor
|
|
self._all_executor = all_executor
|
|
|
|
@staticmethod
|
|
def convert_to_list(x):
|
|
"""This function is used to convert x to a list.
|
|
If x is None, return an empty list.
|
|
If x is not a list, return a list containing x.
|
|
Otherwise, return x itself.
|
|
"""
|
|
if x is None:
|
|
return []
|
|
if not isinstance(x, list):
|
|
return [x]
|
|
return x
|
|
|
|
@staticmethod
|
|
def converge_ret_val(res):
|
|
"""This function is used to converge res (the return value) of the function.
|
|
If res is None, return None.
|
|
If res is a list and has only one element, return the element.
|
|
Otherwise, return res itself.
|
|
"""
|
|
if res is None:
|
|
return res
|
|
elif isinstance(res, list) and len(res) == 1:
|
|
return res[0]
|
|
return res
|
|
|
|
@staticmethod
|
|
def for_constexpr(
|
|
func: Callable,
|
|
start: int,
|
|
stop: int,
|
|
step: int,
|
|
used_args: list,
|
|
iter_args: list,
|
|
):
|
|
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
|
loop_results = iter_args
|
|
log().debug("iter_args [%s]", iter_args)
|
|
for i in range(start, stop, step):
|
|
log().debug("i [%s] iter_args [%s]", i, iter_args)
|
|
loop_results = func(i, *used_args, *loop_results)
|
|
log().debug("loop_results [%s]", loop_results)
|
|
if loop_results is None:
|
|
loop_results = []
|
|
if not isinstance(loop_results, list):
|
|
loop_results = [loop_results]
|
|
|
|
log().debug("done loop_results [%s]", loop_results)
|
|
return Executor.converge_ret_val(loop_results)
|
|
|
|
def for_execute(
|
|
self,
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args=[],
|
|
iter_args=[],
|
|
iter_arg_names=[],
|
|
unroll=-1,
|
|
unroll_full=False,
|
|
pipelining=None,
|
|
):
|
|
assert (
|
|
self._loop_execute_range_dynamic
|
|
), "Functions must be set before execution."
|
|
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
|
|
|
return self._loop_execute_range_dynamic(
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args,
|
|
iter_args,
|
|
iter_arg_names,
|
|
unroll,
|
|
unroll_full,
|
|
pipelining,
|
|
)
|
|
|
|
def if_execute(
|
|
self,
|
|
pred,
|
|
then_block: Callable,
|
|
else_block: Optional[Callable] = None,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
):
|
|
assert self._if_dynamic, "Functions must be set before execution."
|
|
|
|
# MLIR generation
|
|
return self._if_dynamic(
|
|
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
|
)
|
|
|
|
def while_execute(
|
|
self,
|
|
pred,
|
|
while_before_block: Callable,
|
|
while_after_block: Callable,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
):
|
|
assert self._while_dynamic, "Functions must be set before execution."
|
|
|
|
# MLIR generation
|
|
return self._while_dynamic(
|
|
while_before_block,
|
|
while_after_block,
|
|
used_args,
|
|
yield_args,
|
|
yield_arg_names,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Decorator
|
|
# =============================================================================
|
|
|
|
executor = Executor()
|
|
|
|
|
|
def loop_selector(
|
|
start,
|
|
stop,
|
|
step,
|
|
*,
|
|
used_args=[],
|
|
iter_args=[],
|
|
iter_arg_names=[],
|
|
unroll=-1,
|
|
unroll_full=False,
|
|
pipelining=None,
|
|
):
|
|
log().debug(
|
|
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args,
|
|
iter_args,
|
|
unroll,
|
|
unroll_full,
|
|
pipelining,
|
|
)
|
|
from .typing import Integer, Numeric
|
|
|
|
def _maybe_upcast(value):
|
|
if isinstance(value, Integer):
|
|
value = value.ir_value()
|
|
|
|
return value
|
|
|
|
start = _maybe_upcast(start)
|
|
stop = _maybe_upcast(stop)
|
|
step = _maybe_upcast(step)
|
|
|
|
def ir_loop(func):
|
|
return executor.for_execute(
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args,
|
|
iter_args,
|
|
iter_arg_names,
|
|
unroll,
|
|
unroll_full,
|
|
pipelining,
|
|
)
|
|
|
|
return ir_loop
|
|
|
|
|
|
def if_selector(pred, used_args=[], yield_args=[]):
|
|
log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
|
|
# Handle Numeric types here?
|
|
|
|
from .typing import Numeric
|
|
|
|
if isinstance(pred, Numeric):
|
|
pred = pred.value
|
|
|
|
def ir_loop(func):
|
|
return func(pred, *used_args, *yield_args)
|
|
|
|
return ir_loop
|
|
|
|
|
|
def while_selector(pred, used_args=[], yield_args=[]):
|
|
def ir_while_loop(func):
|
|
return func(pred, *used_args, *yield_args)
|
|
|
|
return ir_while_loop
|
|
|
|
|
|
def while_executor(
|
|
pred,
|
|
while_before_block: Callable,
|
|
while_after_block: Callable,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
):
|
|
return executor.while_execute(
|
|
pred,
|
|
while_before_block,
|
|
while_after_block,
|
|
used_args,
|
|
yield_args,
|
|
yield_arg_names,
|
|
)
|
|
|
|
|
|
def if_executor(
|
|
pred,
|
|
then_block: Callable,
|
|
else_block: Optional[Callable] = None,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
):
|
|
return executor.if_execute(
|
|
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Range
|
|
# =============================================================================
|
|
|
|
|
|
class range:
|
|
"""
|
|
A range-like object for dynamic loop iteration in the DSL.
|
|
|
|
This class provides a range interface similar to Python's built-in range,
|
|
but is designed to be preprocessed into constructs for dynamic
|
|
loop execution.
|
|
|
|
The class supports both single-argument (stop) and three-argument
|
|
(start, stop, step) constructors with additional parameters for loop
|
|
optimization:
|
|
|
|
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
|
|
- unroll_full: Whether to fully unroll the loop
|
|
- pipelining: Compiler generated pipeline configuration
|
|
"""
|
|
@overload
|
|
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
|
|
pass
|
|
|
|
@overload
|
|
def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
|
|
pass
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
|
|
|
|
|
|
@deprecated(
|
|
"range_dynamic is deprecated and will be removed in the future, please remove it."
|
|
)
|
|
def range_dynamic(*args, **kwargs):
|
|
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
|
|
|
|
|
|
def range_constexpr(*args):
|
|
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
|
|
|
|
# =============================================================================
|
|
# If expressions
|
|
# =============================================================================
|
|
|
|
|
|
def const_expr(expression):
|
|
"""
|
|
This function is used to check if the expression is a python value.
|
|
If the expression is a python value, return the boolean value of the expression.
|
|
If the expression is a dynamic expression, raise an error.
|
|
"""
|
|
from .typing import Numeric
|
|
|
|
failed = False
|
|
|
|
if isinstance(expression, Numeric):
|
|
if isinstance(expression.value, (int, float, bool)):
|
|
return expression.value
|
|
else:
|
|
failed = True
|
|
elif executor._is_dynamic_expression(expression):
|
|
failed = True
|
|
|
|
if failed:
|
|
raise DSLRuntimeError(
|
|
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
|
|
context={
|
|
"If your expression depends on dynamic values": "Remove `const_expr()`",
|
|
},
|
|
)
|
|
return expression
|
|
|
|
|
|
@deprecated(
|
|
"dynamic_expr is deprecated and will be removed in the future, please remove it."
|
|
)
|
|
def dynamic_expr(expression):
|
|
return expression
|
|
|
|
|
|
# =============================================================================
|
|
# Assertion & casting
|
|
# =============================================================================
|
|
|
|
|
|
def assert_executor(test, msg=None):
|
|
from .typing import Numeric
|
|
|
|
fail = False
|
|
# Implicit convert dynamic expression to bool is not allowed
|
|
# So here explicitly do a None check
|
|
if test is not None and executor._is_dynamic_expression(test):
|
|
if isinstance(test, Numeric):
|
|
try:
|
|
test = test.to(bool)
|
|
except:
|
|
fail = True
|
|
else:
|
|
fail = True
|
|
|
|
if not fail:
|
|
assert test, msg
|
|
else:
|
|
raise DSLRuntimeError(
|
|
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
|
suggestion = "Please replace with runtime assert."
|
|
)
|
|
|
|
|
|
def bool_cast(value):
|
|
if executor._is_dynamic_expression(value):
|
|
raise DSLRuntimeError(
|
|
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
|
suggestion = "Please explicitly convert to boolean with expressions like comparision."
|
|
)
|
|
return bool(value)
|
|
|
|
def compare_executor(left, comparators, ops):
|
|
"""
|
|
Executes comparison operations with a left operand and a list of comparators.
|
|
|
|
Args:
|
|
left: The leftmost value in the comparison chain
|
|
comparators: A list of values to compare against
|
|
ops: A list of comparison operators to apply
|
|
|
|
Returns:
|
|
The result of the comparison chain
|
|
|
|
Raises:
|
|
AssertionError: If the executor function is not set before execution
|
|
"""
|
|
assert (
|
|
executor._compare_executor is not None
|
|
), "Function must be set before execution."
|
|
return executor._compare_executor(left, comparators, ops)
|
|
|
|
|
|
def any_executor(iterable):
|
|
"""Executes the 'any' operation on an iterable, handling both dynamic and static expressions.
|
|
|
|
:param iterable: An iterable to check if any elements evaluate to True
|
|
:type iterable: Iterable
|
|
:return: boolean of Python value or IR value
|
|
:rtype: bool or cutlass.Boolean
|
|
|
|
"""
|
|
if executor._any_executor and executor._is_dynamic_expression(iterable):
|
|
return executor._any_executor(iterable)
|
|
else:
|
|
return any(iterable)
|
|
|
|
|
|
def all_executor(iterable):
|
|
"""Executes the 'all' operation on an iterable, handling both dynamic and static expressions.
|
|
|
|
:param iterable: An iterable to check if all elements evaluate to True
|
|
:type iterable: Iterable
|
|
:return: boolean of Python value or IR value
|
|
:rtype: bool or cutlass.Boolean
|
|
"""
|
|
if executor._all_executor and executor._is_dynamic_expression(iterable):
|
|
return executor._all_executor(iterable)
|
|
else:
|
|
return all(iterable)
|
|
|
|
|
|
# =============================================================================
|
|
# Control flow checks
|
|
# =============================================================================
|
|
def range_value_check(*args):
|
|
"""
|
|
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
|
|
"""
|
|
try:
|
|
args = tuple(arg.__index__() for arg in args)
|
|
|
|
# Compute range size and warn if it's too large
|
|
start = 0
|
|
end = 0
|
|
step = 1
|
|
if len(args) == 1:
|
|
end = args[0]
|
|
elif len(args) == 2:
|
|
start = args[0]
|
|
end = args[1]
|
|
elif len(args) == 3:
|
|
start = args[0]
|
|
end = args[1]
|
|
step = args[2]
|
|
|
|
range_length = (abs(end - start) - 1) // abs(step) + 1
|
|
if range_length >= 64:
|
|
warnings.warn(
|
|
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
|
|
category=UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
return (start, end, step)
|
|
except:
|
|
raise DSLRuntimeError(
|
|
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
|
|
suggestion="Use `range` instead of `range_constexpr`.",
|
|
)
|
|
|
|
|
|
def range_perf_warning(filename, lineno, *args):
|
|
has_dynamic_expr = False
|
|
for arg in args:
|
|
if executor._is_dynamic_expression(arg):
|
|
has_dynamic_expr = True
|
|
break
|
|
if not has_dynamic_expr:
|
|
warnings.warn_explicit(
|
|
(
|
|
"This loop is no longer unrolled and may cause performance regression. "
|
|
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
|
|
),
|
|
category=UserWarning,
|
|
filename=filename,
|
|
lineno=lineno,
|
|
)
|