585 lines
17 KiB
Python
585 lines
17 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module provides helper functions that are generated by the preprocessor.
|
|
The preprocessor read through python's ast and changes the input code.
|
|
"""
|
|
|
|
from typing import Callable, Iterator, Optional, overload
|
|
|
|
from .utils.logger import log
|
|
from .common import *
|
|
|
|
from ._mlir_helpers.arith import ArithValue
|
|
|
|
class Executor:
|
|
"""
|
|
The Executor class handles dynamic and compile-time (constexpr) execution
|
|
of "for" loops and "if-else-elif" statements.
|
|
|
|
Methods:
|
|
set_functions: Assigns the functions for checking loop bounds and
|
|
conditional evaluation.
|
|
|
|
for_dynamic: Generates MLIR for OP
|
|
for_constexpr: Executes a for loop at JIT compile-time
|
|
for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds.
|
|
|
|
if_dynamic: Generates MLIR if OP
|
|
if_constexpr: Executes a if at JIT compile-time by python interpreter
|
|
if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._is_dynamic_expression = None
|
|
self._loop_execute_range_dynamic = None
|
|
self._if_dynamic = None
|
|
self._while_dynamic = None
|
|
|
|
def set_functions(
|
|
self,
|
|
is_dynamic_expression: Callable,
|
|
loop_execute_range_dynamic: Callable,
|
|
if_dynamic: Callable,
|
|
while_dynamic: Callable,
|
|
):
|
|
self._is_dynamic_expression = is_dynamic_expression
|
|
self._loop_execute_range_dynamic = loop_execute_range_dynamic
|
|
self._if_dynamic = if_dynamic
|
|
self._while_dynamic = while_dynamic
|
|
|
|
@staticmethod
|
|
def convert_to_list(x):
|
|
"""This function is used to convert x to a list.
|
|
If x is None, return an empty list.
|
|
If x is not a list, return a list containing x.
|
|
Otherwise, return x itself.
|
|
"""
|
|
if x is None:
|
|
return []
|
|
if not isinstance(x, list):
|
|
return [x]
|
|
return x
|
|
|
|
@staticmethod
|
|
def converge_ret_val(res):
|
|
"""This function is used to converge res (the return value) of the function.
|
|
If res is None, return None.
|
|
If res is a list and has only one element, return the element.
|
|
Otherwise, return res itself.
|
|
"""
|
|
if res is None:
|
|
return res
|
|
elif isinstance(res, list) and len(res) == 1:
|
|
return res[0]
|
|
return res
|
|
|
|
def for_dynamic(
|
|
self,
|
|
func: Callable,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args: list,
|
|
iter_args: list,
|
|
iter_arg_names: list,
|
|
unroll=bool,
|
|
unroll_full=int,
|
|
):
|
|
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
|
|
return self._loop_execute_range_dynamic(
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args,
|
|
iter_args,
|
|
iter_arg_names,
|
|
unroll,
|
|
unroll_full,
|
|
)
|
|
|
|
@staticmethod
|
|
def for_constexpr(
|
|
func: Callable,
|
|
start: int,
|
|
stop: int,
|
|
step: int,
|
|
used_args: list,
|
|
iter_args: list,
|
|
):
|
|
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
|
|
loop_results = iter_args
|
|
log().debug("iter_args [%s]", iter_args)
|
|
for i in range(start, stop, step):
|
|
log().debug("i [%s] iter_args [%s]", i, iter_args)
|
|
loop_results = func(i, *used_args, *loop_results)
|
|
log().debug("loop_results [%s]", loop_results)
|
|
if loop_results is None:
|
|
loop_results = []
|
|
if not isinstance(loop_results, list):
|
|
loop_results = [loop_results]
|
|
|
|
log().debug("done loop_results [%s]", loop_results)
|
|
return Executor.converge_ret_val(loop_results)
|
|
|
|
def for_execute(
|
|
self,
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args=[],
|
|
iter_args=[],
|
|
iter_arg_names=[],
|
|
unroll=-1,
|
|
unroll_full=False,
|
|
is_range_constexpr=None,
|
|
):
|
|
assert (
|
|
self._loop_execute_range_dynamic and self._is_dynamic_expression
|
|
), "Functions must be set before execution."
|
|
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
|
any_dynamic_expression = (
|
|
self._is_dynamic_expression(start)
|
|
or self._is_dynamic_expression(stop)
|
|
or self._is_dynamic_expression(step)
|
|
)
|
|
|
|
if is_range_constexpr is None:
|
|
if not any_dynamic_expression:
|
|
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
|
|
else:
|
|
return self.for_dynamic(
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args,
|
|
iter_args,
|
|
iter_arg_names,
|
|
unroll,
|
|
unroll_full,
|
|
)
|
|
|
|
# Ensure bounds are compile-time constants for constexpr execution
|
|
if is_range_constexpr:
|
|
if any_dynamic_expression:
|
|
raise DSLRuntimeError(
|
|
"Loop bounds must be constexpr (compile-time constants)"
|
|
)
|
|
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
|
|
|
|
# MLIR generation
|
|
return self.for_dynamic(
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args,
|
|
iter_args,
|
|
iter_arg_names,
|
|
unroll,
|
|
unroll_full,
|
|
)
|
|
|
|
def if_dynamic(
|
|
self,
|
|
pred,
|
|
then_block: Callable,
|
|
else_block: Optional[Callable] = None,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
):
|
|
return self._if_dynamic(
|
|
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
|
)
|
|
|
|
@staticmethod
|
|
def if_constexpr(
|
|
pred,
|
|
then_block: Callable,
|
|
else_block: Optional[Callable] = None,
|
|
used_args=[],
|
|
yield_args=[],
|
|
):
|
|
if pred:
|
|
log().debug(" running then block [%s]", yield_args)
|
|
res = then_block(*used_args, *yield_args)
|
|
log().debug("result [%s]", res)
|
|
return Executor.converge_ret_val(res)
|
|
elif else_block is not None:
|
|
log().debug("running else [%s]", yield_args)
|
|
res = else_block(*used_args, *yield_args)
|
|
log().debug("result [%s]", res)
|
|
return Executor.converge_ret_val(res)
|
|
|
|
def if_execute(
|
|
self,
|
|
pred,
|
|
then_block: Callable,
|
|
else_block: Optional[Callable] = None,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
if_constexpr=None,
|
|
):
|
|
assert (
|
|
self._if_dynamic and self._is_dynamic_expression
|
|
), "Functions must be set before execution."
|
|
|
|
is_if_constexpr = not self._is_dynamic_expression(pred)
|
|
if if_constexpr is None:
|
|
if is_if_constexpr:
|
|
return self.if_constexpr(
|
|
pred, then_block, else_block, used_args, yield_args
|
|
)
|
|
else:
|
|
return self.if_dynamic(
|
|
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
|
)
|
|
|
|
# Ensure bounds are compile-time constants for constexpr execution
|
|
if if_constexpr:
|
|
if not is_if_constexpr:
|
|
raise DSLRuntimeError(
|
|
"If predicate must be constexpr (compile-time constants)"
|
|
)
|
|
return self.if_constexpr(
|
|
pred, then_block, else_block, used_args, yield_args
|
|
)
|
|
|
|
# MLIR generation
|
|
return self.if_dynamic(
|
|
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
|
)
|
|
|
|
def while_dynamic(
|
|
self,
|
|
while_before_block: Callable,
|
|
while_after_block: Callable,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
):
|
|
return self._while_dynamic(
|
|
while_before_block,
|
|
while_after_block,
|
|
used_args,
|
|
yield_args,
|
|
yield_arg_names,
|
|
)
|
|
|
|
@staticmethod
|
|
def while_constexpr(
|
|
while_before_block,
|
|
while_after_block,
|
|
used_args=[],
|
|
yield_args=[],
|
|
):
|
|
log().debug(
|
|
"while_constexpr begin %s", while_before_block.__qualname__
|
|
)
|
|
cond, loop_results = while_before_block(*used_args, *yield_args)
|
|
while cond:
|
|
loop_results = Executor.convert_to_list(loop_results)
|
|
log().debug(
|
|
"calling while_after [%s], [%s]",
|
|
used_args,
|
|
loop_results,
|
|
)
|
|
loop_results = while_after_block(*used_args, *loop_results)
|
|
log().debug(
|
|
"while after [%s]", loop_results
|
|
)
|
|
loop_results = Executor.convert_to_list(loop_results)
|
|
log().debug(
|
|
"calling while_before [%s], [%s]",
|
|
used_args,
|
|
loop_results,
|
|
)
|
|
cond, loop_results = while_before_block(*used_args, *loop_results)
|
|
log().debug(
|
|
"while_before cond, results [%s], [%s]",
|
|
cond,
|
|
loop_results,
|
|
)
|
|
|
|
log().debug(
|
|
"while_constexpr results %s", loop_results
|
|
)
|
|
return Executor.converge_ret_val(loop_results)
|
|
|
|
def while_execute(
|
|
self,
|
|
pred,
|
|
while_before_block: Callable,
|
|
while_after_block: Callable,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
while_constexpr=None,
|
|
):
|
|
assert (
|
|
self._while_dynamic and self._is_dynamic_expression
|
|
), "Functions must be set before execution."
|
|
|
|
is_while_constexpr = not self._is_dynamic_expression(pred)
|
|
|
|
# Ensure bounds are compile-time constants for constexpr execution
|
|
if while_constexpr:
|
|
if not is_while_constexpr:
|
|
raise DSLRuntimeError(
|
|
"While predicate must be constexpr (compile-time constants)"
|
|
)
|
|
return self.while_constexpr(
|
|
while_before_block, while_after_block, used_args, yield_args
|
|
)
|
|
|
|
# MLIR generation
|
|
return self.while_dynamic(
|
|
while_before_block,
|
|
while_after_block,
|
|
used_args,
|
|
yield_args,
|
|
yield_arg_names,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Decorator
|
|
# =============================================================================
|
|
|
|
executor = Executor()
|
|
|
|
|
|
def loop_selector(
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args=[],
|
|
iter_args=[],
|
|
iter_arg_names=[],
|
|
unroll=-1,
|
|
unroll_full=False,
|
|
constexpr=None,
|
|
):
|
|
log().info(
|
|
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args,
|
|
iter_args,
|
|
unroll,
|
|
unroll_full,
|
|
constexpr,
|
|
)
|
|
from .typing import Integer, Numeric
|
|
|
|
def _maybe_upcast(value):
|
|
if isinstance(value, Integer):
|
|
value = value.ir_value()
|
|
|
|
return value
|
|
|
|
start = _maybe_upcast(start)
|
|
stop = _maybe_upcast(stop)
|
|
step = _maybe_upcast(step)
|
|
|
|
def ir_loop(func):
|
|
return executor.for_execute(
|
|
func,
|
|
start,
|
|
stop,
|
|
step,
|
|
used_args,
|
|
iter_args,
|
|
iter_arg_names,
|
|
unroll,
|
|
unroll_full,
|
|
constexpr,
|
|
)
|
|
|
|
return ir_loop
|
|
|
|
|
|
def if_selector(pred, used_args=[], yield_args=[]):
|
|
log().info("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
|
|
# Handle Numeric types here?
|
|
|
|
from .typing import Numeric
|
|
|
|
if isinstance(pred, Numeric):
|
|
pred = pred.value
|
|
|
|
def ir_loop(func):
|
|
return func(pred, *used_args, *yield_args)
|
|
|
|
return ir_loop
|
|
|
|
|
|
def while_selector(pred, used_args=[], yield_args=[]):
|
|
def ir_while_loop(func):
|
|
return func(pred, *used_args, *yield_args)
|
|
|
|
return ir_while_loop
|
|
|
|
|
|
def while_executor(
|
|
pred,
|
|
while_before_block: Callable,
|
|
while_after_block: Callable,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
constexpr=None,
|
|
):
|
|
return executor.while_execute(
|
|
pred,
|
|
while_before_block,
|
|
while_after_block,
|
|
used_args,
|
|
yield_args,
|
|
yield_arg_names,
|
|
constexpr,
|
|
)
|
|
|
|
|
|
def if_executor(
|
|
pred,
|
|
then_block: Callable,
|
|
else_block: Optional[Callable] = None,
|
|
used_args=[],
|
|
yield_args=[],
|
|
yield_arg_names=[],
|
|
constexpr=None,
|
|
):
|
|
return executor.if_execute(
|
|
pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Range
|
|
# =============================================================================
|
|
|
|
|
|
class range_dynamic:
|
|
@overload
|
|
def __new__(cls, stop, unroll=0, unroll_full=False):
|
|
pass
|
|
|
|
@overload
|
|
def __new__(cls, start, stop, step, unroll=0, unroll_full=False):
|
|
pass
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
|
|
|
|
|
|
class range_constexpr:
|
|
def __init__(self, *args):
|
|
if len(args) == 1:
|
|
self.start = 0
|
|
self.stop = args[0]
|
|
self.step = 1
|
|
elif len(args) == 2:
|
|
self.start, self.stop = args
|
|
self.step = 1
|
|
elif len(args) == 3:
|
|
self.start, self.stop, self.step = args
|
|
else:
|
|
raise DSLRuntimeError(
|
|
"range_constexpr supports up to 3 arguments (start, stop, step)"
|
|
)
|
|
# Ensure the arguments are compile-time constants (if required)
|
|
for arg_name, arg_value in [
|
|
("step", self.step),
|
|
("start", self.start),
|
|
("stop", self.stop),
|
|
]:
|
|
if executor._is_dynamic_expression(arg_value):
|
|
raise DSLRuntimeError(
|
|
f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, "
|
|
f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL "
|
|
f"will handle them during runtime. ",
|
|
suggestion="Use `range` instead of `range_constexpr`.",
|
|
)
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
current = self.start
|
|
while current < self.stop:
|
|
yield current
|
|
current += self.step
|
|
|
|
|
|
# =============================================================================
|
|
# If expressions
|
|
# =============================================================================
|
|
|
|
|
|
def const_expr(expression):
|
|
if executor._is_dynamic_expression(expression):
|
|
raise DSLRuntimeError(
|
|
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
|
|
context={
|
|
"const_expr": "Accepts only constexpr (compile-time constant)",
|
|
"If your expression depends on dynamic values": "Avoid marking it as `const_expr()`",
|
|
"If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically",
|
|
},
|
|
)
|
|
return expression
|
|
|
|
|
|
def dynamic_expr(expression):
|
|
raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR")
|
|
|
|
|
|
# =============================================================================
|
|
# Assertion & casting
|
|
# =============================================================================
|
|
|
|
|
|
def assert_executor(test, msg=None):
|
|
from .typing import Numeric
|
|
|
|
fail = False
|
|
# Implicit convert dynamic expression to bool is not allowed
|
|
# So here explicitly do a None check
|
|
if test is not None and executor._is_dynamic_expression(test):
|
|
if isinstance(test, Numeric):
|
|
try:
|
|
test = test.to(bool)
|
|
except:
|
|
fail = True
|
|
else:
|
|
fail = True
|
|
|
|
if not fail:
|
|
assert test, msg
|
|
else:
|
|
raise DSLRuntimeError(
|
|
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
|
suggestion = "Please replace with runtime assert."
|
|
)
|
|
|
|
|
|
def bool_cast(value):
|
|
if executor._is_dynamic_expression(value):
|
|
raise DSLRuntimeError(
|
|
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
|
suggestion = "Please explicitly convert to boolean with expressions like comparision."
|
|
)
|
|
return bool(value)
|