# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. """ This module provides helper functions that are generated by the preprocessor. The preprocessor read through python's ast and changes the input code. """ from typing import Callable, Iterator, Optional, overload from typing_extensions import deprecated import warnings from .utils.logger import log from .common import * from ._mlir_helpers.arith import ArithValue class Executor: """ The Executor class handles dynamic and compile-time (constexpr) execution of "for" loops and "if-else-elif" statements. Methods: set_functions: Assigns the functions for checking loop bounds and conditional evaluation. for_execute: Generates MLIR for OP while_execute: Generates MLIR while OP if_execute: generate MLIR if OP """ def __init__(self): self._is_dynamic_expression = None self._loop_execute_range_dynamic = None self._if_dynamic = None self._while_dynamic = None self._compare_executor = None self._any_executor = None self._all_executor = None def set_functions( self, is_dynamic_expression: Callable, loop_execute_range_dynamic: Callable, if_dynamic: Callable, while_dynamic: Callable, compare_executor: Callable, any_executor: Callable = None, all_executor: Callable = None, ): self._is_dynamic_expression = is_dynamic_expression self._loop_execute_range_dynamic = loop_execute_range_dynamic self._if_dynamic = if_dynamic self._while_dynamic = while_dynamic self._compare_executor = compare_executor self._any_executor = any_executor self._all_executor = all_executor @staticmethod def convert_to_list(x): """This function is used to convert x to a list. If x is None, return an empty list. If x is not a list, return a list containing x. Otherwise, return x itself. """ if x is None: return [] if not isinstance(x, list): return [x] return x @staticmethod def converge_ret_val(res): """This function is used to converge res (the return value) of the function. If res is None, return None. If res is a list and has only one element, return the element. Otherwise, return res itself. """ if res is None: return res elif isinstance(res, list) and len(res) == 1: return res[0] return res @staticmethod def for_constexpr( func: Callable, start: int, stop: int, step: int, used_args: list, iter_args: list, ): log().debug("start [%s] stop [%s] step [%s]", start, stop, step) loop_results = iter_args log().debug("iter_args [%s]", iter_args) for i in range(start, stop, step): log().debug("i [%s] iter_args [%s]", i, iter_args) loop_results = func(i, *used_args, *loop_results) log().debug("loop_results [%s]", loop_results) if loop_results is None: loop_results = [] if not isinstance(loop_results, list): loop_results = [loop_results] log().debug("done loop_results [%s]", loop_results) return Executor.converge_ret_val(loop_results) def for_execute( self, func, start, stop, step, used_args=[], iter_args=[], iter_arg_names=[], unroll=-1, unroll_full=False, pipelining=None, ): assert ( self._loop_execute_range_dynamic ), "Functions must be set before execution." log().debug("start [%s] stop [%s] step [%s]", start, stop, step) return self._loop_execute_range_dynamic( func, start, stop, step, used_args, iter_args, iter_arg_names, unroll, unroll_full, pipelining, ) def if_execute( self, pred, then_block: Callable, else_block: Optional[Callable] = None, used_args=[], yield_args=[], yield_arg_names=[], ): assert self._if_dynamic, "Functions must be set before execution." # MLIR generation return self._if_dynamic( pred, then_block, else_block, used_args, yield_args, yield_arg_names ) def while_execute( self, pred, while_before_block: Callable, while_after_block: Callable, used_args=[], yield_args=[], yield_arg_names=[], ): assert self._while_dynamic, "Functions must be set before execution." # MLIR generation return self._while_dynamic( while_before_block, while_after_block, used_args, yield_args, yield_arg_names, ) # ============================================================================= # Decorator # ============================================================================= executor = Executor() def loop_selector( start, stop, step, *, used_args=[], iter_args=[], iter_arg_names=[], unroll=-1, unroll_full=False, pipelining=None, ): log().debug( "start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]", start, stop, step, used_args, iter_args, unroll, unroll_full, pipelining, ) from .typing import Integer, Numeric def _maybe_upcast(value): if isinstance(value, Integer): value = value.ir_value() return value start = _maybe_upcast(start) stop = _maybe_upcast(stop) step = _maybe_upcast(step) def ir_loop(func): return executor.for_execute( func, start, stop, step, used_args, iter_args, iter_arg_names, unroll, unroll_full, pipelining, ) return ir_loop def if_selector(pred, used_args=[], yield_args=[]): log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args) # Handle Numeric types here? from .typing import Numeric if isinstance(pred, Numeric): pred = pred.value def ir_loop(func): return func(pred, *used_args, *yield_args) return ir_loop def while_selector(pred, used_args=[], yield_args=[]): def ir_while_loop(func): return func(pred, *used_args, *yield_args) return ir_while_loop def while_executor( pred, while_before_block: Callable, while_after_block: Callable, used_args=[], yield_args=[], yield_arg_names=[], ): return executor.while_execute( pred, while_before_block, while_after_block, used_args, yield_args, yield_arg_names, ) def if_executor( pred, then_block: Callable, else_block: Optional[Callable] = None, used_args=[], yield_args=[], yield_arg_names=[], ): return executor.if_execute( pred, then_block, else_block, used_args, yield_args, yield_arg_names ) # ============================================================================= # Range # ============================================================================= class range: """ A range-like object for dynamic loop iteration in the DSL. This class provides a range interface similar to Python's built-in range, but is designed to be preprocessed into constructs for dynamic loop execution. The class supports both single-argument (stop) and three-argument (start, stop, step) constructors with additional parameters for loop optimization: - unroll: Number of iterations to unroll (0 or 1 = no unrolling) - unroll_full: Whether to fully unroll the loop - pipelining: Compiler generated pipeline configuration """ @overload def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None): pass @overload def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None): pass def __new__(cls, *args, **kwargs): raise DSLRuntimeError("dynamic range should be always preprocessed to IR") def __iter__(self) -> Iterator[int]: raise DSLRuntimeError("dynamic range should be always preprocessed to IR") @deprecated( "range_dynamic is deprecated and will be removed in the future, please remove it." ) def range_dynamic(*args, **kwargs): raise DSLRuntimeError("range_dynamic should be always preprocessed to IR") def range_constexpr(*args): raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.") # ============================================================================= # If expressions # ============================================================================= def const_expr(expression): """ This function is used to check if the expression is a python value. If the expression is a python value, return the boolean value of the expression. If the expression is a dynamic expression, raise an error. """ from .typing import Numeric failed = False if isinstance(expression, Numeric): if isinstance(expression.value, (int, float, bool)): return expression.value else: failed = True elif executor._is_dynamic_expression(expression): failed = True if failed: raise DSLRuntimeError( f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).", context={ "If your expression depends on dynamic values": "Remove `const_expr()`", }, ) return expression @deprecated( "dynamic_expr is deprecated and will be removed in the future, please remove it." ) def dynamic_expr(expression): return expression # ============================================================================= # Assertion & casting # ============================================================================= def assert_executor(test, msg=None): from .typing import Numeric fail = False # Implicit convert dynamic expression to bool is not allowed # So here explicitly do a None check if test is not None and executor._is_dynamic_expression(test): if isinstance(test, Numeric): try: test = test.to(bool) except: fail = True else: fail = True if not fail: assert test, msg else: raise DSLRuntimeError( "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", suggestion = "Please replace with runtime assert." ) def bool_cast(value): if executor._is_dynamic_expression(value): raise DSLRuntimeError( "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", suggestion = "Please explicitly convert to boolean with expressions like comparision." ) return bool(value) def compare_executor(left, comparators, ops): """ Executes comparison operations with a left operand and a list of comparators. Args: left: The leftmost value in the comparison chain comparators: A list of values to compare against ops: A list of comparison operators to apply Returns: The result of the comparison chain Raises: AssertionError: If the executor function is not set before execution """ assert ( executor._compare_executor is not None ), "Function must be set before execution." return executor._compare_executor(left, comparators, ops) def any_executor(iterable): """Executes the 'any' operation on an iterable, handling both dynamic and static expressions. :param iterable: An iterable to check if any elements evaluate to True :type iterable: Iterable :return: boolean of Python value or IR value :rtype: bool or cutlass.Boolean """ if executor._any_executor and executor._is_dynamic_expression(iterable): return executor._any_executor(iterable) else: return any(iterable) def all_executor(iterable): """Executes the 'all' operation on an iterable, handling both dynamic and static expressions. :param iterable: An iterable to check if all elements evaluate to True :type iterable: Iterable :return: boolean of Python value or IR value :rtype: bool or cutlass.Boolean """ if executor._all_executor and executor._is_dynamic_expression(iterable): return executor._all_executor(iterable) else: return all(iterable) # ============================================================================= # Control flow checks # ============================================================================= def range_value_check(*args): """ Ensure all `range_constexpr` bounds are compile-time constants (Python ints). """ try: args = tuple(arg.__index__() for arg in args) # Compute range size and warn if it's too large start = 0 end = 0 step = 1 if len(args) == 1: end = args[0] elif len(args) == 2: start = args[0] end = args[1] elif len(args) == 3: start = args[0] end = args[1] step = args[2] range_length = (abs(end - start) - 1) // abs(step) + 1 if range_length >= 64: warnings.warn( f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.", category=UserWarning, stacklevel=2, ) return (start, end, step) except: raise DSLRuntimeError( "`range_constexpr` requires constexpr (compile-time constant) for all arguments.", suggestion="Use `range` instead of `range_constexpr`.", ) def range_perf_warning(filename, lineno, *args): has_dynamic_expr = False for arg in args: if executor._is_dynamic_expression(arg): has_dynamic_expr = True break if not has_dynamic_expr: warnings.warn_explicit( ( "This loop is no longer unrolled and may cause performance regression. " "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants." ), category=UserWarning, filename=filename, lineno=lineno, )