v4.2 tag release. (#2638)

This commit is contained in:
Junkai-Wu
2025-09-16 00:21:53 +08:00
committed by GitHub
parent 56f0718a97
commit 6a35b4d22f
161 changed files with 14056 additions and 3793 deletions

View File

@ -17,6 +17,7 @@ from ..base_dsl.ast_helpers import (
if_executor,
while_selector,
while_executor,
range,
range_constexpr,
range_dynamic,
const_expr,
@ -28,6 +29,8 @@ from ..base_dsl.ast_helpers import (
all_executor,
range_value_check,
range_perf_warning,
cf_symbol_check,
redirect_builtin_function,
)
from ..base_dsl import *
@ -38,5 +41,4 @@ from ..base_dsl._mlir_helpers.op import dsl_user_op
from ..base_dsl.runtime import *
from ..base_dsl.runtime import cuda as cuda_helpers
from ..base_dsl.compiler import compile
from ..base_dsl.runtime.dlpack_runtime import *
from ..base_dsl.runtime.jit_arg_adapters import *

View File

@ -15,12 +15,14 @@ regarding to that dialect.
"""
# Local module imports
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef
from inspect import isclass
from itertools import chain
from types import GenericAlias, SimpleNamespace, UnionType
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any
import functools
import pkgutil
from dataclasses import is_dataclass
from dataclasses import is_dataclass, fields
from collections.abc import Sequence
import builtins
from ..base_dsl import *
from ..base_dsl import compiler
@ -51,20 +53,15 @@ from ..base_dsl.ast_helpers import (
while_selector,
while_executor,
assert_executor,
const_expr,
dynamic_expr,
bool_cast,
compare_executor,
any_executor,
all_executor,
range_value_check,
range_perf_warning,
)
from ..base_dsl.runtime.dlpack_runtime import (
get_cute_tensor_c_pointer,
get_tensor_desc_shape_all,
get_tensor_desc_stride_all,
get_tensor_desc_element_type,
get_tensor_desc_is_in_device,
get_tensor_desc_assumed_align,
cf_symbol_check,
)
from .cutlass_ast_decorators import (
@ -73,6 +70,16 @@ from .cutlass_ast_decorators import (
_while_execute_dynamic,
)
from .tree_utils import (
is_constexpr_field,
tree_flatten,
tree_unflatten,
PyTreeDef,
is_frozen_dataclass,
DSLTreeFlattenError,
)
from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry
# =============================================================================
# Cutlass DSL Base Abstract Class
@ -125,6 +132,46 @@ def is_cute_algebra_type(arg_spec):
return False
def _get_c_pointers_cutlass(obj):
"""
This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict.
"""
if hasattr(obj, "__c_pointers__"):
return obj.__c_pointers__()
elif isinstance(obj, (tuple, list)):
return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj))
elif isinstance(obj, SimpleNamespace):
return list(
chain.from_iterable(
_get_c_pointers_cutlass(x) for x in obj.__dict__.values()
)
)
elif isinstance(obj, dict):
return list(
chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values())
)
elif is_dataclass(obj):
return list(
chain.from_iterable(
_get_c_pointers_cutlass(getattr(obj, f.name))
for f in fields(obj)
if not is_constexpr_field(f)
)
)
elif isinstance(obj, set):
raise DSLRuntimeError(
"Sets are not supported in get_c_pointers to ensure order preservation",
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
suggestion="Consider using a list or tuple instead",
)
else:
# Try get adapter
adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj))
if adapter is not None:
return _get_c_pointers_cutlass(adapter(obj))
return []
class CutlassBaseDSL(BaseDSL):
"""This abstract class provides a DSL for Cutlass."""
@ -137,16 +184,25 @@ class CutlassBaseDSL(BaseDSL):
preprocess: bool = False,
):
super().__init__(
name,
compiler_provider,
pass_sm_arch_name,
device_compilation_only,
preprocess,
name=name,
dsl_package_name=["cutlass"],
compiler_provider=compiler_provider,
pass_sm_arch_name=pass_sm_arch_name,
device_compilation_only=device_compilation_only,
preprocess=preprocess,
)
self._smem_usage_tracker: tuple = None
# this method is not useful for cutlass_dsl, so we only provide a dummy implementation.
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
return False
# this method is not useful for cutlass_dsl, so we only provide a dummy implementation.
def _handle_tensor_descriptor(
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
) -> Any:
return False
def _build_gpu_module(self, attrs):
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"))
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
@ -229,8 +285,43 @@ class CutlassBaseDSL(BaseDSL):
return version_hash
@staticmethod
def track_smem_allocator(allocator, callback):
"""
Tracks shared memory usage for kernel functions.
Find and set allocator to its parent dsl object.
"""
frame = inspect.currentframe().f_back
while frame:
obj = frame.f_locals.get("self", None)
if obj and isinstance(obj, CutlassBaseDSL):
obj._set_smem_tracking(allocator, callback)
return
frame = frame.f_back
warnings.warn("Cannot find parent dsl for allocator!", UserWarning)
def _set_smem_tracking(self, allocator, callback):
# Registers an allocator and callback for current dsl
self._smem_usage_tracker = (allocator, callback)
def _reset_smem_tracking(self):
# Clear an allocator and callback for current dsl
self._smem_usage_tracker = None
def _get_smem_usage(self) -> int:
# Treat final allocated bytes of allocator as smem usage
if not self._smem_usage_tracker:
return 0
allocator, callback = self._smem_usage_tracker
return callback(allocator)
def _kernel_helper(self, funcBody, *args, **kwargs):
class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper):
def __init__(self, dsl: CutlassBaseDSL):
super().__init__()
self.dsl = dsl
self.dsl._reset_smem_tracking()
def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None):
super().generate_func_op(arg_types, arg_attrs, kernel_name)
self.func_op = func.FuncOp(
@ -272,6 +363,17 @@ class CutlassBaseDSL(BaseDSL):
if cfg.has_cluster:
cfg.cluster = [to_index(size) for size in cfg.cluster]
smem_usage = self.dsl._get_smem_usage()
if any(not isinstance(x, int) for x in [cfg.smem, smem_usage]):
pass # cannot compare dynamic value inside kernel to launch op in py
elif cfg.auto_smem:
cfg.smem = smem_usage
elif smem_usage > cfg.smem:
warnings.warn(
f"Potential error: specified kernel launch smem bytes "
f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!",
UserWarning,
)
cfg.smem = const(cfg.smem)
if not isinstance(cfg.async_deps, (list, tuple)):
@ -295,12 +397,13 @@ class CutlassBaseDSL(BaseDSL):
return token if is_async else None
return KernelLauncher(
self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs
self,
lambda: _CutlassIrKernelGenHelper(self),
funcBody,
*args,
**kwargs,
)
def _get_module_globals(self):
return globals()
def _preprocess_launch_config_args(self, args, kwargs):
"""Helper to preprocess args and kwargs for LaunchConfig"""
if "stream" in kwargs:
@ -316,7 +419,10 @@ class CutlassBaseDSL(BaseDSL):
Validates if the arg is really of the annotated type.
"""
if is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None):
if (
is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None)
or arg_annotation is Any
):
pass
else:
origin = get_origin(arg_annotation)
@ -329,11 +435,12 @@ class CutlassBaseDSL(BaseDSL):
f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}"
)
# Handle Union types and generic types
elif origin is Union:
elif origin is Union or isinstance(arg_annotation, UnionType):
# For Union types, check if arg matches any of the allowed types
allowed_types = get_args(arg_annotation)
if not any(
(isinstance(ty, type) and isinstance(arg, ty))
(ty is Any)
or (isinstance(ty, type) and isinstance(arg, ty))
or (get_origin(ty) is tuple and isinstance(arg, tuple))
for ty in allowed_types
):
@ -381,6 +488,26 @@ class CutlassBaseDSL(BaseDSL):
jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals)
else:
jit_exec_arg = jit_arg_type = jit_arg_attr = None
elif not hasattr(arg, "__extract_mlir_values__") and not hasattr(
arg, "__new_from_mlir_values__"
):
# Try tree_flatten
try:
dyn_vals, _ = tree_flatten(arg)
except DSLTreeFlattenError:
# If fails, just return the original arg
return jit_exec_arg, jit_arg_type, jit_arg_attr
if dyn_vals:
jit_arg_type.extend([v.type for v in dyn_vals])
jit_arg_attr.extend([default_attr] * len(dyn_vals))
jit_exec_arg.extend(
_get_c_pointers_cutlass(arg) if is_host else dyn_vals
)
else:
# If tree flatten yields empty list, treat it as a constexpr thing
# Like a dataclass with all fields are constexpr, or an empty tuple or list
jit_exec_arg = jit_arg_type = jit_arg_attr = None
return jit_exec_arg, jit_arg_type, jit_arg_attr
def _generate_execution_arguments_for_known_types(
@ -396,6 +523,17 @@ class CutlassBaseDSL(BaseDSL):
blk_args = fop_args[iv_block_args : iv_block_args + n_args]
ir_arg.append(new_from_mlir_values(arg, blk_args))
iv_block_args += n_args
elif not hasattr(arg, "__extract_mlir_values__") and not hasattr(
arg, "__new_from_mlir_values__"
):
# Try tree_unflatten
try:
dyn_vals, tree_def = tree_flatten(arg)
block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)]
ir_arg.append(tree_unflatten(tree_def, block_args))
iv_block_args += len(dyn_vals)
except DSLTreeFlattenError:
return ir_arg, iv_block_args
return ir_arg, iv_block_args
@ -458,10 +596,7 @@ class KernelLauncher:
def _check_func_args(self, funcBody, *func_args, **func_kwargs):
# Get function signature
if isinstance(funcBody, DSLCallable):
sig = funcBody.get_signature()
else:
sig = inspect.signature(funcBody)
sig = inspect.signature(funcBody)
# func_args and func_kwargs should match funcBody's signature,
# no extra or missing arguments.
@ -473,6 +608,12 @@ class KernelLauncher:
cause=e,
)
def smem_usage(self) -> int:
"""
Check smem usage for this kernel, only available after `launch`
"""
return self.dsl._get_smem_usage()
def launch(self, *args, **kwargs):
self.dsl.frame = inspect.currentframe().f_back
self.dsl._preprocess_launch_config_args(args, kwargs)
@ -497,134 +638,151 @@ class KernelLauncher:
# =============================================================================
# Utils
# =============================================================================
def is_frozen_dataclass(obj_or_cls) -> bool:
def _filter_readonly_frozen_dataclass(
iter_args: List[Any], items_to_filter: List[Any], full_write_args_count: int
) -> List[Any]:
"""
Return True if obj_or_cls is a dataclass (class or instance) declared with frozen=True,
otherwise False.
"""
if not isinstance(obj_or_cls, type):
# If it's an instance, get its class
obj_or_cls = obj_or_cls.__class__
Filter items based on whether corresponding iter_args are frozen dataclasses.
# Must be a dataclass, and __dataclass_params__.frozen must be True
return (
is_dataclass(obj_or_cls)
and getattr(obj_or_cls, "__dataclass_params__", None) is not None
and obj_or_cls.__dataclass_params__.frozen
This function filters items (which can be values or names) based on the same
logic: keep items if they correspond to full-write arguments (index < full_write_args_count)
or if the corresponding iter_arg is not a frozen dataclass.
Args:
iter_args: List of arguments to check for frozen dataclass status
items_to_filter: List of items to filter (values or names)
full_write_args_count: Number of arguments that are always written (not read-only)
Returns:
Filtered list of items
Examples:
# Filter values (original remove_read_only_frozen_dataclass behavior)
filtered_values = _filter_readonly_frozen_dataclass(iter_args, iter_args, full_write_args_count)
# Filter names (original filter_readonly_frozen_dataclass_names behavior)
filtered_names = _filter_readonly_frozen_dataclass(iter_args, iter_args_names, full_write_args_count)
"""
return [
item
for i, item in enumerate(items_to_filter)
if i < full_write_args_count or not is_frozen_dataclass(iter_args[i])
]
def remove_read_only_frozen_dataclass(
iter_args: List[Any], full_write_args_count: int
) -> List[Any]:
"""Filter out frozen dataclass arguments that are not full-write arguments."""
return _filter_readonly_frozen_dataclass(
iter_args, iter_args, full_write_args_count
)
def filter_readonly_frozen_dataclass_names(
iter_args: List[Any], iter_args_names: List[str], full_write_args_count: int
) -> List[str]:
"""Filter names based on whether corresponding iter_args are frozen dataclasses."""
return _filter_readonly_frozen_dataclass(
iter_args, iter_args_names, full_write_args_count
)
def insert_read_only_frozen_dataclass(
iter_args: List[Any], original_iter_args: List[Any], full_write_args_count: int
) -> List[Any]:
"""
Insert read-only frozen dataclass arguments back into the iteration arguments.
This function takes the new iteration arguments and the original arguments,
and preserves frozen dataclass instances from the original arguments while
using the new arguments for non-frozen dataclass instances.
Args:
iter_args: New iteration arguments to use for non-frozen dataclass instances
original_iter_args: Original iteration arguments to preserve frozen dataclass instances from
full_write_args_count: Number of arguments that are always written (not read-only)
Returns:
List of arguments with frozen dataclass instances preserved from original
"""
# Take full-write arguments from new iter_args
full_write_args = (
iter_args[:full_write_args_count] if full_write_args_count > 0 else []
)
# Process remaining arguments: preserve frozen dataclass from original, use new for others
remaining_original = original_iter_args[full_write_args_count:]
remaining_new = iter_args[full_write_args_count:]
def process_remaining_arg(original_arg, new_arg_iter):
"""Process a single remaining argument, preserving frozen dataclass if present"""
return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter)
# Use zip to pair original args with new args, then map the processing function
new_arg_iter = iter(remaining_new)
processed_remaining = [
process_remaining_arg(orig_arg, new_arg_iter) for orig_arg in remaining_original
]
return full_write_args + processed_remaining
def unpack_to_irvalue(
mixed_values: List[Any], body_name: str, full_write_args_count: int
) -> Tuple[List[ir.Value], PyTreeDef]:
log().debug("===--- Values UNPack")
for idx, packed in enumerate(mixed_values):
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
try:
unpacked_values, treedef = tree_flatten(
remove_read_only_frozen_dataclass(mixed_values, full_write_args_count)
)
except DSLTreeFlattenError as e:
raise DSLRuntimeError(
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression.",
context={
e.message: (
f"All expressions within '{body_name}' must be dynamic expressions, "
"mixing Python objects and dynamic expressions is not supported. "
"The DSL failed to convert the Python object into dynamic expressions."
)
},
suggestion=(
f"Please ensure '{e.type_str}' implements the '{DynamicExpression.__name__}' or mark with `dataclass`, "
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
),
)
log().debug("------------------ ")
for idx, unpacked in enumerate(unpacked_values):
log().debug("[%d]: unpacked values: %s", idx, unpacked)
log().debug("treedef: %s", treedef)
log().debug("------------------ ")
return unpacked_values, treedef
def pack_from_irvalue(
ir_values: List["ir.Value"],
indices: Dict[int, Tuple[int, int]],
class_types: List[Any],
pytree_def: PyTreeDef,
mixed_values: List[Any],
full_write_args_count: int,
) -> List[Any]:
"""
Packs MLIR values into a list of mixed values.
"""
log().debug("===--- Values Pack (%d)", len(ir_values))
for idx, packed in enumerate(ir_values):
log().debug("[%d]: will-packed: %s", idx, ir_values)
for idx, unpacked in indices.items():
log().debug("[%d]: indices: %s", idx, unpacked)
for idx, c in enumerate(class_types):
log().debug("[%d]: obj-types: %s", idx, type(c))
mixed_values = [None] * len(indices)
for idx, (start, length) in sorted(indices.items()):
chunk = ir_values[start : start + length]
obj = class_types[idx]
if is_frozen_dataclass(obj):
mixed_values[idx] = obj
elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"):
mixed_values[idx] = obj.__new_from_mlir_values__(chunk)
elif isinstance(chunk, list) and chunk[0] is None:
mixed_values[idx] = class_types[idx]
else:
if len(chunk) == 1:
try:
mixed_values[idx] = t.as_numeric(chunk[0])
except ValueError:
# Suppress the conversion error and try new_from_mlir_values below
pass
if mixed_values[idx] is None:
mixed_values[idx] = new_from_mlir_values(obj, chunk)
log().debug("------------------ ")
for idx, packed in enumerate(mixed_values):
log().debug("[%d]: packed: %s", idx, packed)
log().debug("------------------ ")
return mixed_values
def unpack_to_irvalue(
mixed_values: List[Any], body_name: str
) -> Tuple[List[ir.Value], List[Any], Dict[int, Tuple[int, int]], List[Any]]:
"""
Unpacks mixed values into ir.Value values.
"""
unpacked_values = []
ir_values = []
indices = {}
class_types = []
current_offset = 0
log().debug("===--- Values UNPack (%d)", len(mixed_values))
for idx, packed in enumerate(mixed_values):
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
for idx, item in enumerate(mixed_values):
class_types.append(item)
try:
if is_frozen_dataclass(item):
extracted_vals = [None]
else:
extracted_vals = extract_mlir_values(item)
# it's consexpr (python value), so we create mlir value for it
if extracted_vals == []:
if item is None:
extracted_vals = [None]
else:
dyn_expr = t.as_numeric(item)
extracted_vals = extract_mlir_values(dyn_expr)
ir_values.extend(extracted_vals)
else:
ir_values.extend(extracted_vals)
unpacked_values.extend(extracted_vals)
length = len(extracted_vals)
indices[idx] = (current_offset, length)
current_offset += length
except Exception as e:
raise DSLRuntimeError(
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression (aka MLIR value).",
context={
item: (
f"All expressions within '{body_name}' must be dynamic expressions, "
"mixing Python objects and dynamic expressions (aka MLIR values) is not supported. "
"The DSL failed to convert the Python object into MLIR values."
)
},
suggestion=(
f"Please ensure '{item}' implements the '{DynamicExpression.__name__}', "
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
),
) from e
log().debug("------------------ ")
for idx, unpacked in enumerate(unpacked_values):
log().debug("[%d]: unpacked values: %s", idx, unpacked)
for idx, unpacked in enumerate(ir_values):
log().debug("[%d]: unpacked ir_values: %s", idx, unpacked)
for idx, unpacked in indices.items():
log().debug("[%d]: indices: %s", idx, unpacked)
for idx, unpacked in enumerate(class_types):
log().debug("[%d]: initial-class-types: %s", idx, unpacked)
for idx, value in enumerate(ir_values):
log().debug("[%d]: will-packed: %s", idx, value)
log().debug("treedef: %s", pytree_def)
log().debug("------------------ ")
return ir_values, unpacked_values, indices, class_types
unflattened = tree_unflatten(pytree_def, ir_values)
return insert_read_only_frozen_dataclass(
unflattened, mixed_values, full_write_args_count
)
def to_index(value):
@ -1015,8 +1173,8 @@ def any_(iterable):
def select_(cond, if_value, else_value):
def _as_scalar(value):
if const_expr(isinstance(value, list)):
if const_expr(len(value) == 1):
if isinstance(value, list):
if len(value) == 1:
return value[0]
else:
raise DSLRuntimeError(
@ -1024,16 +1182,16 @@ def select_(cond, if_value, else_value):
)
return value
if const_expr(not is_dynamic_expression(cond)):
if not is_dynamic_expression(cond):
raise DSLRuntimeError("Conditional expression must be dynamic")
# Extract MLIR values
cond = extract_mlir_values(cond)
if const_expr(is_dynamic_expression(if_value)):
if is_dynamic_expression(if_value):
if_value = extract_mlir_values(if_value)
else:
if_value = const(if_value)
if const_expr(is_dynamic_expression(else_value)):
if is_dynamic_expression(else_value):
else_value = extract_mlir_values(else_value)
else:
else_value = const(else_value)
@ -1089,7 +1247,7 @@ def for_generate(
iter_args: Optional[Sequence[ir.Value]] = None,
*,
unroll: LoopUnroll = None,
pipelining=None,
prefetch_stages=None,
loc=None,
ip=None,
):
@ -1127,8 +1285,8 @@ def for_generate(
if unroll is not None:
for_op.attributes["loop_annotation"] = unroll
if pipelining is not None:
for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining)
if prefetch_stages is not None:
for_op.attributes["cutlass.pipelining"] = _createI32Attr(prefetch_stages)
iv = for_op.induction_variable
new_results = new_from_mlir_values(iter_args, for_op.results)
@ -1155,11 +1313,11 @@ def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None):
"""
res = None
# Handle Python bool first to prevent infinite recursion
if const_expr(type(lhs) == bool):
if type(lhs) == bool:
res = lhs ^ True
elif const_expr(hasattr(lhs, "__dsl_not__")):
elif hasattr(lhs, "__dsl_not__"):
res = lhs.__dsl_not__(loc=loc, ip=ip)
elif const_expr(is_dynamic_expression(lhs)):
elif is_dynamic_expression(lhs):
# If lhs is MLIR value, compute not using xor
res = arith.XOrIOp(lhs, const(1, lhs.type)).result
else:
@ -1338,29 +1496,59 @@ def equal(lhs, rhs):
return lhs == rhs
def in_(lhs, rhs, op):
def not_equal(lhs, rhs):
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return lhs != rhs
# Both sequence
if isinstance(lhs, Sequence) and isinstance(rhs, Sequence):
# Short-circuit for unequal length
if len(lhs) != len(rhs):
return True
return any_(not_equal(l, r) for l, r in zip(lhs, rhs))
if hasattr(lhs, "__ne__"):
return lhs != rhs
elif hasattr(rhs, "__ne__"):
return rhs != lhs
else:
return not_(equal(lhs, rhs))
def in_(lhs, rhs):
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return lhs in rhs
if not isinstance(rhs, Sequence):
raise DSLRuntimeError(
f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}"
f"'in' not supported between instances of {type(lhs)} and {type(rhs)}"
)
return any_(equal(lhs, r) for r in rhs)
def _lt_gt(lhs, rhs, op):
def native_lt_gt(lhs, rhs, op):
if op == "<":
return lhs < rhs
elif op == ">":
return lhs > rhs
else:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
def _lte_gte(lhs, rhs, op):
def native_lte_gte(lhs, rhs, op):
match op:
case "<":
return lhs < rhs
case "<=":
if hasattr(lhs, "__le__"):
return lhs <= rhs
else:
return not_(lhs > rhs)
case ">":
return lhs > rhs
case ">=":
if hasattr(lhs, "__ge__"):
return lhs >= rhs
else:
return not_(lhs < rhs)
case _:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return native_lt_gt(lhs, rhs, op)
return native_lte_gte(lhs, rhs, op)
# Both sequence, comparisons other than == and != do not allow mixing different types of sequences
if (
@ -1375,7 +1563,7 @@ def _lt_gt(lhs, rhs, op):
is_equal = equal(l, r)
mask.append(not_(or_(is_equal, unequal_found)))
unequal_found = not_(is_equal)
comp_results.append(_lt_gt(l, r, op))
comp_results.append(_lte_gte(l, r, op))
result = any_(and_(r, m) for r, m in zip(comp_results, mask))
@ -1383,62 +1571,126 @@ def _lt_gt(lhs, rhs, op):
# Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types
# If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one
has_valid_mask = any_(mask)
if op == "<":
length_result = len(lhs) < len(rhs)
elif op == ">":
length_result = len(lhs) > len(rhs)
match op:
case "<":
length_result = len(lhs) < len(rhs)
case ">":
length_result = len(lhs) > len(rhs)
case "<=":
length_result = len(lhs) <= len(rhs)
case ">=":
length_result = len(lhs) >= len(rhs)
if type(has_valid_mask) == bool:
return result if has_valid_mask else length_result
else:
return select_(has_valid_mask, result, length_result)
else:
return result
if op in {"<=", ">="}:
# If no unequal, return True
return select_(unequal_found, result, True)
else:
return result
else:
return native_lt_gt(lhs, rhs, op)
return native_lte_gte(lhs, rhs, op)
def greater_than(lhs, rhs):
return _lt_gt(lhs, rhs, ">")
return _lte_gte(lhs, rhs, ">")
def greater_equal(lhs, rhs):
return _lte_gte(lhs, rhs, ">=")
def less_than(lhs, rhs):
return _lt_gt(lhs, rhs, "<")
return _lte_gte(lhs, rhs, "<")
def less_equal(lhs, rhs):
return _lte_gte(lhs, rhs, "<=")
def _compare_dispatch(lhs, rhs, op):
"""
Dispatches the comparison operation between lhs and rhs based on the given operator.
:param lhs: The left-hand side operand for the comparison.
:param rhs: The right-hand side operand for the comparison.
:param op: The comparison operator as a string. Supported operators are:
- "is", "is not": Python identity comparisons.
- "in", "not in": Membership tests.
- "==", "!=": Equality and inequality.
- "<", ">", "<=", ">=": Relational comparisons.
:return: The result of the comparison, which may be a boolean or a DSL-specific type.
:raises DSLRuntimeError: If the operator is not supported.
"""
match op:
# 'is' and 'is not' are pure python operators
case "is":
return lhs is rhs
case "is not":
return lhs is not rhs
case "in":
return in_(lhs, rhs)
case "not in":
return not_(in_(lhs, rhs))
case "==":
return equal(lhs, rhs)
case "!=":
return not_equal(lhs, rhs)
case "<":
return less_than(lhs, rhs)
case ">":
return greater_than(lhs, rhs)
case ">=":
return greater_equal(lhs, rhs)
case "<=":
return less_equal(lhs, rhs)
case _:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
def _compare_executor(left, comparators, ops):
result = left
# Fast path for single comparison
if len(comparators) == 1:
return _compare_dispatch(left, comparators[0], ops[0])
# Chain comparison, dispatch in a loop
result = True
current = left
for comparator, op in zip(comparators, ops):
# 'is' and 'is not' are pure python operators
if op == "is":
result = result is comparator
elif op == "is not":
result = result is not comparator
elif op in ["in", "not in"]:
result = in_(left, comparator, op)
elif op in ["==", "!="]:
result = equal(left, comparator)
elif op in ["<", ">="]:
result = less_than(left, comparator)
elif op in [">", "<="]:
result = greater_than(left, comparator)
else:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
# Invert the result for NotIn, NotEq, GtE, LtE
if op in ["not in", "!=", ">=", "<="]:
result = not_(result)
cmp_result = _compare_dispatch(current, comparator, op)
result = and_(result, cmp_result)
current = comparator
return result
def _builtin_redirector(fcn):
if fcn == builtins.max:
return max
elif fcn == builtins.min:
return min
elif fcn == builtins.any:
return any_
elif fcn == builtins.all:
return all_
else:
raise DSLRuntimeError(f"Unsupported built-in function: {fcn}")
# =============================================================================
# Set the AST decorator
# =============================================================================
# Set the DSL specific functions
executor.set_functions(
is_dynamic_expression,
_loop_execute_range_dynamic,
_if_execute_dynamic,
_while_execute_dynamic,
_compare_executor,
any_,
all_,
is_dynamic_expression=is_dynamic_expression,
loop_execute_range_dynamic=_loop_execute_range_dynamic,
if_dynamic=_if_execute_dynamic,
while_dynamic=_while_execute_dynamic,
compare_executor=_compare_executor,
any_executor=any_,
all_executor=all_,
builtin_redirector=_builtin_redirector,
)

View File

@ -14,13 +14,22 @@ from types import NoneType
from cutlass._mlir import ir
from cutlass._mlir.dialects import scf, arith
from cutlass._mlir.extras import types as T
from collections.abc import Sequence
from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values
from ..base_dsl.dsl import is_dynamic_expression
from ..base_dsl.ast_helpers import *
from ..base_dsl.utils.logger import log
from ..base_dsl import typing as t
from ..base_dsl.typing import Int32, Float32, Boolean, Numeric, get_mlir_types
from ..base_dsl.typing import (
Int32,
Float32,
Boolean,
Numeric,
get_mlir_types,
as_numeric,
)
from . import cutlass as cutlass_dsl
from .tree_utils import PyTreeDef, check_tree_equal
# =============================================================================
# AST Helpers
@ -57,14 +66,6 @@ class ScfGenerator:
def __init__(self):
pass
@staticmethod
def fill_none(ir_values, unpacked_values):
i = 0
for idx, item in enumerate(unpacked_values):
if item is not None:
unpacked_values[idx] = ir_values[i]
i += 1
@staticmethod
def _normalize_region_result_to_list(region_result: Any) -> List[Any]:
"""
@ -82,34 +83,109 @@ class ScfGenerator:
return region_result_list
@staticmethod
def check_region_result(region_values, ir_values):
for i, (expected_value, actual_value) in enumerate(
zip(ir_values, region_values)
def _check_region_result(original_value, region_value, arg_name, op_type_name):
"""
Validate that a region result maintains the same type as the original value.
This method checks for type consistency between the original value passed to a dynamic
SCF operation (like for, if, while) and the value returned from the operation's region.
Args:
original_value: The value before entering the SCF operation region
region_value: The value returned from the SCF operation region
arg_name: Name of the argument being checked (for error reporting)
op_type_name: Type of SCF operation (e.g., 'for', 'if', 'while') for error reporting
Raises:
DSLRuntimeError: If the region value has a different type than the original value.
The error includes suggestions for using compile-time control flow instead.
Note:
This method performs relaxed type checking that allows inheritance relationships.
For example, a child class can be returned where a parent class was expected.
However, fundamental type changes (like None to non-None, different sequence types,
or different numeric types) are not allowed in dynamic SCF operations.
"""
def get_type_name(value):
if isinstance(value, NoneType):
return "None"
elif isinstance(value, Sequence):
return f"{type(value).__name__}<{len(value)}>"
else:
return type(value).__name__
# Check for type mismatches
type_mismatch = False
old_type_name = None
new_type_name = None
# Handle None type changes
if isinstance(original_value, NoneType) != isinstance(region_value, NoneType):
type_mismatch = True
old_type_name = get_type_name(original_value)
new_type_name = get_type_name(region_value)
# Handle sequence type/length changes
elif isinstance(original_value, Sequence) and isinstance(
region_value, Sequence
):
expected_value_type = get_mlir_types(expected_value)
actual_value_type = get_mlir_types(actual_value)
if expected_value_type != actual_value_type:
return False, i, expected_value_type, actual_value_type
return True, -1, None, None
if type(original_value) != type(region_value) or len(original_value) != len(
region_value
):
type_mismatch = True
old_type_name = get_type_name(original_value)
new_type_name = get_type_name(region_value)
# Handle numeric type changes
elif isinstance(
original_value, (Numeric, ArithValue, ir.Value, int, float, bool)
) or isinstance(
region_value, (Numeric, ArithValue, ir.Value, int, float, bool)
):
try:
original_numeric = as_numeric(original_value)
region_numeric = as_numeric(region_value)
if original_numeric.dtype != region_numeric.dtype:
type_mismatch = True
old_type_name = original_numeric.dtype.__name__
new_type_name = region_numeric.dtype.__name__
except Exception:
pass
# Handle general type changes (relaxed for inheritance)
elif type(original_value) != type(region_value):
old_type = type(original_value)
new_type = type(region_value)
if not (issubclass(old_type, new_type) or issubclass(new_type, old_type)):
type_mismatch = True
old_type_name = old_type.__name__
new_type_name = new_type.__name__
if type_mismatch:
raise DSLRuntimeError(
f"`{arg_name}` is {old_type_name} prior to this `{op_type_name}`, "
f"and update to {new_type_name} inside of this `{op_type_name}` is not supported.",
suggestion=(
f"Please avoid changing type inside a dynamic `{op_type_name}`, "
f"or change to compile-time control flow by marking this `{op_type_name}` with "
f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`."
),
)
def scf_execute_dynamic(
self,
op_type_name: str,
used_args: List[Any],
mix_iter_args: List[Any],
full_write_args_count: int,
mix_iter_arg_names: List[str],
create_op_func: Callable[
[List[ir.Value], Dict[int, Tuple[int, int]], List[Any]], ir.Operation
],
create_op_func: Callable[[List[ir.Value]], ir.Operation],
region_builders: List[
Callable[
[
"ir.Operation",
List["ir.Value"], # block_args
List[Any], # used_args
List["ir.Value"], # dyn_yield_ops
Dict[int, Tuple[int, int]],
PyTreeDef,
List[Any],
int,
],
Any,
]
@ -119,11 +195,11 @@ class ScfGenerator:
block_term_op_builder: Dict[Callable, Callable] = {},
) -> Any:
# 1) Unpack
ir_values, dyn_unpacked_values, dyn_indices, dyn_class_types = (
cutlass_dsl.unpack_to_irvalue(mix_iter_args, op_type_name)
ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue(
mix_iter_args, op_type_name, full_write_args_count
)
# 2) Create the SCF op
op = create_op_func(ir_values, dyn_indices, dyn_class_types)
op = create_op_func(ir_values)
log().debug("Generated scf.%s \n[%s]", op_type_name, op)
# 3) Build the regions
@ -135,76 +211,61 @@ class ScfGenerator:
region_result = builder(
op,
block_args,
used_args,
dyn_unpacked_values,
dyn_indices,
dyn_class_types,
ir_values,
pytree_def,
mix_iter_args,
full_write_args_count,
)
# Use custom terminator if provided for this builder, otherwise use default YieldOp
if builder in block_term_op_builder:
# Use the provided terminator generator
block_term_op_builder[builder](region_result)
block_term_op_builder[builder](region_result, full_write_args_count)
else:
# For standard yield op, check result
for arg, result, name in zip(
mix_iter_args,
(
region_result
if isinstance(region_result, list)
else [region_result]
),
mix_iter_arg_names,
):
if isinstance(arg, NoneType) and not isinstance(
result, NoneType
):
raise DSLRuntimeError(
(
f"`{name}` is None prior to this `{op_type_name}`, "
f"and update to non-None value inside of this `{op_type_name}` is not supported."
),
suggestion=(
f"Please make sure `{name}` is not None prior to this `{op_type_name}`, "
f"or mark this `{op_type_name}` with "
f"`{'range' if op_type_name == 'for' else 'const_expr'}`."
),
)
# Normalize region_result
region_result_list = ScfGenerator._normalize_region_result_to_list(
region_result
)
# Default behavior - generate YieldOp
region_values, unpacked_values, _, _ = (
cutlass_dsl.unpack_to_irvalue(region_result_list, op_type_name)
)
is_match, mismatch_idx, expected_type, actual_type = (
ScfGenerator.check_region_result(region_values, ir_values)
)
if not is_match:
# From unpacked index, we need to find the original index
original_idx = -1
for unpacked_idx, (original_idx, length) in dyn_indices.items():
if (
mismatch_idx >= original_idx
and mismatch_idx < original_idx + length
):
original_idx = unpacked_idx
break
raise DSLRuntimeError(
f"`{op_type_name}` expects {expected_type} type for varible `{mix_iter_arg_names[original_idx]}`, but got {actual_type}.",
suggestion=f"Please make sure `{mix_iter_arg_names[original_idx]}` type is not changed inside of `{op_type_name}`.",
# For standard yield op, check result
for arg, result, name in zip(
mix_iter_args,
region_result_list,
mix_iter_arg_names,
):
ScfGenerator._check_region_result(
arg, result, name, op_type_name
)
# Default behavior - generate YieldOp
region_values, yield_pytree_def = cutlass_dsl.unpack_to_irvalue(
region_result_list, op_type_name, full_write_args_count
)
mismatch = check_tree_equal(pytree_def, yield_pytree_def)
if mismatch != -1:
# Get arg name
filterd_arg_names = (
cutlass_dsl.filter_readonly_frozen_dataclass_names(
mix_iter_args, mix_iter_arg_names, full_write_args_count
)
)
raise DSLRuntimeError(
f"`{filterd_arg_names[mismatch]}` is structured different after this `{op_type_name}`.",
suggestion=(
f"Please avoid changing type structure inside a dynamic `{op_type_name}`, "
f"or change to compile-time control flow by marking this `{op_type_name}` with "
f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`."
),
)
scf.YieldOp(region_values)
log().debug("Completed scf.%s \n[%s]", op_type_name, op)
ScfGenerator.fill_none(op.results, unpacked_values)
# 4) Pack final results
final_results = cutlass_dsl.pack_from_irvalue(
unpacked_values, dyn_indices, dyn_class_types
op.results, pytree_def, mix_iter_args, full_write_args_count
)
# 5) Return in a nice pattern
@ -215,28 +276,32 @@ class ScfGenerator:
return final_results
def _attr_const_check(attr, expected_type, attr_name):
# Use strict type equality to prevent `bool` being accepted where `int` is required.
if is_dynamic_expression(attr) or type(attr) is not expected_type:
raise DSLRuntimeError(
f"loop attribute `{attr_name}` must be a Python value of type `{expected_type.__name__}`, got `{type(attr).__name__}`."
)
def _loop_execute_range_dynamic(
func: Callable,
start: Any,
stop: Any,
step: Any,
used_args: List[Any] = [],
mix_iter_args: List[Any] = [],
full_write_args_count: int = 0,
mix_iter_arg_names: List[str] = [],
unroll: int = -1,
unroll_full: bool = False,
pipelining: int = None,
prefetch_stages: int = None,
):
"""
Example: build an scf.for with optional unroll, using our universal helper.
"""
scf_gen = ScfGenerator()
def create_for_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
def create_for_op(dyn_yield_ops: List[ir.Value]):
for d in dyn_yield_ops:
if not isinstance(d, ir.Value):
raise DSLRuntimeError(
@ -254,6 +319,10 @@ def _loop_execute_range_dynamic(
stop_ = stop_.ir_value()
step_ = step_.ir_value()
# Attributes must be pure Python value, add a check
_attr_const_check(unroll, int, "unroll")
_attr_const_check(unroll_full, bool, "unroll_full")
# Possibly attach unroll attributes
unroll_attr = None
if unroll_full:
@ -262,17 +331,18 @@ def _loop_execute_range_dynamic(
unroll_attr = LoopUnroll(count=unroll)
log().debug("Unroll attribute: %s", unroll_attr)
pipelining_attr = None
if pipelining is not None:
if pipelining >= 0:
pipelining_attr = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), pipelining
prefetch_stages_attr = None
if prefetch_stages is not None:
_attr_const_check(prefetch_stages, int, "prefetch_stages")
if prefetch_stages >= 0:
prefetch_stages_attr = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), prefetch_stages
)
else:
raise DSLRuntimeError(
f"Pipelining must be non-negative, got {pipelining}"
f"loop attribute `prefetch_stages` must be non-negative, got `{prefetch_stages}`."
)
log().debug("Pipelining attribute: %s", pipelining_attr)
log().debug("prefetch_stages attribute: %s", prefetch_stages_attr)
log().debug(
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
@ -303,47 +373,48 @@ def _loop_execute_range_dynamic(
if unroll_attr is not None:
for_op.attributes["loop_annotation"] = unroll_attr
if pipelining_attr is not None:
for_op.attributes["cutlass.pipelining"] = pipelining_attr
if prefetch_stages_attr is not None:
for_op.attributes["cutlass.pipelining"] = prefetch_stages_attr
return for_op
def for_body_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
op,
block_args,
_,
pytree_def,
mix_iter_args,
full_write_args_count,
):
# Insert induction variable at the beginning
dyn_yield_ops.insert(0, block_args[0])
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
# scf.ForOp block_args are typically [induction_var, iter_args...]
# But MLIR also gives you op.induction_variable
iv = t.as_numeric(op.induction_variable)
log().debug(
"For body builder: %s block_args: %s used_args: %s",
"For body builder: %s block_args: %s full_write_args_count: %s",
iv,
block_args,
used_args,
full_write_args_count,
)
if len(block_args) <= 1:
# block_args[1:] are iteration variables
func_args = []
func_args.extend(
cutlass_dsl.pack_from_irvalue(
block_args[1:], pytree_def, mix_iter_args, full_write_args_count
)
)
if not func_args:
# No iteration arguments, or only the induction var
func(iv, *used_args)
func(iv)
return [] # yield nothing
else:
# block_args[1:] are iteration variables
func_args = [*used_args]
func_args.extend(
cutlass_dsl.pack_from_irvalue(
block_args[1:], dyn_indices, dyn_class_types
)
)
updated_func_args = func(iv, *func_args)
return updated_func_args
# Now call the universal SCF executor with a single region builder
return scf_gen.scf_execute_dynamic(
op_type_name="for",
used_args=used_args,
mix_iter_args=mix_iter_args,
full_write_args_count=full_write_args_count,
mix_iter_arg_names=mix_iter_arg_names,
create_op_func=create_for_op,
region_builders=[for_body_builder],
@ -354,8 +425,8 @@ def _if_execute_dynamic(
pred: "ir.Value",
then_block: Callable,
else_block: Callable = None,
used_args: List[Any] = [],
mix_yield_args: List[Any] = [],
full_write_args_count: int = 0,
mix_yield_arg_names: List[str] = [],
if_constexpr=None, # ignoring for brevity
):
@ -364,11 +435,7 @@ def _if_execute_dynamic(
"""
scf_gen = ScfGenerator()
def create_if_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
def create_if_op(dyn_yield_ops: List[ir.Value]):
# Assume final result types match the dynamic yields
result_types = [arg.type for arg in dyn_yield_ops]
@ -387,11 +454,18 @@ def _if_execute_dynamic(
return if_op
def then_builder(
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
if_op,
_,
dyn_yield_ops,
pytree_def,
mix_iter_args,
full_write_args_count,
):
flat_args = [*used_args]
flat_args = []
flat_args.extend(
cutlass_dsl.pack_from_irvalue(dyn_yield_ops, dyn_indices, dyn_class_types)
cutlass_dsl.pack_from_irvalue(
dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count
)
)
return then_block(*flat_args)
@ -400,12 +474,17 @@ def _if_execute_dynamic(
if else_block is not None:
def else_builder(
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
if_op,
_,
dyn_yield_ops,
pytree_def,
mix_iter_args,
full_write_args_count,
):
flat_args = [*used_args]
flat_args = []
flat_args.extend(
cutlass_dsl.pack_from_irvalue(
dyn_yield_ops, dyn_indices, dyn_class_types
dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count
)
)
return else_block(*flat_args)
@ -414,8 +493,8 @@ def _if_execute_dynamic(
return scf_gen.scf_execute_dynamic(
op_type_name="if",
used_args=used_args,
mix_iter_args=mix_yield_args,
full_write_args_count=full_write_args_count,
mix_iter_arg_names=mix_yield_arg_names,
create_op_func=create_if_op,
region_builders=region_builders,
@ -425,9 +504,9 @@ def _if_execute_dynamic(
def _while_execute_dynamic(
while_before_block: Callable,
while_after_block: Callable = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
"""
Create and return an SCF WhileOp for dynamic loops.
@ -436,8 +515,7 @@ def _while_execute_dynamic(
Args:
while_before_block: Function that returns (condition, updated_values)
while_after_block: Function that returns updated values
used_args: Additional arguments used in the loop body
yield_args: Values that are updated in the loop
write_args: Values that are updated in the loop
See create_while_function in ast_preprocessor.py for details on the input structure.
"""
@ -445,11 +523,7 @@ def _while_execute_dynamic(
while_op_type_name = "while"
scf_gen = ScfGenerator()
def create_while_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
def create_while_op(dyn_yield_ops: List[ir.Value]):
# Create the while operation with the types from yield_args
result_types = [arg.type for arg in dyn_yield_ops]
try:
@ -468,14 +542,19 @@ def _while_execute_dynamic(
) from e
def before_block_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
op,
block_args,
_,
pytree_def,
mix_iter_args,
full_write_args_count,
):
# Build the before (condition) block
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
flat_args = [*used_args]
flat_args = []
flat_args.extend(
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
cutlass_dsl.pack_from_irvalue(
block_args, pytree_def, mix_iter_args, full_write_args_count
)
)
log().debug("before block args: %s", flat_args)
@ -493,18 +572,15 @@ def _while_execute_dynamic(
return cond, before_results
def before_block_terminator(cond_and_results):
def before_block_terminator(cond_and_results, full_write_args_count):
# Generate a condition op instead of yield op
cond = cond_and_results[0]
before_result_list = ScfGenerator._normalize_region_result_to_list(
cond_and_results[1]
)
ir_cond_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
[cond], while_op_type_name
)
ir_cond = ir_cond_list[0]
ir_results_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
before_result_list, while_op_type_name
ir_cond = as_numeric(cond).ir_value()
ir_results_list, pytree_def = cutlass_dsl.unpack_to_irvalue(
before_result_list, while_op_type_name, full_write_args_count
)
log().debug(
"creating scf.ConditionOp with [%s], [%s]",
@ -514,14 +590,19 @@ def _while_execute_dynamic(
scf.ConditionOp(ir_cond, ir_results_list)
def after_block_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
op,
block_args,
_,
pytree_def,
mix_iter_args,
full_write_args_count,
):
# Build the after (body) block
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
flat_args = [*used_args]
flat_args = []
flat_args.extend(
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
cutlass_dsl.pack_from_irvalue(
block_args, pytree_def, mix_iter_args, full_write_args_count
)
)
log().debug("after block args: %s", flat_args)
@ -541,9 +622,9 @@ def _while_execute_dynamic(
# Call the universal SCF executor with two region builders
return scf_gen.scf_execute_dynamic(
op_type_name=while_op_type_name,
used_args=used_args,
mix_iter_args=yield_args,
mix_iter_arg_names=yield_arg_names,
mix_iter_args=write_args,
full_write_args_count=full_write_args_count,
mix_iter_arg_names=write_args_names,
create_op_func=create_while_op,
region_builders=[before_block_builder, after_block_builder],
block_term_op_builder={

View File

@ -0,0 +1,763 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin
import dataclasses
import itertools as it
from types import SimpleNamespace
from ..base_dsl.typing import as_numeric, Numeric, Constexpr
from ..base_dsl._mlir_helpers.arith import ArithValue
from ..base_dsl.common import DSLBaseError
from .._mlir import ir
# =============================================================================
# Tree Utils
# =============================================================================
class DSLTreeFlattenError(DSLBaseError):
"""Exception raised when tree flattening fails due to unsupported types."""
def __init__(self, msg: str, type_str: str):
super().__init__(msg)
self.type_str = type_str
def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]:
"""Unzip a sequence of pairs into two lists."""
lst1, lst2 = [], []
for x1, x2 in pairs:
lst1.append(x1)
lst2.append(x2)
return lst1, lst2
def get_fully_qualified_class_name(x: Any) -> str:
"""
Get the fully qualified class name of an object.
Args:
x: Any object
Returns:
str: Fully qualified class name in format 'module.class_name'
Example:
>>> get_fully_qualified_class_name([1, 2, 3])
'builtins.list'
"""
return f"{x.__class__.__module__}.{x.__class__.__qualname__}"
def is_frozen_dataclass(obj_or_cls: Any) -> bool:
"""
Check if an object or class is a frozen dataclass.
Args:
obj_or_cls: Either a dataclass instance or class
Returns:
bool: True if the object/class is a dataclass declared with frozen=True,
False otherwise
Example:
>>> from dataclasses import dataclass
>>> @dataclass(frozen=True)
... class Point:
... x: int
... y: int
>>> is_frozen_dataclass(Point)
True
>>> is_frozen_dataclass(Point(1, 2))
True
"""
cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__
return (
dataclasses.is_dataclass(cls)
and getattr(cls, "__dataclass_params__", None) is not None
and cls.__dataclass_params__.frozen
)
def is_dynamic_expression(x: Any) -> bool:
"""
Check if an object implements the DynamicExpression protocol.
Objects implementing this protocol must have both `__extract_mlir_values__`
and `__new_from_mlir_values__` methods.
Args:
x: Any object to check
Returns:
bool: True if the object implements the DynamicExpression protocol,
False otherwise
"""
return all(
hasattr(x, attr)
for attr in ("__extract_mlir_values__", "__new_from_mlir_values__")
)
def is_constexpr_field(field: dataclasses.Field) -> bool:
"""
Check if a field is a constexpr field.
"""
if field.type is Constexpr:
return True
elif get_origin(field.type) is Constexpr:
return True
return False
# =============================================================================
# PyTreeDef
# =============================================================================
class NodeType(NamedTuple):
"""
Represents a node in a pytree structure.
Attributes:
name: String representation of the node type
to_iterable: Function to convert node to iterable form
from_iterable: Function to reconstruct node from iterable form
"""
name: str
to_iterable: Callable
from_iterable: Callable
class PyTreeDef(NamedTuple):
"""
Represents the structure definition of a pytree.
Attributes:
node_type: The type of this node
node_metadata: SimpleNamespace metadata associated with this node
child_treedefs: Tuple of child tree definitions
"""
node_type: NodeType
node_metadata: SimpleNamespace
child_treedefs: tuple["PyTreeDef", ...]
@dataclasses.dataclass(frozen=True)
class Leaf:
"""
Represents a leaf node in a pytree structure.
Attributes:
is_numeric: Whether this leaf contains a `Numeric` value
is_none: Whether this leaf represents None
node_metadata: SimpleNamespace metadata associated with this leaf
ir_type_str: String representation of the IR type
"""
is_numeric: bool = False
is_none: bool = False
node_metadata: SimpleNamespace = None
ir_type_str: str = None
# =============================================================================
# Default to_iterable and from_iterable
# =============================================================================
def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]:
"""
Extract non-method, non-function attributes from a dataclass instance.
Args:
x: A dataclass instance
Returns:
tuple: (field_names, field_values) lists
"""
fields = [field.name for field in dataclasses.fields(x)]
# If the dataclass has extra fields, raise an error
for k in x.__dict__.keys():
if k not in fields:
raise DSLTreeFlattenError(
f"`{x}` has extra field `{k}`",
type_str=get_fully_qualified_class_name(x),
)
if not fields:
return [], []
# record constexpr fields
members = []
constexpr_fields = []
for field in dataclasses.fields(x):
if is_constexpr_field(field):
constexpr_fields.append(field.name)
fields.remove(field.name)
v = getattr(x, field.name)
if is_dynamic_expression(v):
raise DSLTreeFlattenError(
f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`",
type_str=get_fully_qualified_class_name(x),
)
else:
members.append(getattr(x, field.name))
return fields, members, constexpr_fields
def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
"""
Convert a dataclass instance to iterable form for tree flattening.
Extracts all non-method, non-function attributes that don't start with '__'
and returns them along with metadata about the dataclass.
Args:
x: A dataclass instance
Returns:
tuple: (metadata, members) where metadata contains type info and field names,
and members is the list of attribute values
"""
fields, members, constexpr_fields = extract_dataclass_members(x)
metadata = SimpleNamespace(
type_str=get_fully_qualified_class_name(x),
fields=fields,
constexpr_fields=constexpr_fields,
original_obj=x,
)
return metadata, members
def set_dataclass_attributes(
instance: Any,
fields: list[str],
values: Iterable[Any],
constexpr_fields: list[str],
) -> Any:
"""
Set attributes on a dataclass instance.
Args:
instance: The dataclass instance
fields: List of field names
values: Iterable of field values
is_frozen: Whether the dataclass is frozen
Returns:
The instance with attributes set
"""
if not fields:
return instance
kwargs = dict(zip(fields, values))
for field in constexpr_fields:
kwargs[field] = getattr(instance, field)
return dataclasses.replace(instance, **kwargs)
def default_dataclass_from_iterable(
metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
"""
Reconstruct a dataclass instance from iterable form.
Handles both regular and frozen dataclasses appropriately.
Args:
metadata: Metadata containing type information and field names
children: Iterable of attribute values to reconstruct the instance
Returns:
The reconstructed dataclass instance
"""
instance = metadata.original_obj
new_instance = set_dataclass_attributes(
instance, metadata.fields, children, metadata.constexpr_fields
)
metadata.original_obj = new_instance
return new_instance
def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
"""
Convert a dynamic expression to iterable form.
Uses the object's `__extract_mlir_values__` method to extract MLIR values.
Args:
x: A dynamic expression object
Returns:
tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression
and mlir_values are the extracted MLIR values
"""
return (
SimpleNamespace(is_dynamic_expression=1, original_obj=x),
x.__extract_mlir_values__(),
)
def dynamic_expression_from_iterable(
metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
"""
Reconstruct a dynamic expression from iterable form.
Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values.
Args:
metadata: Metadata containing the original object
children: Iterable of MLIR values to reconstruct from
Returns:
The reconstructed dynamic expression object
"""
return metadata.original_obj.__new_from_mlir_values__(list(children))
def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
"""
Convert a dict to iterable form.
"""
if isinstance(x, SimpleNamespace):
keys = list(x.__dict__.keys())
values = list(x.__dict__.values())
else:
keys = list(x.keys())
values = list(x.values())
return (
SimpleNamespace(
type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys
),
values,
)
def default_dict_from_iterable(
metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
"""
Reconstruct a dict from iterable form.
"""
instance = metadata.original_obj
fields = metadata.fields
is_simple_namespace = isinstance(instance, SimpleNamespace)
for k, v in zip(fields, children):
if is_simple_namespace:
setattr(instance, k, v)
else:
instance[k] = v
return instance
# =============================================================================
# Register pytree nodes
# =============================================================================
_node_types: dict[type, NodeType] = {}
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType:
"""
Register a new node type for pytree operations.
Args:
ty: The type to register
to_iter: Function to convert instances of this type to iterable form
from_iter: Function to reconstruct instances of this type from iterable form
Returns:
NodeType: The created NodeType instance
"""
nt = NodeType(str(ty), to_iter, from_iter)
_node_types[ty] = nt
return nt
def register_default_node_types() -> None:
"""Register default node types for pytree operations."""
default_registrations = [
(
tuple,
lambda t: (SimpleNamespace(length=len(t)), list(t)),
lambda _, xs: tuple(xs),
),
(
list,
lambda l: (SimpleNamespace(length=len(l)), list(l)),
lambda _, xs: list(xs),
),
(
dict,
default_dict_to_iterable,
default_dict_from_iterable,
),
(
SimpleNamespace,
default_dict_to_iterable,
default_dict_from_iterable,
),
]
for ty, to_iter, from_iter in default_registrations:
register_pytree_node(ty, to_iter, from_iter)
# Initialize default registrations
register_default_node_types()
# =============================================================================
# tree_flatten and tree_unflatten
# =============================================================================
"""
Behavior of tree_flatten and tree_unflatten, for example:
```python
a = (1, 2, 3)
b = MyClass(a=1, b =[1,2,3])
```
yields the following tree:
```python
tree_a = PyTreeDef(type = 'tuple',
metadata = {length = 3},
children = [
Leaf(type = int),
Leaf(type = int),
Leaf(type = int),
],
)
flattened_a = [1, 2, 3]
tree_b = PyTreeDef(type = 'MyClass',
metadata = {fields = ['a','b']},
children = [
PyTreeDef(type = `list`,
metadata = {length = 3},
children = [
Leaf(type=`int`),
Leaf(type=`int`),
Leaf(type=`int`),
],
),
Leaf(type=int),
],
)
flattened_b = [1, 1, 2, 3]
```
Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure.
``` python
unflattened_a = tree_unflatten(tree_a, flattened_a)
unflattened_b = tree_unflatten(tree_b, flattened_b)
```
yields the following structure:
``` python
unflattened_a = (1, 2, 3)
unflattened_b = MyClass(a=1, b =[1,2,3])
```
unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b.
"""
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
"""
Flatten a nested structure into a flat list of values and a tree definition.
This function recursively traverses nested data structures (trees) and
flattens them into a linear list of leaf values, while preserving the
structure information in a PyTreeDef.
Args:
x: The nested structure to flatten
Returns:
tuple: (flat_values, treedef) where flat_values is a list of leaf values
and treedef is the tree structure definition
Raises:
DSLTreeFlattenError: If the structure contains unsupported types
Example:
>>> tree_flatten([1, [2, 3], 4])
([1, 2, 3, 4], PyTreeDef(...))
"""
children_iter, treedef = _tree_flatten(x)
return list(children_iter), treedef
def get_registered_node_types_or_insert(x: Any) -> NodeType | None:
"""
Get the registered node type for an object, registering it if necessary.
This function checks if a type is already registered for pytree operations.
If not, it automatically registers the type based on its characteristics:
- Dynamic expressions get registered with dynamic expression handlers
- Dataclasses get registered with default dataclass handlers
Args:
x: The object to get or register a node type for
Returns:
NodeType or None: The registered node type, or None if the type
cannot be registered
"""
node_type = _node_types.get(type(x))
if node_type:
return node_type
elif is_dynamic_expression(x):
# If a class implements DynamicExpression protocol, register it before default dataclass one
return register_pytree_node(
type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable
)
elif dataclasses.is_dataclass(x):
return register_pytree_node(
type(x), default_dataclass_to_iterable, default_dataclass_from_iterable
)
else:
return None
def create_leaf_for_value(
x: Any,
is_numeric: bool = False,
is_none: bool = False,
node_metadata: SimpleNamespace = None,
ir_type_str: str = None,
) -> Leaf:
"""
Create a Leaf node for a given value.
Args:
x: The value to create a leaf for
is_numeric: Whether this is a numeric value
is_none: Whether this represents None
node_metadata: Optional metadata
ir_type_str: Optional IR type string
Returns:
Leaf: The created leaf node
"""
return Leaf(
is_numeric=is_numeric,
is_none=is_none,
node_metadata=node_metadata,
ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None),
)
def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]:
"""
Internal function to flatten a tree structure.
This is the core implementation of tree flattening that handles different
types of objects including None, ArithValue, ir.Value, Numeric types,
and registered pytree node types.
Args:
x: The object to flatten
Returns:
tuple: (flattened_values, treedef) where flattened_values is an iterable
of leaf values and treedef is the tree structure
Raises:
DSLTreeFlattenError: If the object type is not supported
"""
match x:
case None:
return [], create_leaf_for_value(x, is_none=True)
case ArithValue() if is_dynamic_expression(x):
v = x.__extract_mlir_values__()
return v, create_leaf_for_value(
x,
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
ir_type_str=str(v[0].type),
)
case ArithValue():
return [x], create_leaf_for_value(x, is_numeric=True)
case ir.Value():
return [x], create_leaf_for_value(x)
case Numeric():
v = x.__extract_mlir_values__()
return v, create_leaf_for_value(
x,
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
ir_type_str=str(v[0].type),
)
case _:
node_type = get_registered_node_types_or_insert(x)
if node_type:
node_metadata, children = node_type.to_iterable(x)
children_flat, child_trees = unzip2(map(_tree_flatten, children))
flattened = it.chain.from_iterable(children_flat)
return flattened, PyTreeDef(
node_type, node_metadata, tuple(child_trees)
)
# Try to convert to numeric
try:
nval = as_numeric(x).ir_value()
return [nval], create_leaf_for_value(nval, is_numeric=True)
except Exception:
raise DSLTreeFlattenError(
"Flatten Error", get_fully_qualified_class_name(x)
)
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
"""
Reconstruct a nested structure from a flat list of values and tree definition.
This is the inverse operation of tree_flatten. It takes the flattened
values and the tree structure definition to reconstruct the original
nested structure.
Args:
treedef: The tree structure definition from tree_flatten
xs: List of flat values to reconstruct from
Returns:
The reconstructed nested structure
Example:
>>> flat_values, treedef = tree_flatten([1, [2, 3], 4])
>>> tree_unflatten(treedef, flat_values)
[1, [2, 3], 4]
"""
return _tree_unflatten(treedef, iter(xs))
def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any:
"""
Internal function to reconstruct a tree structure.
This is the core implementation of tree unflattening that handles
different types of tree definitions including Leaf nodes and PyTreeDef nodes.
Args:
treedef: The tree structure definition
xs: Iterator of flat values to reconstruct from
Returns:
The reconstructed object
"""
match treedef:
case Leaf(is_none=True):
return None
case Leaf(
node_metadata=metadata
) if metadata and metadata.is_dynamic_expression:
return metadata.original_obj.__new_from_mlir_values__([next(xs)])
case Leaf(is_numeric=True):
return as_numeric(next(xs))
case Leaf():
return next(xs)
case PyTreeDef():
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
return treedef.node_type.from_iterable(treedef.node_metadata, children)
def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool:
"""
Check if two tree definitions are structurally equal.
This is a helper function for check_tree_equal that recursively compares
tree structures.
Args:
lhs: Left tree definition (PyTreeDef or Leaf)
rhs: Right tree definition (PyTreeDef or Leaf)
Returns:
bool: True if the trees are structurally equal, False otherwise
"""
match (lhs, rhs):
case (Leaf(), Leaf()):
return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str
case (PyTreeDef(), PyTreeDef()):
lhs_metadata = lhs.node_metadata
rhs_metadata = rhs.node_metadata
lhs_fields = getattr(lhs_metadata, "fields", [])
rhs_fields = getattr(rhs_metadata, "fields", [])
lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", [])
rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", [])
return (
lhs.node_type == rhs.node_type
and lhs_fields == rhs_fields
and lhs_constexpr_fields == rhs_constexpr_fields
and len(lhs.child_treedefs) == len(rhs.child_treedefs)
and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs))
)
case _:
return False
def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int:
"""
Check if two tree definitions are equal and return the index of first difference.
This function compares two tree definitions and returns the index of the
first child that differs, or -1 if they are completely equal.
Args:
lhs: Left tree definition
rhs: Right tree definition
Returns:
int: Index of the first differing child, or -1 if trees are equal
Example:
>>> treedef1 = tree_flatten([1, [2, 3]])[1]
>>> treedef2 = tree_flatten([1, [2, 4]])[1]
>>> check_tree_equal(treedef1, treedef2)
1 # The second child differs
"""
assert len(lhs.child_treedefs) == len(rhs.child_treedefs)
def find_first_difference(
index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]]
) -> int:
index, (l, r) = index_and_pair
return index if not _check_tree_equal(l, r) else -1
differences = map(
find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs))
)
return next((diff for diff in differences if diff != -1), -1)