582 lines
16 KiB
Python
582 lines
16 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module provides helper functions that are generated by the preprocessor.
|
|
The preprocessor read through python's ast and changes the input code.
|
|
"""
|
|
|
|
from typing import Callable, Iterator, Optional, overload
|
|
from typing_extensions import deprecated
|
|
import warnings
|
|
import inspect
|
|
from types import BuiltinFunctionType
|
|
from functools import lru_cache
|
|
|
|
from .utils.logger import log
|
|
from .common import *
|
|
|
|
from ._mlir_helpers.arith import ArithValue
|
|
|
|
|
|
class Executor:
|
|
"""
|
|
The Executor class handles dynamic and compile-time (constexpr) execution
|
|
of "for" loops and "if-else-elif" statements.
|
|
|
|
Methods:
|
|
set_functions: Assigns the functions for checking loop bounds and
|
|
conditional evaluation.
|
|
|
|
for_execute: Generates MLIR for OP
|
|
while_execute: Generates MLIR while OP
|
|
if_execute: generate MLIR if OP
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._is_dynamic_expression = None
|
|
self._loop_execute_range_dynamic = None
|
|
self._if_dynamic = None
|
|
self._while_dynamic = None
|
|
self._compare_executor = None
|
|
self._any_executor = None
|
|
self._all_executor = None
|
|
self._builtin_redirector = None
|
|
|
|
def set_functions(
|
|
self,
|
|
*,
|
|
is_dynamic_expression: Callable,
|
|
loop_execute_range_dynamic: Callable,
|
|
if_dynamic: Callable,
|
|
while_dynamic: Callable,
|
|
compare_executor: Callable,
|
|
any_executor: Callable = None,
|
|
all_executor: Callable = None,
|
|
builtin_redirector: Callable = None,
|
|
):
|
|
self._is_dynamic_expression = is_dynamic_expression
|
|
self._loop_execute_range_dynamic = loop_execute_range_dynamic
|
|
self._if_dynamic = if_dynamic
|
|
self._while_dynamic = while_dynamic
|
|
self._compare_executor = compare_executor
|
|
self._any_executor = any_executor
|
|
self._all_executor = all_executor
|
|
self._builtin_redirector = builtin_redirector
|
|
|
|
@staticmethod
|
|
def convert_to_list(x):
|
|
"""This function is used to convert x to a list.
|
|
If x is None, return an empty list.
|
|
If x is not a list, return a list containing x.
|
|
Otherwise, return x itself.
|
|
"""
|
|
if x is None:
|
|
return []
|
|
if not isinstance(x, list):
|
|
return [x]
|
|
return x
|
|
|
|
@staticmethod
|
|
def converge_ret_val(res):
|
|
"""This function is used to converge res (the return value) of the function.
|
|
If res is None, return None.
|
|
If res is a list and has only one element, return the element.
|
|
Otherwise, return res itself.
|
|
"""
|
|
if res is None:
|
|
return res
|
|
elif isinstance(res, list) and len(res) == 1:
|
|
return res[0]
|
|
return res
|
|
|
|
def for_execute(
|
|
self,
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
write_args=[],
|
|
full_write_args_count=0,
|
|
write_args_names=[],
|
|
unroll=-1,
|
|
unroll_full=False,
|
|
prefetch_stages=None,
|
|
):
|
|
assert (
|
|
self._loop_execute_range_dynamic
|
|
), "Functions must be set before execution."
|
|
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
|
|
|
return self._loop_execute_range_dynamic(
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
write_args,
|
|
full_write_args_count,
|
|
write_args_names,
|
|
unroll,
|
|
unroll_full,
|
|
prefetch_stages,
|
|
)
|
|
|
|
def if_execute(
|
|
self,
|
|
pred,
|
|
then_block: Callable,
|
|
else_block: Optional[Callable] = None,
|
|
write_args=[],
|
|
full_write_args_count=0,
|
|
write_args_names=[],
|
|
):
|
|
assert self._if_dynamic, "Functions must be set before execution."
|
|
|
|
# MLIR generation
|
|
return self._if_dynamic(
|
|
pred,
|
|
then_block,
|
|
else_block,
|
|
write_args,
|
|
full_write_args_count,
|
|
write_args_names,
|
|
)
|
|
|
|
def while_execute(
|
|
self,
|
|
pred,
|
|
while_before_block: Callable,
|
|
while_after_block: Callable,
|
|
write_args=[],
|
|
full_write_args_count=0,
|
|
write_args_names=[],
|
|
):
|
|
assert self._while_dynamic, "Functions must be set before execution."
|
|
|
|
# MLIR generation
|
|
return self._while_dynamic(
|
|
while_before_block,
|
|
while_after_block,
|
|
write_args,
|
|
full_write_args_count,
|
|
write_args_names,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Decorator
|
|
# =============================================================================
|
|
|
|
executor = Executor()
|
|
|
|
|
|
def loop_selector(
|
|
start,
|
|
stop,
|
|
step,
|
|
*,
|
|
write_args=[],
|
|
full_write_args_count=0,
|
|
write_args_names=[],
|
|
unroll=-1,
|
|
unroll_full=False,
|
|
prefetch_stages=None,
|
|
):
|
|
log().debug(
|
|
"start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]",
|
|
start,
|
|
stop,
|
|
step,
|
|
write_args,
|
|
full_write_args_count,
|
|
write_args_names,
|
|
unroll,
|
|
unroll_full,
|
|
prefetch_stages,
|
|
)
|
|
from .typing import Integer, Numeric
|
|
|
|
def _maybe_upcast(value):
|
|
if isinstance(value, Integer):
|
|
value = value.ir_value()
|
|
|
|
return value
|
|
|
|
start = _maybe_upcast(start)
|
|
stop = _maybe_upcast(stop)
|
|
step = _maybe_upcast(step)
|
|
|
|
def ir_loop(func):
|
|
return executor.for_execute(
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
write_args,
|
|
full_write_args_count,
|
|
write_args_names,
|
|
unroll,
|
|
unroll_full,
|
|
prefetch_stages,
|
|
)
|
|
|
|
return ir_loop
|
|
|
|
|
|
def if_selector(pred, write_args=[]):
|
|
log().debug("pred [%s] write_args [%s]", pred, write_args)
|
|
# Handle Numeric types here?
|
|
|
|
from .typing import Numeric
|
|
|
|
if isinstance(pred, Numeric):
|
|
pred = pred.value
|
|
|
|
def ir_loop(func):
|
|
return func(pred, *write_args)
|
|
|
|
return ir_loop
|
|
|
|
|
|
def while_selector(pred, write_args=[]):
|
|
def ir_while_loop(func):
|
|
return func(pred, *write_args)
|
|
|
|
return ir_while_loop
|
|
|
|
|
|
def while_executor(
|
|
pred,
|
|
while_before_block: Callable,
|
|
while_after_block: Callable,
|
|
write_args=[],
|
|
full_write_args_count=0,
|
|
write_args_names=[],
|
|
):
|
|
return executor.while_execute(
|
|
pred,
|
|
while_before_block,
|
|
while_after_block,
|
|
write_args,
|
|
full_write_args_count,
|
|
write_args_names,
|
|
)
|
|
|
|
|
|
def if_executor(
|
|
pred,
|
|
then_block: Callable,
|
|
else_block: Optional[Callable] = None,
|
|
write_args=[],
|
|
full_write_args_count=0,
|
|
write_args_names=[],
|
|
):
|
|
return executor.if_execute(
|
|
pred,
|
|
then_block,
|
|
else_block,
|
|
write_args,
|
|
full_write_args_count,
|
|
write_args_names,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Range
|
|
# =============================================================================
|
|
|
|
|
|
class range:
|
|
"""
|
|
A range-like object for dynamic loop iteration in the DSL.
|
|
|
|
This class provides a range interface similar to Python's built-in range,
|
|
but is designed to be preprocessed into constructs for dynamic
|
|
loop execution.
|
|
|
|
The class supports both single-argument (stop) and three-argument
|
|
(start, stop, step) constructors with additional parameters for loop
|
|
optimization:
|
|
|
|
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
|
|
- unroll_full: Whether to fully unroll the loop
|
|
- prefetch_stages: Number of prefetch stages to generate
|
|
"""
|
|
|
|
@overload
|
|
def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None):
|
|
pass
|
|
|
|
@overload
|
|
def __new__(
|
|
cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None
|
|
):
|
|
pass
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
|
|
|
|
|
|
@deprecated(
|
|
"range_dynamic is deprecated and will be removed in the future, please remove it."
|
|
)
|
|
def range_dynamic(*args, **kwargs):
|
|
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
|
|
|
|
|
|
def range_constexpr(*args):
|
|
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
|
|
|
|
|
|
# =============================================================================
|
|
# If expressions
|
|
# =============================================================================
|
|
|
|
|
|
def const_expr(expression):
|
|
"""
|
|
This function is used to check if the expression is a python value.
|
|
If the expression is a python value, return the boolean value of the expression.
|
|
If the expression is a dynamic expression, raise an error.
|
|
"""
|
|
from .typing import Numeric
|
|
|
|
failed = False
|
|
|
|
if isinstance(expression, Numeric):
|
|
if isinstance(expression.value, (int, float, bool)):
|
|
return expression.value
|
|
else:
|
|
failed = True
|
|
elif executor._is_dynamic_expression(expression):
|
|
failed = True
|
|
|
|
if failed:
|
|
raise DSLRuntimeError(
|
|
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
|
|
context={
|
|
"If your expression depends on dynamic values": "Remove `const_expr()`",
|
|
},
|
|
)
|
|
return expression
|
|
|
|
|
|
@deprecated(
|
|
"dynamic_expr is deprecated and will be removed in the future, please remove it."
|
|
)
|
|
def dynamic_expr(expression):
|
|
return expression
|
|
|
|
|
|
# =============================================================================
|
|
# Assertion & casting
|
|
# =============================================================================
|
|
|
|
|
|
def assert_executor(test, msg=None):
|
|
from .typing import Numeric
|
|
|
|
fail = False
|
|
# Implicit convert dynamic expression to bool is not allowed
|
|
# So here explicitly do a None check
|
|
if test is not None and executor._is_dynamic_expression(test):
|
|
if isinstance(test, Numeric):
|
|
try:
|
|
test = test.to(bool)
|
|
except:
|
|
fail = True
|
|
else:
|
|
fail = True
|
|
|
|
if not fail:
|
|
assert test, msg
|
|
else:
|
|
raise DSLRuntimeError(
|
|
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
|
suggestion="Please replace with runtime assert.",
|
|
)
|
|
|
|
|
|
def bool_cast(value):
|
|
if executor._is_dynamic_expression(value):
|
|
raise DSLRuntimeError(
|
|
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
|
suggestion="Please explicitly convert to boolean with expressions like comparision.",
|
|
)
|
|
return bool(value)
|
|
|
|
|
|
def compare_executor(left, comparators, ops):
|
|
"""
|
|
Executes comparison operations with a left operand and a list of comparators.
|
|
|
|
Args:
|
|
left: The leftmost value in the comparison chain
|
|
comparators: A list of values to compare against
|
|
ops: A list of comparison operators to apply
|
|
|
|
Returns:
|
|
The result of the comparison chain
|
|
|
|
Raises:
|
|
AssertionError: If the executor function is not set before execution
|
|
"""
|
|
assert (
|
|
executor._compare_executor is not None
|
|
), "Function must be set before execution."
|
|
return executor._compare_executor(left, comparators, ops)
|
|
|
|
|
|
def any_executor(iterable):
|
|
"""Executes the 'any' operation on an iterable, handling both dynamic and static expressions.
|
|
|
|
:param iterable: An iterable to check if any elements evaluate to True
|
|
:type iterable: Iterable
|
|
:return: boolean of Python value or IR value
|
|
:rtype: bool or cutlass.Boolean
|
|
|
|
"""
|
|
if executor._any_executor and executor._is_dynamic_expression(iterable):
|
|
return executor._any_executor(iterable)
|
|
else:
|
|
return any(iterable)
|
|
|
|
|
|
def all_executor(iterable):
|
|
"""Executes the 'all' operation on an iterable, handling both dynamic and static expressions.
|
|
|
|
:param iterable: An iterable to check if all elements evaluate to True
|
|
:type iterable: Iterable
|
|
:return: boolean of Python value or IR value
|
|
:rtype: bool or cutlass.Boolean
|
|
"""
|
|
if executor._all_executor and executor._is_dynamic_expression(iterable):
|
|
return executor._all_executor(iterable)
|
|
else:
|
|
return all(iterable)
|
|
|
|
|
|
# =============================================================================
|
|
# Control flow checks
|
|
# =============================================================================
|
|
class DSLOptimizationWarning(Warning):
|
|
"""
|
|
This warning is used to warn the user about the optimization related issues in DSL.
|
|
"""
|
|
|
|
def __init__(self, message):
|
|
self.message = message
|
|
super().__init__()
|
|
|
|
def __str__(self):
|
|
return self.message
|
|
|
|
|
|
def range_value_check(*args):
|
|
"""
|
|
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
|
|
"""
|
|
try:
|
|
args = tuple(arg.__index__() for arg in args)
|
|
|
|
# Compute range size and warn if it's too large
|
|
start = 0
|
|
end = 0
|
|
step = 1
|
|
if len(args) == 1:
|
|
end = args[0]
|
|
elif len(args) == 2:
|
|
start = args[0]
|
|
end = args[1]
|
|
elif len(args) == 3:
|
|
start = args[0]
|
|
end = args[1]
|
|
step = args[2]
|
|
|
|
range_length = (abs(end - start) - 1) // abs(step) + 1
|
|
if range_length >= 64:
|
|
warnings.warn(
|
|
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
|
|
category=DSLOptimizationWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
return (start, end, step)
|
|
except:
|
|
raise DSLRuntimeError(
|
|
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
|
|
suggestion="Use `range` instead of `range_constexpr`.",
|
|
)
|
|
|
|
|
|
def range_perf_warning(filename, lineno, *args):
|
|
has_dynamic_expr = False
|
|
for arg in args:
|
|
if executor._is_dynamic_expression(arg):
|
|
has_dynamic_expr = True
|
|
break
|
|
if not has_dynamic_expr:
|
|
warnings.warn_explicit(
|
|
(
|
|
"This loop is no longer unrolled and may cause performance regression. "
|
|
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
|
|
),
|
|
category=DSLOptimizationWarning,
|
|
filename=filename,
|
|
lineno=lineno,
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _get_self_module():
|
|
"""
|
|
This function is used to get the owning module of this function.
|
|
"""
|
|
return inspect.getmodule(_get_self_module)
|
|
|
|
|
|
def cf_symbol_check(symbol):
|
|
"""
|
|
Check if the symbol is control flow symbol from current module.
|
|
"""
|
|
|
|
failed = False
|
|
name = symbol.__name__
|
|
self_module = _get_self_module()
|
|
if inspect.ismodule(symbol):
|
|
name = "range"
|
|
if not self_module.__name__.startswith(symbol.__name__):
|
|
failed = True
|
|
else:
|
|
owning_module = inspect.getmodule(symbol)
|
|
if owning_module != self_module:
|
|
failed = True
|
|
|
|
if failed:
|
|
raise DSLRuntimeError(
|
|
f"Incorrect {symbol.__name__} is used.",
|
|
suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.",
|
|
)
|
|
|
|
|
|
def redirect_builtin_function(fcn):
|
|
"""
|
|
This function is used to redirect built-in function call
|
|
to the function defined in DSL package.
|
|
"""
|
|
# Only redirect if it's a built-in
|
|
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
|
|
return executor._builtin_redirector(fcn)
|
|
return fcn
|