v4.1 release

This commit is contained in:
Junkai-Wu
2025-07-03 20:07:53 +08:00
committed by GitHub
parent b995f93317
commit a1aaf2300a
155 changed files with 18407 additions and 6068 deletions

View File

@ -23,6 +23,11 @@ from ..base_dsl.ast_helpers import (
dynamic_expr,
assert_executor,
bool_cast,
compare_executor,
any_executor,
all_executor,
range_value_check,
range_perf_warning,
)
from ..base_dsl import *

View File

@ -20,6 +20,7 @@ from inspect import isclass
import functools
import pkgutil
from dataclasses import is_dataclass
from collections.abc import Sequence
from ..base_dsl import *
from ..base_dsl import compiler
@ -51,6 +52,11 @@ from ..base_dsl.ast_helpers import (
while_executor,
assert_executor,
bool_cast,
compare_executor,
any_executor,
all_executor,
range_value_check,
range_perf_warning,
)
from ..base_dsl.runtime.dlpack_runtime import (
get_cute_tensor_c_pointer,
@ -67,18 +73,6 @@ from .cutlass_ast_decorators import (
_while_execute_dynamic,
)
# =============================================================================
# Set the AST decorator
# =============================================================================
# Set the DSL specific functions
executor.set_functions(
is_dynamic_expression,
_loop_execute_range_dynamic,
_if_execute_dynamic,
_while_execute_dynamic,
)
# =============================================================================
# Cutlass DSL Base Abstract Class
@ -1023,7 +1017,6 @@ def select_(cond, if_value, else_value):
)
return value
# Non-DSL dynamic cond should be handled before this.
if const_expr(not is_dynamic_expression(cond)):
raise DSLRuntimeError("Conditional expression must be dynamic")
@ -1089,6 +1082,7 @@ def for_generate(
iter_args: Optional[Sequence[ir.Value]] = None,
*,
unroll: LoopUnroll = None,
pipelining=None,
loc=None,
ip=None,
):
@ -1126,6 +1120,9 @@ def for_generate(
if unroll is not None:
for_op.attributes["loop_annotation"] = unroll
if pipelining is not None:
for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining)
iv = for_op.induction_variable
new_results = new_from_mlir_values(iter_args, for_op.results)
new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args)
@ -1319,3 +1316,122 @@ def while_generate(
Generate a WhileLoopContext for a dynamic loop.
"""
return WhileLoopContext(inputs, condition, loc=loc, ip=ip)
def equal(lhs, rhs):
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return lhs == rhs
# Both sequence
if isinstance(lhs, Sequence) and isinstance(rhs, Sequence):
# Short-circuit for unequal length
if len(lhs) != len(rhs):
return False
return all_(equal(l, r) for l, r in zip(lhs, rhs))
return lhs == rhs
def in_(lhs, rhs, op):
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return lhs in rhs
if not isinstance(rhs, Sequence):
raise DSLRuntimeError(
f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}"
)
return any_(equal(lhs, r) for r in rhs)
def _lt_gt(lhs, rhs, op):
def native_lt_gt(lhs, rhs, op):
if op == "<":
return lhs < rhs
elif op == ">":
return lhs > rhs
else:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return native_lt_gt(lhs, rhs, op)
# Both sequence, comparisons other than == and != do not allow mixing different types of sequences
if (
isinstance(lhs, Sequence)
and isinstance(rhs, Sequence)
and type(lhs) == type(rhs)
):
unequal_found = False
comp_results = []
mask = []
for l, r in zip(lhs, rhs):
is_equal = equal(l, r)
mask.append(not_(or_(is_equal, unequal_found)))
unequal_found = not_(is_equal)
comp_results.append(_lt_gt(l, r, op))
result = any_(and_(r, m) for r, m in zip(comp_results, mask))
if len(lhs) != len(rhs):
# Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types
# If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one
has_valid_mask = any_(mask)
if op == "<":
length_result = len(lhs) < len(rhs)
elif op == ">":
length_result = len(lhs) > len(rhs)
if type(has_valid_mask) == bool:
return result if has_valid_mask else length_result
else:
return select_(has_valid_mask, result, length_result)
else:
return result
else:
return native_lt_gt(lhs, rhs, op)
def greater_than(lhs, rhs):
return _lt_gt(lhs, rhs, ">")
def less_than(lhs, rhs):
return _lt_gt(lhs, rhs, "<")
def _compare_executor(left, comparators, ops):
result = left
for comparator, op in zip(comparators, ops):
# 'is' and 'is not' are pure python operators
if op == "is":
result = result is comparator
elif op == "is not":
result = result is not comparator
elif op in ["in", "not in"]:
result = in_(left, comparator, op)
elif op in ["==", "!="]:
result = equal(left, comparator)
elif op in ["<", ">="]:
result = less_than(left, comparator)
elif op in [">", "<="]:
result = greater_than(left, comparator)
else:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
# Invert the result for NotIn, NotEq, GtE, LtE
if op in ["not in", "!=", ">=", "<="]:
result = not_(result)
return result
# =============================================================================
# Set the AST decorator
# =============================================================================
# Set the DSL specific functions
executor.set_functions(
is_dynamic_expression,
_loop_execute_range_dynamic,
_if_execute_dynamic,
_while_execute_dynamic,
_compare_executor,
any_,
all_,
)

View File

@ -10,6 +10,7 @@
# is strictly prohibited.
from typing import List, Tuple
from types import NoneType
from cutlass._mlir import ir
from cutlass._mlir.dialects import scf, arith
from cutlass._mlir.extras import types as T
@ -145,6 +146,30 @@ class ScfGenerator:
# Use the provided terminator generator
block_term_op_builder[builder](region_result)
else:
# For standard yield op, check result
for arg, result, name in zip(
mix_iter_args,
(
region_result
if isinstance(region_result, list)
else [region_result]
),
mix_iter_arg_names,
):
if isinstance(arg, NoneType) and not isinstance(
result, NoneType
):
raise DSLRuntimeError(
(
f"`{name}` is None prior to this `{op_type_name}`, "
f"and update to non-None value inside of this `{op_type_name}` is not supported."
),
suggestion=(
f"Please make sure `{name}` is not None prior to this `{op_type_name}`, "
f"or mark this `{op_type_name}` with "
f"`{'range' if op_type_name == 'for' else 'const_expr'}`."
),
)
# Normalize region_result
region_result_list = ScfGenerator._normalize_region_result_to_list(
region_result
@ -200,6 +225,7 @@ def _loop_execute_range_dynamic(
mix_iter_arg_names: List[str] = [],
unroll: int = -1,
unroll_full: bool = False,
pipelining: int = None,
):
"""
Example: build an scf.for with optional unroll, using our universal helper.
@ -236,6 +262,18 @@ def _loop_execute_range_dynamic(
unroll_attr = LoopUnroll(count=unroll)
log().debug("Unroll attribute: %s", unroll_attr)
pipelining_attr = None
if pipelining is not None:
if pipelining >= 0:
pipelining_attr = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), pipelining
)
else:
raise DSLRuntimeError(
f"Pipelining must be non-negative, got {pipelining}"
)
log().debug("Pipelining attribute: %s", pipelining_attr)
log().debug(
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
start_,
@ -265,6 +303,9 @@ def _loop_execute_range_dynamic(
if unroll_attr is not None:
for_op.attributes["loop_annotation"] = unroll_attr
if pipelining_attr is not None:
for_op.attributes["cutlass.pipelining"] = pipelining_attr
return for_op
def for_body_builder(