v4.1 release
This commit is contained in:
@ -23,6 +23,11 @@ from ..base_dsl.ast_helpers import (
|
||||
dynamic_expr,
|
||||
assert_executor,
|
||||
bool_cast,
|
||||
compare_executor,
|
||||
any_executor,
|
||||
all_executor,
|
||||
range_value_check,
|
||||
range_perf_warning,
|
||||
)
|
||||
|
||||
from ..base_dsl import *
|
||||
|
||||
@ -20,6 +20,7 @@ from inspect import isclass
|
||||
import functools
|
||||
import pkgutil
|
||||
from dataclasses import is_dataclass
|
||||
from collections.abc import Sequence
|
||||
|
||||
from ..base_dsl import *
|
||||
from ..base_dsl import compiler
|
||||
@ -51,6 +52,11 @@ from ..base_dsl.ast_helpers import (
|
||||
while_executor,
|
||||
assert_executor,
|
||||
bool_cast,
|
||||
compare_executor,
|
||||
any_executor,
|
||||
all_executor,
|
||||
range_value_check,
|
||||
range_perf_warning,
|
||||
)
|
||||
from ..base_dsl.runtime.dlpack_runtime import (
|
||||
get_cute_tensor_c_pointer,
|
||||
@ -67,18 +73,6 @@ from .cutlass_ast_decorators import (
|
||||
_while_execute_dynamic,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Set the AST decorator
|
||||
# =============================================================================
|
||||
|
||||
# Set the DSL specific functions
|
||||
executor.set_functions(
|
||||
is_dynamic_expression,
|
||||
_loop_execute_range_dynamic,
|
||||
_if_execute_dynamic,
|
||||
_while_execute_dynamic,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cutlass DSL Base Abstract Class
|
||||
@ -1023,7 +1017,6 @@ def select_(cond, if_value, else_value):
|
||||
)
|
||||
return value
|
||||
|
||||
# Non-DSL dynamic cond should be handled before this.
|
||||
if const_expr(not is_dynamic_expression(cond)):
|
||||
raise DSLRuntimeError("Conditional expression must be dynamic")
|
||||
|
||||
@ -1089,6 +1082,7 @@ def for_generate(
|
||||
iter_args: Optional[Sequence[ir.Value]] = None,
|
||||
*,
|
||||
unroll: LoopUnroll = None,
|
||||
pipelining=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
@ -1126,6 +1120,9 @@ def for_generate(
|
||||
if unroll is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll
|
||||
|
||||
if pipelining is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining)
|
||||
|
||||
iv = for_op.induction_variable
|
||||
new_results = new_from_mlir_values(iter_args, for_op.results)
|
||||
new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args)
|
||||
@ -1319,3 +1316,122 @@ def while_generate(
|
||||
Generate a WhileLoopContext for a dynamic loop.
|
||||
"""
|
||||
return WhileLoopContext(inputs, condition, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def equal(lhs, rhs):
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return lhs == rhs
|
||||
|
||||
# Both sequence
|
||||
if isinstance(lhs, Sequence) and isinstance(rhs, Sequence):
|
||||
# Short-circuit for unequal length
|
||||
if len(lhs) != len(rhs):
|
||||
return False
|
||||
return all_(equal(l, r) for l, r in zip(lhs, rhs))
|
||||
return lhs == rhs
|
||||
|
||||
|
||||
def in_(lhs, rhs, op):
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return lhs in rhs
|
||||
|
||||
if not isinstance(rhs, Sequence):
|
||||
raise DSLRuntimeError(
|
||||
f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}"
|
||||
)
|
||||
|
||||
return any_(equal(lhs, r) for r in rhs)
|
||||
|
||||
|
||||
def _lt_gt(lhs, rhs, op):
|
||||
def native_lt_gt(lhs, rhs, op):
|
||||
if op == "<":
|
||||
return lhs < rhs
|
||||
elif op == ">":
|
||||
return lhs > rhs
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return native_lt_gt(lhs, rhs, op)
|
||||
|
||||
# Both sequence, comparisons other than == and != do not allow mixing different types of sequences
|
||||
if (
|
||||
isinstance(lhs, Sequence)
|
||||
and isinstance(rhs, Sequence)
|
||||
and type(lhs) == type(rhs)
|
||||
):
|
||||
unequal_found = False
|
||||
comp_results = []
|
||||
mask = []
|
||||
for l, r in zip(lhs, rhs):
|
||||
is_equal = equal(l, r)
|
||||
mask.append(not_(or_(is_equal, unequal_found)))
|
||||
unequal_found = not_(is_equal)
|
||||
comp_results.append(_lt_gt(l, r, op))
|
||||
|
||||
result = any_(and_(r, m) for r, m in zip(comp_results, mask))
|
||||
|
||||
if len(lhs) != len(rhs):
|
||||
# Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types
|
||||
# If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one
|
||||
has_valid_mask = any_(mask)
|
||||
if op == "<":
|
||||
length_result = len(lhs) < len(rhs)
|
||||
elif op == ">":
|
||||
length_result = len(lhs) > len(rhs)
|
||||
if type(has_valid_mask) == bool:
|
||||
return result if has_valid_mask else length_result
|
||||
else:
|
||||
return select_(has_valid_mask, result, length_result)
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
return native_lt_gt(lhs, rhs, op)
|
||||
|
||||
|
||||
def greater_than(lhs, rhs):
|
||||
return _lt_gt(lhs, rhs, ">")
|
||||
|
||||
|
||||
def less_than(lhs, rhs):
|
||||
return _lt_gt(lhs, rhs, "<")
|
||||
|
||||
|
||||
def _compare_executor(left, comparators, ops):
|
||||
result = left
|
||||
for comparator, op in zip(comparators, ops):
|
||||
# 'is' and 'is not' are pure python operators
|
||||
if op == "is":
|
||||
result = result is comparator
|
||||
elif op == "is not":
|
||||
result = result is not comparator
|
||||
elif op in ["in", "not in"]:
|
||||
result = in_(left, comparator, op)
|
||||
elif op in ["==", "!="]:
|
||||
result = equal(left, comparator)
|
||||
elif op in ["<", ">="]:
|
||||
result = less_than(left, comparator)
|
||||
elif op in [">", "<="]:
|
||||
result = greater_than(left, comparator)
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
# Invert the result for NotIn, NotEq, GtE, LtE
|
||||
if op in ["not in", "!=", ">=", "<="]:
|
||||
result = not_(result)
|
||||
return result
|
||||
|
||||
# =============================================================================
|
||||
# Set the AST decorator
|
||||
# =============================================================================
|
||||
|
||||
# Set the DSL specific functions
|
||||
executor.set_functions(
|
||||
is_dynamic_expression,
|
||||
_loop_execute_range_dynamic,
|
||||
_if_execute_dynamic,
|
||||
_while_execute_dynamic,
|
||||
_compare_executor,
|
||||
any_,
|
||||
all_,
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import List, Tuple
|
||||
from types import NoneType
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import scf, arith
|
||||
from cutlass._mlir.extras import types as T
|
||||
@ -145,6 +146,30 @@ class ScfGenerator:
|
||||
# Use the provided terminator generator
|
||||
block_term_op_builder[builder](region_result)
|
||||
else:
|
||||
# For standard yield op, check result
|
||||
for arg, result, name in zip(
|
||||
mix_iter_args,
|
||||
(
|
||||
region_result
|
||||
if isinstance(region_result, list)
|
||||
else [region_result]
|
||||
),
|
||||
mix_iter_arg_names,
|
||||
):
|
||||
if isinstance(arg, NoneType) and not isinstance(
|
||||
result, NoneType
|
||||
):
|
||||
raise DSLRuntimeError(
|
||||
(
|
||||
f"`{name}` is None prior to this `{op_type_name}`, "
|
||||
f"and update to non-None value inside of this `{op_type_name}` is not supported."
|
||||
),
|
||||
suggestion=(
|
||||
f"Please make sure `{name}` is not None prior to this `{op_type_name}`, "
|
||||
f"or mark this `{op_type_name}` with "
|
||||
f"`{'range' if op_type_name == 'for' else 'const_expr'}`."
|
||||
),
|
||||
)
|
||||
# Normalize region_result
|
||||
region_result_list = ScfGenerator._normalize_region_result_to_list(
|
||||
region_result
|
||||
@ -200,6 +225,7 @@ def _loop_execute_range_dynamic(
|
||||
mix_iter_arg_names: List[str] = [],
|
||||
unroll: int = -1,
|
||||
unroll_full: bool = False,
|
||||
pipelining: int = None,
|
||||
):
|
||||
"""
|
||||
Example: build an scf.for with optional unroll, using our universal helper.
|
||||
@ -236,6 +262,18 @@ def _loop_execute_range_dynamic(
|
||||
unroll_attr = LoopUnroll(count=unroll)
|
||||
log().debug("Unroll attribute: %s", unroll_attr)
|
||||
|
||||
pipelining_attr = None
|
||||
if pipelining is not None:
|
||||
if pipelining >= 0:
|
||||
pipelining_attr = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(32), pipelining
|
||||
)
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
f"Pipelining must be non-negative, got {pipelining}"
|
||||
)
|
||||
log().debug("Pipelining attribute: %s", pipelining_attr)
|
||||
|
||||
log().debug(
|
||||
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
|
||||
start_,
|
||||
@ -265,6 +303,9 @@ def _loop_execute_range_dynamic(
|
||||
if unroll_attr is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll_attr
|
||||
|
||||
if pipelining_attr is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = pipelining_attr
|
||||
|
||||
return for_op
|
||||
|
||||
def for_body_builder(
|
||||
|
||||
Reference in New Issue
Block a user