v4.2 tag release. (#2638)
This commit is contained in:
@ -17,6 +17,7 @@ from ..base_dsl.ast_helpers import (
|
||||
if_executor,
|
||||
while_selector,
|
||||
while_executor,
|
||||
range,
|
||||
range_constexpr,
|
||||
range_dynamic,
|
||||
const_expr,
|
||||
@ -28,6 +29,8 @@ from ..base_dsl.ast_helpers import (
|
||||
all_executor,
|
||||
range_value_check,
|
||||
range_perf_warning,
|
||||
cf_symbol_check,
|
||||
redirect_builtin_function,
|
||||
)
|
||||
|
||||
from ..base_dsl import *
|
||||
@ -38,5 +41,4 @@ from ..base_dsl._mlir_helpers.op import dsl_user_op
|
||||
from ..base_dsl.runtime import *
|
||||
from ..base_dsl.runtime import cuda as cuda_helpers
|
||||
from ..base_dsl.compiler import compile
|
||||
from ..base_dsl.runtime.dlpack_runtime import *
|
||||
from ..base_dsl.runtime.jit_arg_adapters import *
|
||||
|
||||
@ -15,12 +15,14 @@ regarding to that dialect.
|
||||
"""
|
||||
|
||||
# Local module imports
|
||||
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef
|
||||
from inspect import isclass
|
||||
from itertools import chain
|
||||
from types import GenericAlias, SimpleNamespace, UnionType
|
||||
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any
|
||||
import functools
|
||||
import pkgutil
|
||||
from dataclasses import is_dataclass
|
||||
from dataclasses import is_dataclass, fields
|
||||
from collections.abc import Sequence
|
||||
import builtins
|
||||
|
||||
from ..base_dsl import *
|
||||
from ..base_dsl import compiler
|
||||
@ -51,20 +53,15 @@ from ..base_dsl.ast_helpers import (
|
||||
while_selector,
|
||||
while_executor,
|
||||
assert_executor,
|
||||
const_expr,
|
||||
dynamic_expr,
|
||||
bool_cast,
|
||||
compare_executor,
|
||||
any_executor,
|
||||
all_executor,
|
||||
range_value_check,
|
||||
range_perf_warning,
|
||||
)
|
||||
from ..base_dsl.runtime.dlpack_runtime import (
|
||||
get_cute_tensor_c_pointer,
|
||||
get_tensor_desc_shape_all,
|
||||
get_tensor_desc_stride_all,
|
||||
get_tensor_desc_element_type,
|
||||
get_tensor_desc_is_in_device,
|
||||
get_tensor_desc_assumed_align,
|
||||
cf_symbol_check,
|
||||
)
|
||||
|
||||
from .cutlass_ast_decorators import (
|
||||
@ -73,6 +70,16 @@ from .cutlass_ast_decorators import (
|
||||
_while_execute_dynamic,
|
||||
)
|
||||
|
||||
from .tree_utils import (
|
||||
is_constexpr_field,
|
||||
tree_flatten,
|
||||
tree_unflatten,
|
||||
PyTreeDef,
|
||||
is_frozen_dataclass,
|
||||
DSLTreeFlattenError,
|
||||
)
|
||||
from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cutlass DSL Base Abstract Class
|
||||
@ -125,6 +132,46 @@ def is_cute_algebra_type(arg_spec):
|
||||
return False
|
||||
|
||||
|
||||
def _get_c_pointers_cutlass(obj):
|
||||
"""
|
||||
This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict.
|
||||
"""
|
||||
if hasattr(obj, "__c_pointers__"):
|
||||
return obj.__c_pointers__()
|
||||
elif isinstance(obj, (tuple, list)):
|
||||
return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj))
|
||||
elif isinstance(obj, SimpleNamespace):
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
_get_c_pointers_cutlass(x) for x in obj.__dict__.values()
|
||||
)
|
||||
)
|
||||
elif isinstance(obj, dict):
|
||||
return list(
|
||||
chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values())
|
||||
)
|
||||
elif is_dataclass(obj):
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
_get_c_pointers_cutlass(getattr(obj, f.name))
|
||||
for f in fields(obj)
|
||||
if not is_constexpr_field(f)
|
||||
)
|
||||
)
|
||||
elif isinstance(obj, set):
|
||||
raise DSLRuntimeError(
|
||||
"Sets are not supported in get_c_pointers to ensure order preservation",
|
||||
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
|
||||
suggestion="Consider using a list or tuple instead",
|
||||
)
|
||||
else:
|
||||
# Try get adapter
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj))
|
||||
if adapter is not None:
|
||||
return _get_c_pointers_cutlass(adapter(obj))
|
||||
return []
|
||||
|
||||
|
||||
class CutlassBaseDSL(BaseDSL):
|
||||
"""This abstract class provides a DSL for Cutlass."""
|
||||
|
||||
@ -137,16 +184,25 @@ class CutlassBaseDSL(BaseDSL):
|
||||
preprocess: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
compiler_provider,
|
||||
pass_sm_arch_name,
|
||||
device_compilation_only,
|
||||
preprocess,
|
||||
name=name,
|
||||
dsl_package_name=["cutlass"],
|
||||
compiler_provider=compiler_provider,
|
||||
pass_sm_arch_name=pass_sm_arch_name,
|
||||
device_compilation_only=device_compilation_only,
|
||||
preprocess=preprocess,
|
||||
)
|
||||
self._smem_usage_tracker: tuple = None
|
||||
|
||||
# this method is not useful for cutlass_dsl, so we only provide a dummy implementation.
|
||||
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
|
||||
return False
|
||||
|
||||
# this method is not useful for cutlass_dsl, so we only provide a dummy implementation.
|
||||
def _handle_tensor_descriptor(
|
||||
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
|
||||
) -> Any:
|
||||
return False
|
||||
|
||||
def _build_gpu_module(self, attrs):
|
||||
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"))
|
||||
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
|
||||
@ -229,8 +285,43 @@ class CutlassBaseDSL(BaseDSL):
|
||||
|
||||
return version_hash
|
||||
|
||||
@staticmethod
|
||||
def track_smem_allocator(allocator, callback):
|
||||
"""
|
||||
Tracks shared memory usage for kernel functions.
|
||||
Find and set allocator to its parent dsl object.
|
||||
"""
|
||||
frame = inspect.currentframe().f_back
|
||||
while frame:
|
||||
obj = frame.f_locals.get("self", None)
|
||||
if obj and isinstance(obj, CutlassBaseDSL):
|
||||
obj._set_smem_tracking(allocator, callback)
|
||||
return
|
||||
frame = frame.f_back
|
||||
warnings.warn("Cannot find parent dsl for allocator!", UserWarning)
|
||||
|
||||
def _set_smem_tracking(self, allocator, callback):
|
||||
# Registers an allocator and callback for current dsl
|
||||
self._smem_usage_tracker = (allocator, callback)
|
||||
|
||||
def _reset_smem_tracking(self):
|
||||
# Clear an allocator and callback for current dsl
|
||||
self._smem_usage_tracker = None
|
||||
|
||||
def _get_smem_usage(self) -> int:
|
||||
# Treat final allocated bytes of allocator as smem usage
|
||||
if not self._smem_usage_tracker:
|
||||
return 0
|
||||
allocator, callback = self._smem_usage_tracker
|
||||
return callback(allocator)
|
||||
|
||||
def _kernel_helper(self, funcBody, *args, **kwargs):
|
||||
class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper):
|
||||
def __init__(self, dsl: CutlassBaseDSL):
|
||||
super().__init__()
|
||||
self.dsl = dsl
|
||||
self.dsl._reset_smem_tracking()
|
||||
|
||||
def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None):
|
||||
super().generate_func_op(arg_types, arg_attrs, kernel_name)
|
||||
self.func_op = func.FuncOp(
|
||||
@ -272,6 +363,17 @@ class CutlassBaseDSL(BaseDSL):
|
||||
if cfg.has_cluster:
|
||||
cfg.cluster = [to_index(size) for size in cfg.cluster]
|
||||
|
||||
smem_usage = self.dsl._get_smem_usage()
|
||||
if any(not isinstance(x, int) for x in [cfg.smem, smem_usage]):
|
||||
pass # cannot compare dynamic value inside kernel to launch op in py
|
||||
elif cfg.auto_smem:
|
||||
cfg.smem = smem_usage
|
||||
elif smem_usage > cfg.smem:
|
||||
warnings.warn(
|
||||
f"Potential error: specified kernel launch smem bytes "
|
||||
f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!",
|
||||
UserWarning,
|
||||
)
|
||||
cfg.smem = const(cfg.smem)
|
||||
|
||||
if not isinstance(cfg.async_deps, (list, tuple)):
|
||||
@ -295,12 +397,13 @@ class CutlassBaseDSL(BaseDSL):
|
||||
return token if is_async else None
|
||||
|
||||
return KernelLauncher(
|
||||
self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs
|
||||
self,
|
||||
lambda: _CutlassIrKernelGenHelper(self),
|
||||
funcBody,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_module_globals(self):
|
||||
return globals()
|
||||
|
||||
def _preprocess_launch_config_args(self, args, kwargs):
|
||||
"""Helper to preprocess args and kwargs for LaunchConfig"""
|
||||
if "stream" in kwargs:
|
||||
@ -316,7 +419,10 @@ class CutlassBaseDSL(BaseDSL):
|
||||
Validates if the arg is really of the annotated type.
|
||||
"""
|
||||
|
||||
if is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None):
|
||||
if (
|
||||
is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None)
|
||||
or arg_annotation is Any
|
||||
):
|
||||
pass
|
||||
else:
|
||||
origin = get_origin(arg_annotation)
|
||||
@ -329,11 +435,12 @@ class CutlassBaseDSL(BaseDSL):
|
||||
f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}"
|
||||
)
|
||||
# Handle Union types and generic types
|
||||
elif origin is Union:
|
||||
elif origin is Union or isinstance(arg_annotation, UnionType):
|
||||
# For Union types, check if arg matches any of the allowed types
|
||||
allowed_types = get_args(arg_annotation)
|
||||
if not any(
|
||||
(isinstance(ty, type) and isinstance(arg, ty))
|
||||
(ty is Any)
|
||||
or (isinstance(ty, type) and isinstance(arg, ty))
|
||||
or (get_origin(ty) is tuple and isinstance(arg, tuple))
|
||||
for ty in allowed_types
|
||||
):
|
||||
@ -381,6 +488,26 @@ class CutlassBaseDSL(BaseDSL):
|
||||
jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals)
|
||||
else:
|
||||
jit_exec_arg = jit_arg_type = jit_arg_attr = None
|
||||
elif not hasattr(arg, "__extract_mlir_values__") and not hasattr(
|
||||
arg, "__new_from_mlir_values__"
|
||||
):
|
||||
# Try tree_flatten
|
||||
try:
|
||||
dyn_vals, _ = tree_flatten(arg)
|
||||
except DSLTreeFlattenError:
|
||||
# If fails, just return the original arg
|
||||
return jit_exec_arg, jit_arg_type, jit_arg_attr
|
||||
|
||||
if dyn_vals:
|
||||
jit_arg_type.extend([v.type for v in dyn_vals])
|
||||
jit_arg_attr.extend([default_attr] * len(dyn_vals))
|
||||
jit_exec_arg.extend(
|
||||
_get_c_pointers_cutlass(arg) if is_host else dyn_vals
|
||||
)
|
||||
else:
|
||||
# If tree flatten yields empty list, treat it as a constexpr thing
|
||||
# Like a dataclass with all fields are constexpr, or an empty tuple or list
|
||||
jit_exec_arg = jit_arg_type = jit_arg_attr = None
|
||||
return jit_exec_arg, jit_arg_type, jit_arg_attr
|
||||
|
||||
def _generate_execution_arguments_for_known_types(
|
||||
@ -396,6 +523,17 @@ class CutlassBaseDSL(BaseDSL):
|
||||
blk_args = fop_args[iv_block_args : iv_block_args + n_args]
|
||||
ir_arg.append(new_from_mlir_values(arg, blk_args))
|
||||
iv_block_args += n_args
|
||||
elif not hasattr(arg, "__extract_mlir_values__") and not hasattr(
|
||||
arg, "__new_from_mlir_values__"
|
||||
):
|
||||
# Try tree_unflatten
|
||||
try:
|
||||
dyn_vals, tree_def = tree_flatten(arg)
|
||||
block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)]
|
||||
ir_arg.append(tree_unflatten(tree_def, block_args))
|
||||
iv_block_args += len(dyn_vals)
|
||||
except DSLTreeFlattenError:
|
||||
return ir_arg, iv_block_args
|
||||
|
||||
return ir_arg, iv_block_args
|
||||
|
||||
@ -458,10 +596,7 @@ class KernelLauncher:
|
||||
|
||||
def _check_func_args(self, funcBody, *func_args, **func_kwargs):
|
||||
# Get function signature
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
sig = funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(funcBody)
|
||||
sig = inspect.signature(funcBody)
|
||||
|
||||
# func_args and func_kwargs should match funcBody's signature,
|
||||
# no extra or missing arguments.
|
||||
@ -473,6 +608,12 @@ class KernelLauncher:
|
||||
cause=e,
|
||||
)
|
||||
|
||||
def smem_usage(self) -> int:
|
||||
"""
|
||||
Check smem usage for this kernel, only available after `launch`
|
||||
"""
|
||||
return self.dsl._get_smem_usage()
|
||||
|
||||
def launch(self, *args, **kwargs):
|
||||
self.dsl.frame = inspect.currentframe().f_back
|
||||
self.dsl._preprocess_launch_config_args(args, kwargs)
|
||||
@ -497,134 +638,151 @@ class KernelLauncher:
|
||||
# =============================================================================
|
||||
# Utils
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def is_frozen_dataclass(obj_or_cls) -> bool:
|
||||
def _filter_readonly_frozen_dataclass(
|
||||
iter_args: List[Any], items_to_filter: List[Any], full_write_args_count: int
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Return True if obj_or_cls is a dataclass (class or instance) declared with frozen=True,
|
||||
otherwise False.
|
||||
"""
|
||||
if not isinstance(obj_or_cls, type):
|
||||
# If it's an instance, get its class
|
||||
obj_or_cls = obj_or_cls.__class__
|
||||
Filter items based on whether corresponding iter_args are frozen dataclasses.
|
||||
|
||||
# Must be a dataclass, and __dataclass_params__.frozen must be True
|
||||
return (
|
||||
is_dataclass(obj_or_cls)
|
||||
and getattr(obj_or_cls, "__dataclass_params__", None) is not None
|
||||
and obj_or_cls.__dataclass_params__.frozen
|
||||
This function filters items (which can be values or names) based on the same
|
||||
logic: keep items if they correspond to full-write arguments (index < full_write_args_count)
|
||||
or if the corresponding iter_arg is not a frozen dataclass.
|
||||
|
||||
Args:
|
||||
iter_args: List of arguments to check for frozen dataclass status
|
||||
items_to_filter: List of items to filter (values or names)
|
||||
full_write_args_count: Number of arguments that are always written (not read-only)
|
||||
|
||||
Returns:
|
||||
Filtered list of items
|
||||
|
||||
Examples:
|
||||
# Filter values (original remove_read_only_frozen_dataclass behavior)
|
||||
filtered_values = _filter_readonly_frozen_dataclass(iter_args, iter_args, full_write_args_count)
|
||||
|
||||
# Filter names (original filter_readonly_frozen_dataclass_names behavior)
|
||||
filtered_names = _filter_readonly_frozen_dataclass(iter_args, iter_args_names, full_write_args_count)
|
||||
"""
|
||||
return [
|
||||
item
|
||||
for i, item in enumerate(items_to_filter)
|
||||
if i < full_write_args_count or not is_frozen_dataclass(iter_args[i])
|
||||
]
|
||||
|
||||
|
||||
def remove_read_only_frozen_dataclass(
|
||||
iter_args: List[Any], full_write_args_count: int
|
||||
) -> List[Any]:
|
||||
"""Filter out frozen dataclass arguments that are not full-write arguments."""
|
||||
return _filter_readonly_frozen_dataclass(
|
||||
iter_args, iter_args, full_write_args_count
|
||||
)
|
||||
|
||||
|
||||
def filter_readonly_frozen_dataclass_names(
|
||||
iter_args: List[Any], iter_args_names: List[str], full_write_args_count: int
|
||||
) -> List[str]:
|
||||
"""Filter names based on whether corresponding iter_args are frozen dataclasses."""
|
||||
return _filter_readonly_frozen_dataclass(
|
||||
iter_args, iter_args_names, full_write_args_count
|
||||
)
|
||||
|
||||
|
||||
def insert_read_only_frozen_dataclass(
|
||||
iter_args: List[Any], original_iter_args: List[Any], full_write_args_count: int
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Insert read-only frozen dataclass arguments back into the iteration arguments.
|
||||
|
||||
This function takes the new iteration arguments and the original arguments,
|
||||
and preserves frozen dataclass instances from the original arguments while
|
||||
using the new arguments for non-frozen dataclass instances.
|
||||
|
||||
Args:
|
||||
iter_args: New iteration arguments to use for non-frozen dataclass instances
|
||||
original_iter_args: Original iteration arguments to preserve frozen dataclass instances from
|
||||
full_write_args_count: Number of arguments that are always written (not read-only)
|
||||
|
||||
Returns:
|
||||
List of arguments with frozen dataclass instances preserved from original
|
||||
"""
|
||||
# Take full-write arguments from new iter_args
|
||||
full_write_args = (
|
||||
iter_args[:full_write_args_count] if full_write_args_count > 0 else []
|
||||
)
|
||||
|
||||
# Process remaining arguments: preserve frozen dataclass from original, use new for others
|
||||
remaining_original = original_iter_args[full_write_args_count:]
|
||||
remaining_new = iter_args[full_write_args_count:]
|
||||
|
||||
def process_remaining_arg(original_arg, new_arg_iter):
|
||||
"""Process a single remaining argument, preserving frozen dataclass if present"""
|
||||
return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter)
|
||||
|
||||
# Use zip to pair original args with new args, then map the processing function
|
||||
new_arg_iter = iter(remaining_new)
|
||||
processed_remaining = [
|
||||
process_remaining_arg(orig_arg, new_arg_iter) for orig_arg in remaining_original
|
||||
]
|
||||
|
||||
return full_write_args + processed_remaining
|
||||
|
||||
|
||||
def unpack_to_irvalue(
|
||||
mixed_values: List[Any], body_name: str, full_write_args_count: int
|
||||
) -> Tuple[List[ir.Value], PyTreeDef]:
|
||||
log().debug("===--- Values UNPack")
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
|
||||
|
||||
try:
|
||||
unpacked_values, treedef = tree_flatten(
|
||||
remove_read_only_frozen_dataclass(mixed_values, full_write_args_count)
|
||||
)
|
||||
except DSLTreeFlattenError as e:
|
||||
raise DSLRuntimeError(
|
||||
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression.",
|
||||
context={
|
||||
e.message: (
|
||||
f"All expressions within '{body_name}' must be dynamic expressions, "
|
||||
"mixing Python objects and dynamic expressions is not supported. "
|
||||
"The DSL failed to convert the Python object into dynamic expressions."
|
||||
)
|
||||
},
|
||||
suggestion=(
|
||||
f"Please ensure '{e.type_str}' implements the '{DynamicExpression.__name__}' or mark with `dataclass`, "
|
||||
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
|
||||
),
|
||||
)
|
||||
|
||||
log().debug("------------------ ")
|
||||
for idx, unpacked in enumerate(unpacked_values):
|
||||
log().debug("[%d]: unpacked values: %s", idx, unpacked)
|
||||
log().debug("treedef: %s", treedef)
|
||||
log().debug("------------------ ")
|
||||
|
||||
return unpacked_values, treedef
|
||||
|
||||
|
||||
def pack_from_irvalue(
|
||||
ir_values: List["ir.Value"],
|
||||
indices: Dict[int, Tuple[int, int]],
|
||||
class_types: List[Any],
|
||||
pytree_def: PyTreeDef,
|
||||
mixed_values: List[Any],
|
||||
full_write_args_count: int,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Packs MLIR values into a list of mixed values.
|
||||
"""
|
||||
log().debug("===--- Values Pack (%d)", len(ir_values))
|
||||
for idx, packed in enumerate(ir_values):
|
||||
log().debug("[%d]: will-packed: %s", idx, ir_values)
|
||||
for idx, unpacked in indices.items():
|
||||
log().debug("[%d]: indices: %s", idx, unpacked)
|
||||
for idx, c in enumerate(class_types):
|
||||
log().debug("[%d]: obj-types: %s", idx, type(c))
|
||||
|
||||
mixed_values = [None] * len(indices)
|
||||
for idx, (start, length) in sorted(indices.items()):
|
||||
chunk = ir_values[start : start + length]
|
||||
obj = class_types[idx]
|
||||
if is_frozen_dataclass(obj):
|
||||
mixed_values[idx] = obj
|
||||
elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"):
|
||||
mixed_values[idx] = obj.__new_from_mlir_values__(chunk)
|
||||
elif isinstance(chunk, list) and chunk[0] is None:
|
||||
mixed_values[idx] = class_types[idx]
|
||||
else:
|
||||
if len(chunk) == 1:
|
||||
try:
|
||||
mixed_values[idx] = t.as_numeric(chunk[0])
|
||||
except ValueError:
|
||||
# Suppress the conversion error and try new_from_mlir_values below
|
||||
pass
|
||||
|
||||
if mixed_values[idx] is None:
|
||||
mixed_values[idx] = new_from_mlir_values(obj, chunk)
|
||||
|
||||
log().debug("------------------ ")
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
log().debug("[%d]: packed: %s", idx, packed)
|
||||
log().debug("------------------ ")
|
||||
return mixed_values
|
||||
|
||||
|
||||
def unpack_to_irvalue(
|
||||
mixed_values: List[Any], body_name: str
|
||||
) -> Tuple[List[ir.Value], List[Any], Dict[int, Tuple[int, int]], List[Any]]:
|
||||
"""
|
||||
Unpacks mixed values into ir.Value values.
|
||||
"""
|
||||
unpacked_values = []
|
||||
ir_values = []
|
||||
indices = {}
|
||||
class_types = []
|
||||
current_offset = 0
|
||||
|
||||
log().debug("===--- Values UNPack (%d)", len(mixed_values))
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
|
||||
for idx, item in enumerate(mixed_values):
|
||||
class_types.append(item)
|
||||
try:
|
||||
if is_frozen_dataclass(item):
|
||||
extracted_vals = [None]
|
||||
else:
|
||||
extracted_vals = extract_mlir_values(item)
|
||||
# it's consexpr (python value), so we create mlir value for it
|
||||
if extracted_vals == []:
|
||||
if item is None:
|
||||
extracted_vals = [None]
|
||||
else:
|
||||
dyn_expr = t.as_numeric(item)
|
||||
extracted_vals = extract_mlir_values(dyn_expr)
|
||||
ir_values.extend(extracted_vals)
|
||||
else:
|
||||
ir_values.extend(extracted_vals)
|
||||
|
||||
unpacked_values.extend(extracted_vals)
|
||||
length = len(extracted_vals)
|
||||
indices[idx] = (current_offset, length)
|
||||
current_offset += length
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(
|
||||
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression (aka MLIR value).",
|
||||
context={
|
||||
item: (
|
||||
f"All expressions within '{body_name}' must be dynamic expressions, "
|
||||
"mixing Python objects and dynamic expressions (aka MLIR values) is not supported. "
|
||||
"The DSL failed to convert the Python object into MLIR values."
|
||||
)
|
||||
},
|
||||
suggestion=(
|
||||
f"Please ensure '{item}' implements the '{DynamicExpression.__name__}', "
|
||||
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
|
||||
),
|
||||
) from e
|
||||
|
||||
log().debug("------------------ ")
|
||||
for idx, unpacked in enumerate(unpacked_values):
|
||||
log().debug("[%d]: unpacked values: %s", idx, unpacked)
|
||||
for idx, unpacked in enumerate(ir_values):
|
||||
log().debug("[%d]: unpacked ir_values: %s", idx, unpacked)
|
||||
for idx, unpacked in indices.items():
|
||||
log().debug("[%d]: indices: %s", idx, unpacked)
|
||||
for idx, unpacked in enumerate(class_types):
|
||||
log().debug("[%d]: initial-class-types: %s", idx, unpacked)
|
||||
for idx, value in enumerate(ir_values):
|
||||
log().debug("[%d]: will-packed: %s", idx, value)
|
||||
log().debug("treedef: %s", pytree_def)
|
||||
log().debug("------------------ ")
|
||||
|
||||
return ir_values, unpacked_values, indices, class_types
|
||||
unflattened = tree_unflatten(pytree_def, ir_values)
|
||||
return insert_read_only_frozen_dataclass(
|
||||
unflattened, mixed_values, full_write_args_count
|
||||
)
|
||||
|
||||
|
||||
def to_index(value):
|
||||
@ -1015,8 +1173,8 @@ def any_(iterable):
|
||||
|
||||
def select_(cond, if_value, else_value):
|
||||
def _as_scalar(value):
|
||||
if const_expr(isinstance(value, list)):
|
||||
if const_expr(len(value) == 1):
|
||||
if isinstance(value, list):
|
||||
if len(value) == 1:
|
||||
return value[0]
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
@ -1024,16 +1182,16 @@ def select_(cond, if_value, else_value):
|
||||
)
|
||||
return value
|
||||
|
||||
if const_expr(not is_dynamic_expression(cond)):
|
||||
if not is_dynamic_expression(cond):
|
||||
raise DSLRuntimeError("Conditional expression must be dynamic")
|
||||
|
||||
# Extract MLIR values
|
||||
cond = extract_mlir_values(cond)
|
||||
if const_expr(is_dynamic_expression(if_value)):
|
||||
if is_dynamic_expression(if_value):
|
||||
if_value = extract_mlir_values(if_value)
|
||||
else:
|
||||
if_value = const(if_value)
|
||||
if const_expr(is_dynamic_expression(else_value)):
|
||||
if is_dynamic_expression(else_value):
|
||||
else_value = extract_mlir_values(else_value)
|
||||
else:
|
||||
else_value = const(else_value)
|
||||
@ -1089,7 +1247,7 @@ def for_generate(
|
||||
iter_args: Optional[Sequence[ir.Value]] = None,
|
||||
*,
|
||||
unroll: LoopUnroll = None,
|
||||
pipelining=None,
|
||||
prefetch_stages=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
@ -1127,8 +1285,8 @@ def for_generate(
|
||||
if unroll is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll
|
||||
|
||||
if pipelining is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining)
|
||||
if prefetch_stages is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = _createI32Attr(prefetch_stages)
|
||||
|
||||
iv = for_op.induction_variable
|
||||
new_results = new_from_mlir_values(iter_args, for_op.results)
|
||||
@ -1155,11 +1313,11 @@ def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None):
|
||||
"""
|
||||
res = None
|
||||
# Handle Python bool first to prevent infinite recursion
|
||||
if const_expr(type(lhs) == bool):
|
||||
if type(lhs) == bool:
|
||||
res = lhs ^ True
|
||||
elif const_expr(hasattr(lhs, "__dsl_not__")):
|
||||
elif hasattr(lhs, "__dsl_not__"):
|
||||
res = lhs.__dsl_not__(loc=loc, ip=ip)
|
||||
elif const_expr(is_dynamic_expression(lhs)):
|
||||
elif is_dynamic_expression(lhs):
|
||||
# If lhs is MLIR value, compute not using xor
|
||||
res = arith.XOrIOp(lhs, const(1, lhs.type)).result
|
||||
else:
|
||||
@ -1338,29 +1496,59 @@ def equal(lhs, rhs):
|
||||
return lhs == rhs
|
||||
|
||||
|
||||
def in_(lhs, rhs, op):
|
||||
def not_equal(lhs, rhs):
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return lhs != rhs
|
||||
|
||||
# Both sequence
|
||||
if isinstance(lhs, Sequence) and isinstance(rhs, Sequence):
|
||||
# Short-circuit for unequal length
|
||||
if len(lhs) != len(rhs):
|
||||
return True
|
||||
return any_(not_equal(l, r) for l, r in zip(lhs, rhs))
|
||||
|
||||
if hasattr(lhs, "__ne__"):
|
||||
return lhs != rhs
|
||||
elif hasattr(rhs, "__ne__"):
|
||||
return rhs != lhs
|
||||
else:
|
||||
return not_(equal(lhs, rhs))
|
||||
|
||||
|
||||
def in_(lhs, rhs):
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return lhs in rhs
|
||||
|
||||
if not isinstance(rhs, Sequence):
|
||||
raise DSLRuntimeError(
|
||||
f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}"
|
||||
f"'in' not supported between instances of {type(lhs)} and {type(rhs)}"
|
||||
)
|
||||
|
||||
return any_(equal(lhs, r) for r in rhs)
|
||||
|
||||
|
||||
def _lt_gt(lhs, rhs, op):
|
||||
def native_lt_gt(lhs, rhs, op):
|
||||
if op == "<":
|
||||
return lhs < rhs
|
||||
elif op == ">":
|
||||
return lhs > rhs
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
def _lte_gte(lhs, rhs, op):
|
||||
def native_lte_gte(lhs, rhs, op):
|
||||
match op:
|
||||
case "<":
|
||||
return lhs < rhs
|
||||
case "<=":
|
||||
if hasattr(lhs, "__le__"):
|
||||
return lhs <= rhs
|
||||
else:
|
||||
return not_(lhs > rhs)
|
||||
case ">":
|
||||
return lhs > rhs
|
||||
case ">=":
|
||||
if hasattr(lhs, "__ge__"):
|
||||
return lhs >= rhs
|
||||
else:
|
||||
return not_(lhs < rhs)
|
||||
case _:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return native_lt_gt(lhs, rhs, op)
|
||||
return native_lte_gte(lhs, rhs, op)
|
||||
|
||||
# Both sequence, comparisons other than == and != do not allow mixing different types of sequences
|
||||
if (
|
||||
@ -1375,7 +1563,7 @@ def _lt_gt(lhs, rhs, op):
|
||||
is_equal = equal(l, r)
|
||||
mask.append(not_(or_(is_equal, unequal_found)))
|
||||
unequal_found = not_(is_equal)
|
||||
comp_results.append(_lt_gt(l, r, op))
|
||||
comp_results.append(_lte_gte(l, r, op))
|
||||
|
||||
result = any_(and_(r, m) for r, m in zip(comp_results, mask))
|
||||
|
||||
@ -1383,62 +1571,126 @@ def _lt_gt(lhs, rhs, op):
|
||||
# Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types
|
||||
# If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one
|
||||
has_valid_mask = any_(mask)
|
||||
if op == "<":
|
||||
length_result = len(lhs) < len(rhs)
|
||||
elif op == ">":
|
||||
length_result = len(lhs) > len(rhs)
|
||||
match op:
|
||||
case "<":
|
||||
length_result = len(lhs) < len(rhs)
|
||||
case ">":
|
||||
length_result = len(lhs) > len(rhs)
|
||||
case "<=":
|
||||
length_result = len(lhs) <= len(rhs)
|
||||
case ">=":
|
||||
length_result = len(lhs) >= len(rhs)
|
||||
if type(has_valid_mask) == bool:
|
||||
return result if has_valid_mask else length_result
|
||||
else:
|
||||
return select_(has_valid_mask, result, length_result)
|
||||
else:
|
||||
return result
|
||||
if op in {"<=", ">="}:
|
||||
# If no unequal, return True
|
||||
return select_(unequal_found, result, True)
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
return native_lt_gt(lhs, rhs, op)
|
||||
return native_lte_gte(lhs, rhs, op)
|
||||
|
||||
|
||||
def greater_than(lhs, rhs):
|
||||
return _lt_gt(lhs, rhs, ">")
|
||||
return _lte_gte(lhs, rhs, ">")
|
||||
|
||||
|
||||
def greater_equal(lhs, rhs):
|
||||
return _lte_gte(lhs, rhs, ">=")
|
||||
|
||||
|
||||
def less_than(lhs, rhs):
|
||||
return _lt_gt(lhs, rhs, "<")
|
||||
return _lte_gte(lhs, rhs, "<")
|
||||
|
||||
|
||||
def less_equal(lhs, rhs):
|
||||
return _lte_gte(lhs, rhs, "<=")
|
||||
|
||||
|
||||
def _compare_dispatch(lhs, rhs, op):
|
||||
"""
|
||||
Dispatches the comparison operation between lhs and rhs based on the given operator.
|
||||
|
||||
:param lhs: The left-hand side operand for the comparison.
|
||||
:param rhs: The right-hand side operand for the comparison.
|
||||
:param op: The comparison operator as a string. Supported operators are:
|
||||
- "is", "is not": Python identity comparisons.
|
||||
- "in", "not in": Membership tests.
|
||||
- "==", "!=": Equality and inequality.
|
||||
- "<", ">", "<=", ">=": Relational comparisons.
|
||||
:return: The result of the comparison, which may be a boolean or a DSL-specific type.
|
||||
:raises DSLRuntimeError: If the operator is not supported.
|
||||
"""
|
||||
match op:
|
||||
# 'is' and 'is not' are pure python operators
|
||||
case "is":
|
||||
return lhs is rhs
|
||||
case "is not":
|
||||
return lhs is not rhs
|
||||
case "in":
|
||||
return in_(lhs, rhs)
|
||||
case "not in":
|
||||
return not_(in_(lhs, rhs))
|
||||
case "==":
|
||||
return equal(lhs, rhs)
|
||||
case "!=":
|
||||
return not_equal(lhs, rhs)
|
||||
case "<":
|
||||
return less_than(lhs, rhs)
|
||||
case ">":
|
||||
return greater_than(lhs, rhs)
|
||||
case ">=":
|
||||
return greater_equal(lhs, rhs)
|
||||
case "<=":
|
||||
return less_equal(lhs, rhs)
|
||||
case _:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
|
||||
|
||||
def _compare_executor(left, comparators, ops):
|
||||
result = left
|
||||
# Fast path for single comparison
|
||||
if len(comparators) == 1:
|
||||
return _compare_dispatch(left, comparators[0], ops[0])
|
||||
|
||||
# Chain comparison, dispatch in a loop
|
||||
result = True
|
||||
current = left
|
||||
for comparator, op in zip(comparators, ops):
|
||||
# 'is' and 'is not' are pure python operators
|
||||
if op == "is":
|
||||
result = result is comparator
|
||||
elif op == "is not":
|
||||
result = result is not comparator
|
||||
elif op in ["in", "not in"]:
|
||||
result = in_(left, comparator, op)
|
||||
elif op in ["==", "!="]:
|
||||
result = equal(left, comparator)
|
||||
elif op in ["<", ">="]:
|
||||
result = less_than(left, comparator)
|
||||
elif op in [">", "<="]:
|
||||
result = greater_than(left, comparator)
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
# Invert the result for NotIn, NotEq, GtE, LtE
|
||||
if op in ["not in", "!=", ">=", "<="]:
|
||||
result = not_(result)
|
||||
cmp_result = _compare_dispatch(current, comparator, op)
|
||||
result = and_(result, cmp_result)
|
||||
current = comparator
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _builtin_redirector(fcn):
|
||||
if fcn == builtins.max:
|
||||
return max
|
||||
elif fcn == builtins.min:
|
||||
return min
|
||||
elif fcn == builtins.any:
|
||||
return any_
|
||||
elif fcn == builtins.all:
|
||||
return all_
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported built-in function: {fcn}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Set the AST decorator
|
||||
# =============================================================================
|
||||
|
||||
# Set the DSL specific functions
|
||||
executor.set_functions(
|
||||
is_dynamic_expression,
|
||||
_loop_execute_range_dynamic,
|
||||
_if_execute_dynamic,
|
||||
_while_execute_dynamic,
|
||||
_compare_executor,
|
||||
any_,
|
||||
all_,
|
||||
is_dynamic_expression=is_dynamic_expression,
|
||||
loop_execute_range_dynamic=_loop_execute_range_dynamic,
|
||||
if_dynamic=_if_execute_dynamic,
|
||||
while_dynamic=_while_execute_dynamic,
|
||||
compare_executor=_compare_executor,
|
||||
any_executor=any_,
|
||||
all_executor=all_,
|
||||
builtin_redirector=_builtin_redirector,
|
||||
)
|
||||
|
||||
@ -14,13 +14,22 @@ from types import NoneType
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import scf, arith
|
||||
from cutlass._mlir.extras import types as T
|
||||
from collections.abc import Sequence
|
||||
|
||||
from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values
|
||||
from ..base_dsl.dsl import is_dynamic_expression
|
||||
from ..base_dsl.ast_helpers import *
|
||||
from ..base_dsl.utils.logger import log
|
||||
from ..base_dsl import typing as t
|
||||
from ..base_dsl.typing import Int32, Float32, Boolean, Numeric, get_mlir_types
|
||||
from ..base_dsl.typing import (
|
||||
Int32,
|
||||
Float32,
|
||||
Boolean,
|
||||
Numeric,
|
||||
get_mlir_types,
|
||||
as_numeric,
|
||||
)
|
||||
from . import cutlass as cutlass_dsl
|
||||
from .tree_utils import PyTreeDef, check_tree_equal
|
||||
|
||||
# =============================================================================
|
||||
# AST Helpers
|
||||
@ -57,14 +66,6 @@ class ScfGenerator:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def fill_none(ir_values, unpacked_values):
|
||||
i = 0
|
||||
for idx, item in enumerate(unpacked_values):
|
||||
if item is not None:
|
||||
unpacked_values[idx] = ir_values[i]
|
||||
i += 1
|
||||
|
||||
@staticmethod
|
||||
def _normalize_region_result_to_list(region_result: Any) -> List[Any]:
|
||||
"""
|
||||
@ -82,34 +83,109 @@ class ScfGenerator:
|
||||
return region_result_list
|
||||
|
||||
@staticmethod
|
||||
def check_region_result(region_values, ir_values):
|
||||
for i, (expected_value, actual_value) in enumerate(
|
||||
zip(ir_values, region_values)
|
||||
def _check_region_result(original_value, region_value, arg_name, op_type_name):
|
||||
"""
|
||||
Validate that a region result maintains the same type as the original value.
|
||||
|
||||
This method checks for type consistency between the original value passed to a dynamic
|
||||
SCF operation (like for, if, while) and the value returned from the operation's region.
|
||||
|
||||
Args:
|
||||
original_value: The value before entering the SCF operation region
|
||||
region_value: The value returned from the SCF operation region
|
||||
arg_name: Name of the argument being checked (for error reporting)
|
||||
op_type_name: Type of SCF operation (e.g., 'for', 'if', 'while') for error reporting
|
||||
|
||||
Raises:
|
||||
DSLRuntimeError: If the region value has a different type than the original value.
|
||||
The error includes suggestions for using compile-time control flow instead.
|
||||
|
||||
Note:
|
||||
This method performs relaxed type checking that allows inheritance relationships.
|
||||
For example, a child class can be returned where a parent class was expected.
|
||||
However, fundamental type changes (like None to non-None, different sequence types,
|
||||
or different numeric types) are not allowed in dynamic SCF operations.
|
||||
"""
|
||||
|
||||
def get_type_name(value):
|
||||
if isinstance(value, NoneType):
|
||||
return "None"
|
||||
elif isinstance(value, Sequence):
|
||||
return f"{type(value).__name__}<{len(value)}>"
|
||||
else:
|
||||
return type(value).__name__
|
||||
|
||||
# Check for type mismatches
|
||||
type_mismatch = False
|
||||
old_type_name = None
|
||||
new_type_name = None
|
||||
|
||||
# Handle None type changes
|
||||
if isinstance(original_value, NoneType) != isinstance(region_value, NoneType):
|
||||
type_mismatch = True
|
||||
old_type_name = get_type_name(original_value)
|
||||
new_type_name = get_type_name(region_value)
|
||||
# Handle sequence type/length changes
|
||||
elif isinstance(original_value, Sequence) and isinstance(
|
||||
region_value, Sequence
|
||||
):
|
||||
expected_value_type = get_mlir_types(expected_value)
|
||||
actual_value_type = get_mlir_types(actual_value)
|
||||
if expected_value_type != actual_value_type:
|
||||
return False, i, expected_value_type, actual_value_type
|
||||
return True, -1, None, None
|
||||
if type(original_value) != type(region_value) or len(original_value) != len(
|
||||
region_value
|
||||
):
|
||||
type_mismatch = True
|
||||
old_type_name = get_type_name(original_value)
|
||||
new_type_name = get_type_name(region_value)
|
||||
# Handle numeric type changes
|
||||
elif isinstance(
|
||||
original_value, (Numeric, ArithValue, ir.Value, int, float, bool)
|
||||
) or isinstance(
|
||||
region_value, (Numeric, ArithValue, ir.Value, int, float, bool)
|
||||
):
|
||||
try:
|
||||
original_numeric = as_numeric(original_value)
|
||||
region_numeric = as_numeric(region_value)
|
||||
if original_numeric.dtype != region_numeric.dtype:
|
||||
type_mismatch = True
|
||||
old_type_name = original_numeric.dtype.__name__
|
||||
new_type_name = region_numeric.dtype.__name__
|
||||
except Exception:
|
||||
pass
|
||||
# Handle general type changes (relaxed for inheritance)
|
||||
elif type(original_value) != type(region_value):
|
||||
old_type = type(original_value)
|
||||
new_type = type(region_value)
|
||||
if not (issubclass(old_type, new_type) or issubclass(new_type, old_type)):
|
||||
type_mismatch = True
|
||||
old_type_name = old_type.__name__
|
||||
new_type_name = new_type.__name__
|
||||
|
||||
if type_mismatch:
|
||||
raise DSLRuntimeError(
|
||||
f"`{arg_name}` is {old_type_name} prior to this `{op_type_name}`, "
|
||||
f"and update to {new_type_name} inside of this `{op_type_name}` is not supported.",
|
||||
suggestion=(
|
||||
f"Please avoid changing type inside a dynamic `{op_type_name}`, "
|
||||
f"or change to compile-time control flow by marking this `{op_type_name}` with "
|
||||
f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`."
|
||||
),
|
||||
)
|
||||
|
||||
def scf_execute_dynamic(
|
||||
self,
|
||||
op_type_name: str,
|
||||
used_args: List[Any],
|
||||
mix_iter_args: List[Any],
|
||||
full_write_args_count: int,
|
||||
mix_iter_arg_names: List[str],
|
||||
create_op_func: Callable[
|
||||
[List[ir.Value], Dict[int, Tuple[int, int]], List[Any]], ir.Operation
|
||||
],
|
||||
create_op_func: Callable[[List[ir.Value]], ir.Operation],
|
||||
region_builders: List[
|
||||
Callable[
|
||||
[
|
||||
"ir.Operation",
|
||||
List["ir.Value"], # block_args
|
||||
List[Any], # used_args
|
||||
List["ir.Value"], # dyn_yield_ops
|
||||
Dict[int, Tuple[int, int]],
|
||||
PyTreeDef,
|
||||
List[Any],
|
||||
int,
|
||||
],
|
||||
Any,
|
||||
]
|
||||
@ -119,11 +195,11 @@ class ScfGenerator:
|
||||
block_term_op_builder: Dict[Callable, Callable] = {},
|
||||
) -> Any:
|
||||
# 1) Unpack
|
||||
ir_values, dyn_unpacked_values, dyn_indices, dyn_class_types = (
|
||||
cutlass_dsl.unpack_to_irvalue(mix_iter_args, op_type_name)
|
||||
ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue(
|
||||
mix_iter_args, op_type_name, full_write_args_count
|
||||
)
|
||||
# 2) Create the SCF op
|
||||
op = create_op_func(ir_values, dyn_indices, dyn_class_types)
|
||||
op = create_op_func(ir_values)
|
||||
log().debug("Generated scf.%s \n[%s]", op_type_name, op)
|
||||
|
||||
# 3) Build the regions
|
||||
@ -135,76 +211,61 @@ class ScfGenerator:
|
||||
region_result = builder(
|
||||
op,
|
||||
block_args,
|
||||
used_args,
|
||||
dyn_unpacked_values,
|
||||
dyn_indices,
|
||||
dyn_class_types,
|
||||
ir_values,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
)
|
||||
|
||||
# Use custom terminator if provided for this builder, otherwise use default YieldOp
|
||||
if builder in block_term_op_builder:
|
||||
# Use the provided terminator generator
|
||||
block_term_op_builder[builder](region_result)
|
||||
block_term_op_builder[builder](region_result, full_write_args_count)
|
||||
else:
|
||||
# For standard yield op, check result
|
||||
for arg, result, name in zip(
|
||||
mix_iter_args,
|
||||
(
|
||||
region_result
|
||||
if isinstance(region_result, list)
|
||||
else [region_result]
|
||||
),
|
||||
mix_iter_arg_names,
|
||||
):
|
||||
if isinstance(arg, NoneType) and not isinstance(
|
||||
result, NoneType
|
||||
):
|
||||
raise DSLRuntimeError(
|
||||
(
|
||||
f"`{name}` is None prior to this `{op_type_name}`, "
|
||||
f"and update to non-None value inside of this `{op_type_name}` is not supported."
|
||||
),
|
||||
suggestion=(
|
||||
f"Please make sure `{name}` is not None prior to this `{op_type_name}`, "
|
||||
f"or mark this `{op_type_name}` with "
|
||||
f"`{'range' if op_type_name == 'for' else 'const_expr'}`."
|
||||
),
|
||||
)
|
||||
# Normalize region_result
|
||||
region_result_list = ScfGenerator._normalize_region_result_to_list(
|
||||
region_result
|
||||
)
|
||||
# Default behavior - generate YieldOp
|
||||
region_values, unpacked_values, _, _ = (
|
||||
cutlass_dsl.unpack_to_irvalue(region_result_list, op_type_name)
|
||||
)
|
||||
|
||||
is_match, mismatch_idx, expected_type, actual_type = (
|
||||
ScfGenerator.check_region_result(region_values, ir_values)
|
||||
)
|
||||
|
||||
if not is_match:
|
||||
# From unpacked index, we need to find the original index
|
||||
original_idx = -1
|
||||
for unpacked_idx, (original_idx, length) in dyn_indices.items():
|
||||
if (
|
||||
mismatch_idx >= original_idx
|
||||
and mismatch_idx < original_idx + length
|
||||
):
|
||||
original_idx = unpacked_idx
|
||||
break
|
||||
raise DSLRuntimeError(
|
||||
f"`{op_type_name}` expects {expected_type} type for varible `{mix_iter_arg_names[original_idx]}`, but got {actual_type}.",
|
||||
suggestion=f"Please make sure `{mix_iter_arg_names[original_idx]}` type is not changed inside of `{op_type_name}`.",
|
||||
# For standard yield op, check result
|
||||
for arg, result, name in zip(
|
||||
mix_iter_args,
|
||||
region_result_list,
|
||||
mix_iter_arg_names,
|
||||
):
|
||||
ScfGenerator._check_region_result(
|
||||
arg, result, name, op_type_name
|
||||
)
|
||||
|
||||
# Default behavior - generate YieldOp
|
||||
region_values, yield_pytree_def = cutlass_dsl.unpack_to_irvalue(
|
||||
region_result_list, op_type_name, full_write_args_count
|
||||
)
|
||||
|
||||
mismatch = check_tree_equal(pytree_def, yield_pytree_def)
|
||||
if mismatch != -1:
|
||||
# Get arg name
|
||||
filterd_arg_names = (
|
||||
cutlass_dsl.filter_readonly_frozen_dataclass_names(
|
||||
mix_iter_args, mix_iter_arg_names, full_write_args_count
|
||||
)
|
||||
)
|
||||
|
||||
raise DSLRuntimeError(
|
||||
f"`{filterd_arg_names[mismatch]}` is structured different after this `{op_type_name}`.",
|
||||
suggestion=(
|
||||
f"Please avoid changing type structure inside a dynamic `{op_type_name}`, "
|
||||
f"or change to compile-time control flow by marking this `{op_type_name}` with "
|
||||
f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`."
|
||||
),
|
||||
)
|
||||
|
||||
scf.YieldOp(region_values)
|
||||
|
||||
log().debug("Completed scf.%s \n[%s]", op_type_name, op)
|
||||
ScfGenerator.fill_none(op.results, unpacked_values)
|
||||
|
||||
# 4) Pack final results
|
||||
final_results = cutlass_dsl.pack_from_irvalue(
|
||||
unpacked_values, dyn_indices, dyn_class_types
|
||||
op.results, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
|
||||
# 5) Return in a nice pattern
|
||||
@ -215,28 +276,32 @@ class ScfGenerator:
|
||||
return final_results
|
||||
|
||||
|
||||
def _attr_const_check(attr, expected_type, attr_name):
|
||||
# Use strict type equality to prevent `bool` being accepted where `int` is required.
|
||||
if is_dynamic_expression(attr) or type(attr) is not expected_type:
|
||||
raise DSLRuntimeError(
|
||||
f"loop attribute `{attr_name}` must be a Python value of type `{expected_type.__name__}`, got `{type(attr).__name__}`."
|
||||
)
|
||||
|
||||
|
||||
def _loop_execute_range_dynamic(
|
||||
func: Callable,
|
||||
start: Any,
|
||||
stop: Any,
|
||||
step: Any,
|
||||
used_args: List[Any] = [],
|
||||
mix_iter_args: List[Any] = [],
|
||||
full_write_args_count: int = 0,
|
||||
mix_iter_arg_names: List[str] = [],
|
||||
unroll: int = -1,
|
||||
unroll_full: bool = False,
|
||||
pipelining: int = None,
|
||||
prefetch_stages: int = None,
|
||||
):
|
||||
"""
|
||||
Example: build an scf.for with optional unroll, using our universal helper.
|
||||
"""
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_for_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
def create_for_op(dyn_yield_ops: List[ir.Value]):
|
||||
for d in dyn_yield_ops:
|
||||
if not isinstance(d, ir.Value):
|
||||
raise DSLRuntimeError(
|
||||
@ -254,6 +319,10 @@ def _loop_execute_range_dynamic(
|
||||
stop_ = stop_.ir_value()
|
||||
step_ = step_.ir_value()
|
||||
|
||||
# Attributes must be pure Python value, add a check
|
||||
_attr_const_check(unroll, int, "unroll")
|
||||
_attr_const_check(unroll_full, bool, "unroll_full")
|
||||
|
||||
# Possibly attach unroll attributes
|
||||
unroll_attr = None
|
||||
if unroll_full:
|
||||
@ -262,17 +331,18 @@ def _loop_execute_range_dynamic(
|
||||
unroll_attr = LoopUnroll(count=unroll)
|
||||
log().debug("Unroll attribute: %s", unroll_attr)
|
||||
|
||||
pipelining_attr = None
|
||||
if pipelining is not None:
|
||||
if pipelining >= 0:
|
||||
pipelining_attr = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(32), pipelining
|
||||
prefetch_stages_attr = None
|
||||
if prefetch_stages is not None:
|
||||
_attr_const_check(prefetch_stages, int, "prefetch_stages")
|
||||
if prefetch_stages >= 0:
|
||||
prefetch_stages_attr = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(32), prefetch_stages
|
||||
)
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
f"Pipelining must be non-negative, got {pipelining}"
|
||||
f"loop attribute `prefetch_stages` must be non-negative, got `{prefetch_stages}`."
|
||||
)
|
||||
log().debug("Pipelining attribute: %s", pipelining_attr)
|
||||
log().debug("prefetch_stages attribute: %s", prefetch_stages_attr)
|
||||
|
||||
log().debug(
|
||||
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
|
||||
@ -303,47 +373,48 @@ def _loop_execute_range_dynamic(
|
||||
if unroll_attr is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll_attr
|
||||
|
||||
if pipelining_attr is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = pipelining_attr
|
||||
if prefetch_stages_attr is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = prefetch_stages_attr
|
||||
|
||||
return for_op
|
||||
|
||||
def for_body_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
op,
|
||||
block_args,
|
||||
_,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
# Insert induction variable at the beginning
|
||||
dyn_yield_ops.insert(0, block_args[0])
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
# scf.ForOp block_args are typically [induction_var, iter_args...]
|
||||
# But MLIR also gives you op.induction_variable
|
||||
iv = t.as_numeric(op.induction_variable)
|
||||
log().debug(
|
||||
"For body builder: %s block_args: %s used_args: %s",
|
||||
"For body builder: %s block_args: %s full_write_args_count: %s",
|
||||
iv,
|
||||
block_args,
|
||||
used_args,
|
||||
full_write_args_count,
|
||||
)
|
||||
if len(block_args) <= 1:
|
||||
# block_args[1:] are iteration variables
|
||||
func_args = []
|
||||
func_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args[1:], pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
if not func_args:
|
||||
# No iteration arguments, or only the induction var
|
||||
func(iv, *used_args)
|
||||
func(iv)
|
||||
return [] # yield nothing
|
||||
else:
|
||||
# block_args[1:] are iteration variables
|
||||
func_args = [*used_args]
|
||||
func_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args[1:], dyn_indices, dyn_class_types
|
||||
)
|
||||
)
|
||||
updated_func_args = func(iv, *func_args)
|
||||
return updated_func_args
|
||||
|
||||
# Now call the universal SCF executor with a single region builder
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name="for",
|
||||
used_args=used_args,
|
||||
mix_iter_args=mix_iter_args,
|
||||
full_write_args_count=full_write_args_count,
|
||||
mix_iter_arg_names=mix_iter_arg_names,
|
||||
create_op_func=create_for_op,
|
||||
region_builders=[for_body_builder],
|
||||
@ -354,8 +425,8 @@ def _if_execute_dynamic(
|
||||
pred: "ir.Value",
|
||||
then_block: Callable,
|
||||
else_block: Callable = None,
|
||||
used_args: List[Any] = [],
|
||||
mix_yield_args: List[Any] = [],
|
||||
full_write_args_count: int = 0,
|
||||
mix_yield_arg_names: List[str] = [],
|
||||
if_constexpr=None, # ignoring for brevity
|
||||
):
|
||||
@ -364,11 +435,7 @@ def _if_execute_dynamic(
|
||||
"""
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_if_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
def create_if_op(dyn_yield_ops: List[ir.Value]):
|
||||
# Assume final result types match the dynamic yields
|
||||
result_types = [arg.type for arg in dyn_yield_ops]
|
||||
|
||||
@ -387,11 +454,18 @@ def _if_execute_dynamic(
|
||||
return if_op
|
||||
|
||||
def then_builder(
|
||||
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
if_op,
|
||||
_,
|
||||
dyn_yield_ops,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
flat_args = [*used_args]
|
||||
flat_args = []
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(dyn_yield_ops, dyn_indices, dyn_class_types)
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
return then_block(*flat_args)
|
||||
|
||||
@ -400,12 +474,17 @@ def _if_execute_dynamic(
|
||||
if else_block is not None:
|
||||
|
||||
def else_builder(
|
||||
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
if_op,
|
||||
_,
|
||||
dyn_yield_ops,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
flat_args = [*used_args]
|
||||
flat_args = []
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
return else_block(*flat_args)
|
||||
@ -414,8 +493,8 @@ def _if_execute_dynamic(
|
||||
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name="if",
|
||||
used_args=used_args,
|
||||
mix_iter_args=mix_yield_args,
|
||||
full_write_args_count=full_write_args_count,
|
||||
mix_iter_arg_names=mix_yield_arg_names,
|
||||
create_op_func=create_if_op,
|
||||
region_builders=region_builders,
|
||||
@ -425,9 +504,9 @@ def _if_execute_dynamic(
|
||||
def _while_execute_dynamic(
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
"""
|
||||
Create and return an SCF WhileOp for dynamic loops.
|
||||
@ -436,8 +515,7 @@ def _while_execute_dynamic(
|
||||
Args:
|
||||
while_before_block: Function that returns (condition, updated_values)
|
||||
while_after_block: Function that returns updated values
|
||||
used_args: Additional arguments used in the loop body
|
||||
yield_args: Values that are updated in the loop
|
||||
write_args: Values that are updated in the loop
|
||||
|
||||
See create_while_function in ast_preprocessor.py for details on the input structure.
|
||||
"""
|
||||
@ -445,11 +523,7 @@ def _while_execute_dynamic(
|
||||
while_op_type_name = "while"
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_while_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
def create_while_op(dyn_yield_ops: List[ir.Value]):
|
||||
# Create the while operation with the types from yield_args
|
||||
result_types = [arg.type for arg in dyn_yield_ops]
|
||||
try:
|
||||
@ -468,14 +542,19 @@ def _while_execute_dynamic(
|
||||
) from e
|
||||
|
||||
def before_block_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
op,
|
||||
block_args,
|
||||
_,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
# Build the before (condition) block
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
flat_args = [*used_args]
|
||||
flat_args = []
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
|
||||
log().debug("before block args: %s", flat_args)
|
||||
@ -493,18 +572,15 @@ def _while_execute_dynamic(
|
||||
|
||||
return cond, before_results
|
||||
|
||||
def before_block_terminator(cond_and_results):
|
||||
def before_block_terminator(cond_and_results, full_write_args_count):
|
||||
# Generate a condition op instead of yield op
|
||||
cond = cond_and_results[0]
|
||||
before_result_list = ScfGenerator._normalize_region_result_to_list(
|
||||
cond_and_results[1]
|
||||
)
|
||||
ir_cond_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
|
||||
[cond], while_op_type_name
|
||||
)
|
||||
ir_cond = ir_cond_list[0]
|
||||
ir_results_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
|
||||
before_result_list, while_op_type_name
|
||||
ir_cond = as_numeric(cond).ir_value()
|
||||
ir_results_list, pytree_def = cutlass_dsl.unpack_to_irvalue(
|
||||
before_result_list, while_op_type_name, full_write_args_count
|
||||
)
|
||||
log().debug(
|
||||
"creating scf.ConditionOp with [%s], [%s]",
|
||||
@ -514,14 +590,19 @@ def _while_execute_dynamic(
|
||||
scf.ConditionOp(ir_cond, ir_results_list)
|
||||
|
||||
def after_block_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
op,
|
||||
block_args,
|
||||
_,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
# Build the after (body) block
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
flat_args = [*used_args]
|
||||
flat_args = []
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
|
||||
log().debug("after block args: %s", flat_args)
|
||||
@ -541,9 +622,9 @@ def _while_execute_dynamic(
|
||||
# Call the universal SCF executor with two region builders
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name=while_op_type_name,
|
||||
used_args=used_args,
|
||||
mix_iter_args=yield_args,
|
||||
mix_iter_arg_names=yield_arg_names,
|
||||
mix_iter_args=write_args,
|
||||
full_write_args_count=full_write_args_count,
|
||||
mix_iter_arg_names=write_args_names,
|
||||
create_op_func=create_while_op,
|
||||
region_builders=[before_block_builder, after_block_builder],
|
||||
block_term_op_builder={
|
||||
|
||||
763
python/CuTeDSL/cutlass_dsl/tree_utils.py
Normal file
763
python/CuTeDSL/cutlass_dsl/tree_utils.py
Normal file
@ -0,0 +1,763 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin
|
||||
import dataclasses
|
||||
import itertools as it
|
||||
from types import SimpleNamespace
|
||||
|
||||
from ..base_dsl.typing import as_numeric, Numeric, Constexpr
|
||||
from ..base_dsl._mlir_helpers.arith import ArithValue
|
||||
from ..base_dsl.common import DSLBaseError
|
||||
from .._mlir import ir
|
||||
|
||||
# =============================================================================
|
||||
# Tree Utils
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DSLTreeFlattenError(DSLBaseError):
|
||||
"""Exception raised when tree flattening fails due to unsupported types."""
|
||||
|
||||
def __init__(self, msg: str, type_str: str):
|
||||
super().__init__(msg)
|
||||
self.type_str = type_str
|
||||
|
||||
|
||||
def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]:
|
||||
"""Unzip a sequence of pairs into two lists."""
|
||||
lst1, lst2 = [], []
|
||||
for x1, x2 in pairs:
|
||||
lst1.append(x1)
|
||||
lst2.append(x2)
|
||||
return lst1, lst2
|
||||
|
||||
|
||||
def get_fully_qualified_class_name(x: Any) -> str:
|
||||
"""
|
||||
Get the fully qualified class name of an object.
|
||||
|
||||
Args:
|
||||
x: Any object
|
||||
|
||||
Returns:
|
||||
str: Fully qualified class name in format 'module.class_name'
|
||||
|
||||
Example:
|
||||
>>> get_fully_qualified_class_name([1, 2, 3])
|
||||
'builtins.list'
|
||||
"""
|
||||
return f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
|
||||
|
||||
def is_frozen_dataclass(obj_or_cls: Any) -> bool:
|
||||
"""
|
||||
Check if an object or class is a frozen dataclass.
|
||||
|
||||
Args:
|
||||
obj_or_cls: Either a dataclass instance or class
|
||||
|
||||
Returns:
|
||||
bool: True if the object/class is a dataclass declared with frozen=True,
|
||||
False otherwise
|
||||
|
||||
Example:
|
||||
>>> from dataclasses import dataclass
|
||||
>>> @dataclass(frozen=True)
|
||||
... class Point:
|
||||
... x: int
|
||||
... y: int
|
||||
>>> is_frozen_dataclass(Point)
|
||||
True
|
||||
>>> is_frozen_dataclass(Point(1, 2))
|
||||
True
|
||||
"""
|
||||
cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__
|
||||
|
||||
return (
|
||||
dataclasses.is_dataclass(cls)
|
||||
and getattr(cls, "__dataclass_params__", None) is not None
|
||||
and cls.__dataclass_params__.frozen
|
||||
)
|
||||
|
||||
|
||||
def is_dynamic_expression(x: Any) -> bool:
|
||||
"""
|
||||
Check if an object implements the DynamicExpression protocol.
|
||||
|
||||
Objects implementing this protocol must have both `__extract_mlir_values__`
|
||||
and `__new_from_mlir_values__` methods.
|
||||
|
||||
Args:
|
||||
x: Any object to check
|
||||
|
||||
Returns:
|
||||
bool: True if the object implements the DynamicExpression protocol,
|
||||
False otherwise
|
||||
"""
|
||||
return all(
|
||||
hasattr(x, attr)
|
||||
for attr in ("__extract_mlir_values__", "__new_from_mlir_values__")
|
||||
)
|
||||
|
||||
|
||||
def is_constexpr_field(field: dataclasses.Field) -> bool:
|
||||
"""
|
||||
Check if a field is a constexpr field.
|
||||
"""
|
||||
if field.type is Constexpr:
|
||||
return True
|
||||
elif get_origin(field.type) is Constexpr:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PyTreeDef
|
||||
# =============================================================================
|
||||
|
||||
class NodeType(NamedTuple):
|
||||
"""
|
||||
Represents a node in a pytree structure.
|
||||
|
||||
Attributes:
|
||||
name: String representation of the node type
|
||||
to_iterable: Function to convert node to iterable form
|
||||
from_iterable: Function to reconstruct node from iterable form
|
||||
"""
|
||||
name: str
|
||||
to_iterable: Callable
|
||||
from_iterable: Callable
|
||||
|
||||
|
||||
class PyTreeDef(NamedTuple):
|
||||
"""
|
||||
Represents the structure definition of a pytree.
|
||||
|
||||
Attributes:
|
||||
node_type: The type of this node
|
||||
node_metadata: SimpleNamespace metadata associated with this node
|
||||
child_treedefs: Tuple of child tree definitions
|
||||
"""
|
||||
node_type: NodeType
|
||||
node_metadata: SimpleNamespace
|
||||
child_treedefs: tuple["PyTreeDef", ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Leaf:
|
||||
"""
|
||||
Represents a leaf node in a pytree structure.
|
||||
|
||||
Attributes:
|
||||
is_numeric: Whether this leaf contains a `Numeric` value
|
||||
is_none: Whether this leaf represents None
|
||||
node_metadata: SimpleNamespace metadata associated with this leaf
|
||||
ir_type_str: String representation of the IR type
|
||||
"""
|
||||
is_numeric: bool = False
|
||||
is_none: bool = False
|
||||
node_metadata: SimpleNamespace = None
|
||||
ir_type_str: str = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Default to_iterable and from_iterable
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]:
|
||||
"""
|
||||
Extract non-method, non-function attributes from a dataclass instance.
|
||||
|
||||
Args:
|
||||
x: A dataclass instance
|
||||
|
||||
Returns:
|
||||
tuple: (field_names, field_values) lists
|
||||
"""
|
||||
fields = [field.name for field in dataclasses.fields(x)]
|
||||
|
||||
# If the dataclass has extra fields, raise an error
|
||||
for k in x.__dict__.keys():
|
||||
if k not in fields:
|
||||
raise DSLTreeFlattenError(
|
||||
f"`{x}` has extra field `{k}`",
|
||||
type_str=get_fully_qualified_class_name(x),
|
||||
)
|
||||
|
||||
if not fields:
|
||||
return [], []
|
||||
|
||||
# record constexpr fields
|
||||
members = []
|
||||
constexpr_fields = []
|
||||
for field in dataclasses.fields(x):
|
||||
if is_constexpr_field(field):
|
||||
constexpr_fields.append(field.name)
|
||||
fields.remove(field.name)
|
||||
v = getattr(x, field.name)
|
||||
if is_dynamic_expression(v):
|
||||
raise DSLTreeFlattenError(
|
||||
f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`",
|
||||
type_str=get_fully_qualified_class_name(x),
|
||||
)
|
||||
else:
|
||||
members.append(getattr(x, field.name))
|
||||
|
||||
return fields, members, constexpr_fields
|
||||
|
||||
|
||||
def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
||||
"""
|
||||
Convert a dataclass instance to iterable form for tree flattening.
|
||||
|
||||
Extracts all non-method, non-function attributes that don't start with '__'
|
||||
and returns them along with metadata about the dataclass.
|
||||
|
||||
Args:
|
||||
x: A dataclass instance
|
||||
|
||||
Returns:
|
||||
tuple: (metadata, members) where metadata contains type info and field names,
|
||||
and members is the list of attribute values
|
||||
"""
|
||||
fields, members, constexpr_fields = extract_dataclass_members(x)
|
||||
|
||||
metadata = SimpleNamespace(
|
||||
type_str=get_fully_qualified_class_name(x),
|
||||
fields=fields,
|
||||
constexpr_fields=constexpr_fields,
|
||||
original_obj=x,
|
||||
)
|
||||
return metadata, members
|
||||
|
||||
|
||||
def set_dataclass_attributes(
|
||||
instance: Any,
|
||||
fields: list[str],
|
||||
values: Iterable[Any],
|
||||
constexpr_fields: list[str],
|
||||
) -> Any:
|
||||
"""
|
||||
Set attributes on a dataclass instance.
|
||||
|
||||
Args:
|
||||
instance: The dataclass instance
|
||||
fields: List of field names
|
||||
values: Iterable of field values
|
||||
is_frozen: Whether the dataclass is frozen
|
||||
|
||||
Returns:
|
||||
The instance with attributes set
|
||||
"""
|
||||
if not fields:
|
||||
return instance
|
||||
|
||||
kwargs = dict(zip(fields, values))
|
||||
for field in constexpr_fields:
|
||||
kwargs[field] = getattr(instance, field)
|
||||
return dataclasses.replace(instance, **kwargs)
|
||||
|
||||
def default_dataclass_from_iterable(
|
||||
metadata: SimpleNamespace, children: Iterable[Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Reconstruct a dataclass instance from iterable form.
|
||||
|
||||
Handles both regular and frozen dataclasses appropriately.
|
||||
|
||||
Args:
|
||||
metadata: Metadata containing type information and field names
|
||||
children: Iterable of attribute values to reconstruct the instance
|
||||
|
||||
Returns:
|
||||
The reconstructed dataclass instance
|
||||
"""
|
||||
instance = metadata.original_obj
|
||||
|
||||
new_instance = set_dataclass_attributes(
|
||||
instance, metadata.fields, children, metadata.constexpr_fields
|
||||
)
|
||||
metadata.original_obj = new_instance
|
||||
return new_instance
|
||||
|
||||
|
||||
def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
||||
"""
|
||||
Convert a dynamic expression to iterable form.
|
||||
|
||||
Uses the object's `__extract_mlir_values__` method to extract MLIR values.
|
||||
|
||||
Args:
|
||||
x: A dynamic expression object
|
||||
|
||||
Returns:
|
||||
tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression
|
||||
and mlir_values are the extracted MLIR values
|
||||
"""
|
||||
return (
|
||||
SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
||||
x.__extract_mlir_values__(),
|
||||
)
|
||||
|
||||
|
||||
def dynamic_expression_from_iterable(
|
||||
metadata: SimpleNamespace, children: Iterable[Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Reconstruct a dynamic expression from iterable form.
|
||||
|
||||
Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values.
|
||||
|
||||
Args:
|
||||
metadata: Metadata containing the original object
|
||||
children: Iterable of MLIR values to reconstruct from
|
||||
|
||||
Returns:
|
||||
The reconstructed dynamic expression object
|
||||
"""
|
||||
return metadata.original_obj.__new_from_mlir_values__(list(children))
|
||||
|
||||
|
||||
def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
||||
"""
|
||||
Convert a dict to iterable form.
|
||||
"""
|
||||
if isinstance(x, SimpleNamespace):
|
||||
keys = list(x.__dict__.keys())
|
||||
values = list(x.__dict__.values())
|
||||
else:
|
||||
keys = list(x.keys())
|
||||
values = list(x.values())
|
||||
|
||||
return (
|
||||
SimpleNamespace(
|
||||
type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys
|
||||
),
|
||||
values,
|
||||
)
|
||||
|
||||
|
||||
def default_dict_from_iterable(
|
||||
metadata: SimpleNamespace, children: Iterable[Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Reconstruct a dict from iterable form.
|
||||
"""
|
||||
instance = metadata.original_obj
|
||||
fields = metadata.fields
|
||||
is_simple_namespace = isinstance(instance, SimpleNamespace)
|
||||
|
||||
for k, v in zip(fields, children):
|
||||
if is_simple_namespace:
|
||||
setattr(instance, k, v)
|
||||
else:
|
||||
instance[k] = v
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Register pytree nodes
|
||||
# =============================================================================
|
||||
|
||||
_node_types: dict[type, NodeType] = {}
|
||||
|
||||
|
||||
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType:
|
||||
"""
|
||||
Register a new node type for pytree operations.
|
||||
|
||||
Args:
|
||||
ty: The type to register
|
||||
to_iter: Function to convert instances of this type to iterable form
|
||||
from_iter: Function to reconstruct instances of this type from iterable form
|
||||
|
||||
Returns:
|
||||
NodeType: The created NodeType instance
|
||||
"""
|
||||
nt = NodeType(str(ty), to_iter, from_iter)
|
||||
_node_types[ty] = nt
|
||||
return nt
|
||||
|
||||
|
||||
def register_default_node_types() -> None:
|
||||
"""Register default node types for pytree operations."""
|
||||
default_registrations = [
|
||||
(
|
||||
tuple,
|
||||
lambda t: (SimpleNamespace(length=len(t)), list(t)),
|
||||
lambda _, xs: tuple(xs),
|
||||
),
|
||||
(
|
||||
list,
|
||||
lambda l: (SimpleNamespace(length=len(l)), list(l)),
|
||||
lambda _, xs: list(xs),
|
||||
),
|
||||
(
|
||||
dict,
|
||||
default_dict_to_iterable,
|
||||
default_dict_from_iterable,
|
||||
),
|
||||
(
|
||||
SimpleNamespace,
|
||||
default_dict_to_iterable,
|
||||
default_dict_from_iterable,
|
||||
),
|
||||
]
|
||||
|
||||
for ty, to_iter, from_iter in default_registrations:
|
||||
register_pytree_node(ty, to_iter, from_iter)
|
||||
|
||||
|
||||
# Initialize default registrations
|
||||
register_default_node_types()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# tree_flatten and tree_unflatten
|
||||
# =============================================================================
|
||||
|
||||
"""
|
||||
Behavior of tree_flatten and tree_unflatten, for example:
|
||||
|
||||
```python
|
||||
a = (1, 2, 3)
|
||||
b = MyClass(a=1, b =[1,2,3])
|
||||
```
|
||||
|
||||
yields the following tree:
|
||||
|
||||
```python
|
||||
tree_a = PyTreeDef(type = 'tuple',
|
||||
metadata = {length = 3},
|
||||
children = [
|
||||
Leaf(type = int),
|
||||
Leaf(type = int),
|
||||
Leaf(type = int),
|
||||
],
|
||||
)
|
||||
flattened_a = [1, 2, 3]
|
||||
tree_b = PyTreeDef(type = 'MyClass',
|
||||
metadata = {fields = ['a','b']},
|
||||
children = [
|
||||
PyTreeDef(type = `list`,
|
||||
metadata = {length = 3},
|
||||
children = [
|
||||
Leaf(type=`int`),
|
||||
Leaf(type=`int`),
|
||||
Leaf(type=`int`),
|
||||
],
|
||||
),
|
||||
Leaf(type=int),
|
||||
],
|
||||
)
|
||||
flattened_b = [1, 1, 2, 3]
|
||||
```
|
||||
|
||||
Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure.
|
||||
|
||||
``` python
|
||||
unflattened_a = tree_unflatten(tree_a, flattened_a)
|
||||
unflattened_b = tree_unflatten(tree_b, flattened_b)
|
||||
```
|
||||
|
||||
yields the following structure:
|
||||
|
||||
``` python
|
||||
unflattened_a = (1, 2, 3)
|
||||
unflattened_b = MyClass(a=1, b =[1,2,3])
|
||||
```
|
||||
|
||||
unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
|
||||
"""
|
||||
Flatten a nested structure into a flat list of values and a tree definition.
|
||||
|
||||
This function recursively traverses nested data structures (trees) and
|
||||
flattens them into a linear list of leaf values, while preserving the
|
||||
structure information in a PyTreeDef.
|
||||
|
||||
Args:
|
||||
x: The nested structure to flatten
|
||||
|
||||
Returns:
|
||||
tuple: (flat_values, treedef) where flat_values is a list of leaf values
|
||||
and treedef is the tree structure definition
|
||||
|
||||
Raises:
|
||||
DSLTreeFlattenError: If the structure contains unsupported types
|
||||
|
||||
Example:
|
||||
>>> tree_flatten([1, [2, 3], 4])
|
||||
([1, 2, 3, 4], PyTreeDef(...))
|
||||
"""
|
||||
children_iter, treedef = _tree_flatten(x)
|
||||
return list(children_iter), treedef
|
||||
|
||||
|
||||
def get_registered_node_types_or_insert(x: Any) -> NodeType | None:
|
||||
"""
|
||||
Get the registered node type for an object, registering it if necessary.
|
||||
|
||||
This function checks if a type is already registered for pytree operations.
|
||||
If not, it automatically registers the type based on its characteristics:
|
||||
- Dynamic expressions get registered with dynamic expression handlers
|
||||
- Dataclasses get registered with default dataclass handlers
|
||||
|
||||
Args:
|
||||
x: The object to get or register a node type for
|
||||
|
||||
Returns:
|
||||
NodeType or None: The registered node type, or None if the type
|
||||
cannot be registered
|
||||
"""
|
||||
node_type = _node_types.get(type(x))
|
||||
if node_type:
|
||||
return node_type
|
||||
elif is_dynamic_expression(x):
|
||||
# If a class implements DynamicExpression protocol, register it before default dataclass one
|
||||
return register_pytree_node(
|
||||
type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable
|
||||
)
|
||||
elif dataclasses.is_dataclass(x):
|
||||
return register_pytree_node(
|
||||
type(x), default_dataclass_to_iterable, default_dataclass_from_iterable
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def create_leaf_for_value(
|
||||
x: Any,
|
||||
is_numeric: bool = False,
|
||||
is_none: bool = False,
|
||||
node_metadata: SimpleNamespace = None,
|
||||
ir_type_str: str = None,
|
||||
) -> Leaf:
|
||||
"""
|
||||
Create a Leaf node for a given value.
|
||||
|
||||
Args:
|
||||
x: The value to create a leaf for
|
||||
is_numeric: Whether this is a numeric value
|
||||
is_none: Whether this represents None
|
||||
node_metadata: Optional metadata
|
||||
ir_type_str: Optional IR type string
|
||||
|
||||
Returns:
|
||||
Leaf: The created leaf node
|
||||
"""
|
||||
return Leaf(
|
||||
is_numeric=is_numeric,
|
||||
is_none=is_none,
|
||||
node_metadata=node_metadata,
|
||||
ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None),
|
||||
)
|
||||
|
||||
|
||||
def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]:
|
||||
"""
|
||||
Internal function to flatten a tree structure.
|
||||
|
||||
This is the core implementation of tree flattening that handles different
|
||||
types of objects including None, ArithValue, ir.Value, Numeric types,
|
||||
and registered pytree node types.
|
||||
|
||||
Args:
|
||||
x: The object to flatten
|
||||
|
||||
Returns:
|
||||
tuple: (flattened_values, treedef) where flattened_values is an iterable
|
||||
of leaf values and treedef is the tree structure
|
||||
|
||||
Raises:
|
||||
DSLTreeFlattenError: If the object type is not supported
|
||||
"""
|
||||
match x:
|
||||
case None:
|
||||
return [], create_leaf_for_value(x, is_none=True)
|
||||
|
||||
case ArithValue() if is_dynamic_expression(x):
|
||||
v = x.__extract_mlir_values__()
|
||||
return v, create_leaf_for_value(
|
||||
x,
|
||||
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
||||
ir_type_str=str(v[0].type),
|
||||
)
|
||||
|
||||
case ArithValue():
|
||||
return [x], create_leaf_for_value(x, is_numeric=True)
|
||||
|
||||
case ir.Value():
|
||||
return [x], create_leaf_for_value(x)
|
||||
|
||||
case Numeric():
|
||||
v = x.__extract_mlir_values__()
|
||||
return v, create_leaf_for_value(
|
||||
x,
|
||||
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
||||
ir_type_str=str(v[0].type),
|
||||
)
|
||||
|
||||
case _:
|
||||
node_type = get_registered_node_types_or_insert(x)
|
||||
if node_type:
|
||||
node_metadata, children = node_type.to_iterable(x)
|
||||
children_flat, child_trees = unzip2(map(_tree_flatten, children))
|
||||
flattened = it.chain.from_iterable(children_flat)
|
||||
return flattened, PyTreeDef(
|
||||
node_type, node_metadata, tuple(child_trees)
|
||||
)
|
||||
|
||||
# Try to convert to numeric
|
||||
try:
|
||||
nval = as_numeric(x).ir_value()
|
||||
return [nval], create_leaf_for_value(nval, is_numeric=True)
|
||||
except Exception:
|
||||
raise DSLTreeFlattenError(
|
||||
"Flatten Error", get_fully_qualified_class_name(x)
|
||||
)
|
||||
|
||||
|
||||
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
|
||||
"""
|
||||
Reconstruct a nested structure from a flat list of values and tree definition.
|
||||
|
||||
This is the inverse operation of tree_flatten. It takes the flattened
|
||||
values and the tree structure definition to reconstruct the original
|
||||
nested structure.
|
||||
|
||||
Args:
|
||||
treedef: The tree structure definition from tree_flatten
|
||||
xs: List of flat values to reconstruct from
|
||||
|
||||
Returns:
|
||||
The reconstructed nested structure
|
||||
|
||||
Example:
|
||||
>>> flat_values, treedef = tree_flatten([1, [2, 3], 4])
|
||||
>>> tree_unflatten(treedef, flat_values)
|
||||
[1, [2, 3], 4]
|
||||
"""
|
||||
return _tree_unflatten(treedef, iter(xs))
|
||||
|
||||
|
||||
def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any:
|
||||
"""
|
||||
Internal function to reconstruct a tree structure.
|
||||
|
||||
This is the core implementation of tree unflattening that handles
|
||||
different types of tree definitions including Leaf nodes and PyTreeDef nodes.
|
||||
|
||||
Args:
|
||||
treedef: The tree structure definition
|
||||
xs: Iterator of flat values to reconstruct from
|
||||
|
||||
Returns:
|
||||
The reconstructed object
|
||||
"""
|
||||
match treedef:
|
||||
case Leaf(is_none=True):
|
||||
return None
|
||||
|
||||
case Leaf(
|
||||
node_metadata=metadata
|
||||
) if metadata and metadata.is_dynamic_expression:
|
||||
return metadata.original_obj.__new_from_mlir_values__([next(xs)])
|
||||
|
||||
case Leaf(is_numeric=True):
|
||||
return as_numeric(next(xs))
|
||||
|
||||
case Leaf():
|
||||
return next(xs)
|
||||
|
||||
case PyTreeDef():
|
||||
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
|
||||
return treedef.node_type.from_iterable(treedef.node_metadata, children)
|
||||
|
||||
|
||||
def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool:
|
||||
"""
|
||||
Check if two tree definitions are structurally equal.
|
||||
|
||||
This is a helper function for check_tree_equal that recursively compares
|
||||
tree structures.
|
||||
|
||||
Args:
|
||||
lhs: Left tree definition (PyTreeDef or Leaf)
|
||||
rhs: Right tree definition (PyTreeDef or Leaf)
|
||||
|
||||
Returns:
|
||||
bool: True if the trees are structurally equal, False otherwise
|
||||
"""
|
||||
match (lhs, rhs):
|
||||
case (Leaf(), Leaf()):
|
||||
return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str
|
||||
|
||||
case (PyTreeDef(), PyTreeDef()):
|
||||
lhs_metadata = lhs.node_metadata
|
||||
rhs_metadata = rhs.node_metadata
|
||||
|
||||
lhs_fields = getattr(lhs_metadata, "fields", [])
|
||||
rhs_fields = getattr(rhs_metadata, "fields", [])
|
||||
lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", [])
|
||||
rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", [])
|
||||
|
||||
return (
|
||||
lhs.node_type == rhs.node_type
|
||||
and lhs_fields == rhs_fields
|
||||
and lhs_constexpr_fields == rhs_constexpr_fields
|
||||
and len(lhs.child_treedefs) == len(rhs.child_treedefs)
|
||||
and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs))
|
||||
)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int:
|
||||
"""
|
||||
Check if two tree definitions are equal and return the index of first difference.
|
||||
|
||||
This function compares two tree definitions and returns the index of the
|
||||
first child that differs, or -1 if they are completely equal.
|
||||
|
||||
Args:
|
||||
lhs: Left tree definition
|
||||
rhs: Right tree definition
|
||||
|
||||
Returns:
|
||||
int: Index of the first differing child, or -1 if trees are equal
|
||||
|
||||
Example:
|
||||
>>> treedef1 = tree_flatten([1, [2, 3]])[1]
|
||||
>>> treedef2 = tree_flatten([1, [2, 4]])[1]
|
||||
>>> check_tree_equal(treedef1, treedef2)
|
||||
1 # The second child differs
|
||||
"""
|
||||
assert len(lhs.child_treedefs) == len(rhs.child_treedefs)
|
||||
|
||||
def find_first_difference(
|
||||
index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]]
|
||||
) -> int:
|
||||
index, (l, r) = index_and_pair
|
||||
return index if not _check_tree_equal(l, r) else -1
|
||||
|
||||
differences = map(
|
||||
find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs))
|
||||
)
|
||||
return next((diff for diff in differences if diff != -1), -1)
|
||||
Reference in New Issue
Block a user