Files
cutlass/python/CuTeDSL/base_dsl/ast_helpers.py
2025-05-13 15:55:29 -04:00

585 lines
17 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides helper functions that are generated by the preprocessor.
The preprocessor read through python's ast and changes the input code.
"""
from typing import Callable, Iterator, Optional, overload
from .utils.logger import log
from .common import *
from ._mlir_helpers.arith import ArithValue
class Executor:
"""
The Executor class handles dynamic and compile-time (constexpr) execution
of "for" loops and "if-else-elif" statements.
Methods:
set_functions: Assigns the functions for checking loop bounds and
conditional evaluation.
for_dynamic: Generates MLIR for OP
for_constexpr: Executes a for loop at JIT compile-time
for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds.
if_dynamic: Generates MLIR if OP
if_constexpr: Executes a if at JIT compile-time by python interpreter
if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate.
"""
def __init__(self):
self._is_dynamic_expression = None
self._loop_execute_range_dynamic = None
self._if_dynamic = None
self._while_dynamic = None
def set_functions(
self,
is_dynamic_expression: Callable,
loop_execute_range_dynamic: Callable,
if_dynamic: Callable,
while_dynamic: Callable,
):
self._is_dynamic_expression = is_dynamic_expression
self._loop_execute_range_dynamic = loop_execute_range_dynamic
self._if_dynamic = if_dynamic
self._while_dynamic = while_dynamic
@staticmethod
def convert_to_list(x):
"""This function is used to convert x to a list.
If x is None, return an empty list.
If x is not a list, return a list containing x.
Otherwise, return x itself.
"""
if x is None:
return []
if not isinstance(x, list):
return [x]
return x
@staticmethod
def converge_ret_val(res):
"""This function is used to converge res (the return value) of the function.
If res is None, return None.
If res is a list and has only one element, return the element.
Otherwise, return res itself.
"""
if res is None:
return res
elif isinstance(res, list) and len(res) == 1:
return res[0]
return res
def for_dynamic(
self,
func: Callable,
start,
stop,
step,
used_args: list,
iter_args: list,
iter_arg_names: list,
unroll=bool,
unroll_full=int,
):
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
return self._loop_execute_range_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
)
@staticmethod
def for_constexpr(
func: Callable,
start: int,
stop: int,
step: int,
used_args: list,
iter_args: list,
):
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
loop_results = iter_args
log().debug("iter_args [%s]", iter_args)
for i in range(start, stop, step):
log().debug("i [%s] iter_args [%s]", i, iter_args)
loop_results = func(i, *used_args, *loop_results)
log().debug("loop_results [%s]", loop_results)
if loop_results is None:
loop_results = []
if not isinstance(loop_results, list):
loop_results = [loop_results]
log().debug("done loop_results [%s]", loop_results)
return Executor.converge_ret_val(loop_results)
def for_execute(
self,
func,
start,
stop,
step,
used_args=[],
iter_args=[],
iter_arg_names=[],
unroll=-1,
unroll_full=False,
is_range_constexpr=None,
):
assert (
self._loop_execute_range_dynamic and self._is_dynamic_expression
), "Functions must be set before execution."
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
any_dynamic_expression = (
self._is_dynamic_expression(start)
or self._is_dynamic_expression(stop)
or self._is_dynamic_expression(step)
)
if is_range_constexpr is None:
if not any_dynamic_expression:
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
else:
return self.for_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
)
# Ensure bounds are compile-time constants for constexpr execution
if is_range_constexpr:
if any_dynamic_expression:
raise DSLRuntimeError(
"Loop bounds must be constexpr (compile-time constants)"
)
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
# MLIR generation
return self.for_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
)
def if_dynamic(
self,
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
return self._if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
@staticmethod
def if_constexpr(
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
):
if pred:
log().debug(" running then block [%s]", yield_args)
res = then_block(*used_args, *yield_args)
log().debug("result [%s]", res)
return Executor.converge_ret_val(res)
elif else_block is not None:
log().debug("running else [%s]", yield_args)
res = else_block(*used_args, *yield_args)
log().debug("result [%s]", res)
return Executor.converge_ret_val(res)
def if_execute(
self,
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
if_constexpr=None,
):
assert (
self._if_dynamic and self._is_dynamic_expression
), "Functions must be set before execution."
is_if_constexpr = not self._is_dynamic_expression(pred)
if if_constexpr is None:
if is_if_constexpr:
return self.if_constexpr(
pred, then_block, else_block, used_args, yield_args
)
else:
return self.if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
# Ensure bounds are compile-time constants for constexpr execution
if if_constexpr:
if not is_if_constexpr:
raise DSLRuntimeError(
"If predicate must be constexpr (compile-time constants)"
)
return self.if_constexpr(
pred, then_block, else_block, used_args, yield_args
)
# MLIR generation
return self.if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
def while_dynamic(
self,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
return self._while_dynamic(
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
)
@staticmethod
def while_constexpr(
while_before_block,
while_after_block,
used_args=[],
yield_args=[],
):
log().debug(
"while_constexpr begin %s", while_before_block.__qualname__
)
cond, loop_results = while_before_block(*used_args, *yield_args)
while cond:
loop_results = Executor.convert_to_list(loop_results)
log().debug(
"calling while_after [%s], [%s]",
used_args,
loop_results,
)
loop_results = while_after_block(*used_args, *loop_results)
log().debug(
"while after [%s]", loop_results
)
loop_results = Executor.convert_to_list(loop_results)
log().debug(
"calling while_before [%s], [%s]",
used_args,
loop_results,
)
cond, loop_results = while_before_block(*used_args, *loop_results)
log().debug(
"while_before cond, results [%s], [%s]",
cond,
loop_results,
)
log().debug(
"while_constexpr results %s", loop_results
)
return Executor.converge_ret_val(loop_results)
def while_execute(
self,
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
while_constexpr=None,
):
assert (
self._while_dynamic and self._is_dynamic_expression
), "Functions must be set before execution."
is_while_constexpr = not self._is_dynamic_expression(pred)
# Ensure bounds are compile-time constants for constexpr execution
if while_constexpr:
if not is_while_constexpr:
raise DSLRuntimeError(
"While predicate must be constexpr (compile-time constants)"
)
return self.while_constexpr(
while_before_block, while_after_block, used_args, yield_args
)
# MLIR generation
return self.while_dynamic(
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
)
# =============================================================================
# Decorator
# =============================================================================
executor = Executor()
def loop_selector(
start,
stop,
step,
used_args=[],
iter_args=[],
iter_arg_names=[],
unroll=-1,
unroll_full=False,
constexpr=None,
):
log().info(
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
start,
stop,
step,
used_args,
iter_args,
unroll,
unroll_full,
constexpr,
)
from .typing import Integer, Numeric
def _maybe_upcast(value):
if isinstance(value, Integer):
value = value.ir_value()
return value
start = _maybe_upcast(start)
stop = _maybe_upcast(stop)
step = _maybe_upcast(step)
def ir_loop(func):
return executor.for_execute(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
constexpr,
)
return ir_loop
def if_selector(pred, used_args=[], yield_args=[]):
log().info("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
# Handle Numeric types here?
from .typing import Numeric
if isinstance(pred, Numeric):
pred = pred.value
def ir_loop(func):
return func(pred, *used_args, *yield_args)
return ir_loop
def while_selector(pred, used_args=[], yield_args=[]):
def ir_while_loop(func):
return func(pred, *used_args, *yield_args)
return ir_while_loop
def while_executor(
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
constexpr=None,
):
return executor.while_execute(
pred,
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
constexpr,
)
def if_executor(
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
constexpr=None,
):
return executor.if_execute(
pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr
)
# =============================================================================
# Range
# =============================================================================
class range_dynamic:
@overload
def __new__(cls, stop, unroll=0, unroll_full=False):
pass
@overload
def __new__(cls, start, stop, step, unroll=0, unroll_full=False):
pass
def __new__(cls, *args, **kwargs):
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
class range_constexpr:
def __init__(self, *args):
if len(args) == 1:
self.start = 0
self.stop = args[0]
self.step = 1
elif len(args) == 2:
self.start, self.stop = args
self.step = 1
elif len(args) == 3:
self.start, self.stop, self.step = args
else:
raise DSLRuntimeError(
"range_constexpr supports up to 3 arguments (start, stop, step)"
)
# Ensure the arguments are compile-time constants (if required)
for arg_name, arg_value in [
("step", self.step),
("start", self.start),
("stop", self.stop),
]:
if executor._is_dynamic_expression(arg_value):
raise DSLRuntimeError(
f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, "
f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL "
f"will handle them during runtime. ",
suggestion="Use `range` instead of `range_constexpr`.",
)
def __iter__(self) -> Iterator[int]:
current = self.start
while current < self.stop:
yield current
current += self.step
# =============================================================================
# If expressions
# =============================================================================
def const_expr(expression):
if executor._is_dynamic_expression(expression):
raise DSLRuntimeError(
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
context={
"const_expr": "Accepts only constexpr (compile-time constant)",
"If your expression depends on dynamic values": "Avoid marking it as `const_expr()`",
"If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically",
},
)
return expression
def dynamic_expr(expression):
raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR")
# =============================================================================
# Assertion & casting
# =============================================================================
def assert_executor(test, msg=None):
from .typing import Numeric
fail = False
# Implicit convert dynamic expression to bool is not allowed
# So here explicitly do a None check
if test is not None and executor._is_dynamic_expression(test):
if isinstance(test, Numeric):
try:
test = test.to(bool)
except:
fail = True
else:
fail = True
if not fail:
assert test, msg
else:
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please replace with runtime assert."
)
def bool_cast(value):
if executor._is_dynamic_expression(value):
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please explicitly convert to boolean with expressions like comparision."
)
return bool(value)