v4.1 release

This commit is contained in:
Junkai-Wu
2025-07-03 20:07:53 +08:00
committed by GitHub
parent b995f93317
commit a1aaf2300a
155 changed files with 18407 additions and 6068 deletions

View File

@ -15,6 +15,8 @@ The preprocessor read through python's ast and changes the input code.
"""
from typing import Callable, Iterator, Optional, overload
from typing_extensions import deprecated
import warnings
from .utils.logger import log
from .common import *
@ -30,13 +32,9 @@ class Executor:
set_functions: Assigns the functions for checking loop bounds and
conditional evaluation.
for_dynamic: Generates MLIR for OP
for_constexpr: Executes a for loop at JIT compile-time
for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds.
if_dynamic: Generates MLIR if OP
if_constexpr: Executes a if at JIT compile-time by python interpreter
if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate.
for_execute: Generates MLIR for OP
while_execute: Generates MLIR while OP
if_execute: generate MLIR if OP
"""
def __init__(self):
@ -44,6 +42,9 @@ class Executor:
self._loop_execute_range_dynamic = None
self._if_dynamic = None
self._while_dynamic = None
self._compare_executor = None
self._any_executor = None
self._all_executor = None
def set_functions(
self,
@ -51,11 +52,17 @@ class Executor:
loop_execute_range_dynamic: Callable,
if_dynamic: Callable,
while_dynamic: Callable,
compare_executor: Callable,
any_executor: Callable = None,
all_executor: Callable = None,
):
self._is_dynamic_expression = is_dynamic_expression
self._loop_execute_range_dynamic = loop_execute_range_dynamic
self._if_dynamic = if_dynamic
self._while_dynamic = while_dynamic
self._compare_executor = compare_executor
self._any_executor = any_executor
self._all_executor = all_executor
@staticmethod
def convert_to_list(x):
@ -83,31 +90,6 @@ class Executor:
return res[0]
return res
def for_dynamic(
self,
func: Callable,
start,
stop,
step,
used_args: list,
iter_args: list,
iter_arg_names: list,
unroll=bool,
unroll_full=int,
):
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
return self._loop_execute_range_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
)
@staticmethod
def for_constexpr(
func: Callable,
@ -143,44 +125,14 @@ class Executor:
iter_arg_names=[],
unroll=-1,
unroll_full=False,
is_range_constexpr=None,
pipelining=None,
):
assert (
self._loop_execute_range_dynamic and self._is_dynamic_expression
self._loop_execute_range_dynamic
), "Functions must be set before execution."
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
any_dynamic_expression = (
self._is_dynamic_expression(start)
or self._is_dynamic_expression(stop)
or self._is_dynamic_expression(step)
)
if is_range_constexpr is None:
if not any_dynamic_expression:
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
else:
return self.for_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
)
# Ensure bounds are compile-time constants for constexpr execution
if is_range_constexpr:
if any_dynamic_expression:
raise DSLRuntimeError(
"Loop bounds must be constexpr (compile-time constants)"
)
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
# MLIR generation
return self.for_dynamic(
return self._loop_execute_range_dynamic(
func,
start,
stop,
@ -190,40 +142,9 @@ class Executor:
iter_arg_names,
unroll,
unroll_full,
pipelining,
)
def if_dynamic(
self,
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
return self._if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
@staticmethod
def if_constexpr(
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
):
if pred:
log().debug(" running then block [%s]", yield_args)
res = then_block(*used_args, *yield_args)
log().debug("result [%s]", res)
return Executor.converge_ret_val(res)
elif else_block is not None:
log().debug("running else [%s]", yield_args)
res = else_block(*used_args, *yield_args)
log().debug("result [%s]", res)
return Executor.converge_ret_val(res)
def if_execute(
self,
pred,
@ -232,94 +153,14 @@ class Executor:
used_args=[],
yield_args=[],
yield_arg_names=[],
if_constexpr=None,
):
assert (
self._if_dynamic and self._is_dynamic_expression
), "Functions must be set before execution."
is_if_constexpr = not self._is_dynamic_expression(pred)
if if_constexpr is None:
if is_if_constexpr:
return self.if_constexpr(
pred, then_block, else_block, used_args, yield_args
)
else:
return self.if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
# Ensure bounds are compile-time constants for constexpr execution
if if_constexpr:
if not is_if_constexpr:
raise DSLRuntimeError(
"If predicate must be constexpr (compile-time constants)"
)
return self.if_constexpr(
pred, then_block, else_block, used_args, yield_args
)
assert self._if_dynamic, "Functions must be set before execution."
# MLIR generation
return self.if_dynamic(
return self._if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
def while_dynamic(
self,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
return self._while_dynamic(
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
)
@staticmethod
def while_constexpr(
while_before_block,
while_after_block,
used_args=[],
yield_args=[],
):
log().debug(
"while_constexpr begin %s", while_before_block.__qualname__
)
cond, loop_results = while_before_block(*used_args, *yield_args)
while cond:
loop_results = Executor.convert_to_list(loop_results)
log().debug(
"calling while_after [%s], [%s]",
used_args,
loop_results,
)
loop_results = while_after_block(*used_args, *loop_results)
log().debug(
"while after [%s]", loop_results
)
loop_results = Executor.convert_to_list(loop_results)
log().debug(
"calling while_before [%s], [%s]",
used_args,
loop_results,
)
cond, loop_results = while_before_block(*used_args, *loop_results)
log().debug(
"while_before cond, results [%s], [%s]",
cond,
loop_results,
)
log().debug(
"while_constexpr results %s", loop_results
)
return Executor.converge_ret_val(loop_results)
def while_execute(
self,
pred,
@ -328,26 +169,11 @@ class Executor:
used_args=[],
yield_args=[],
yield_arg_names=[],
while_constexpr=None,
):
assert (
self._while_dynamic and self._is_dynamic_expression
), "Functions must be set before execution."
is_while_constexpr = not self._is_dynamic_expression(pred)
# Ensure bounds are compile-time constants for constexpr execution
if while_constexpr:
if not is_while_constexpr:
raise DSLRuntimeError(
"While predicate must be constexpr (compile-time constants)"
)
return self.while_constexpr(
while_before_block, while_after_block, used_args, yield_args
)
assert self._while_dynamic, "Functions must be set before execution."
# MLIR generation
return self.while_dynamic(
return self._while_dynamic(
while_before_block,
while_after_block,
used_args,
@ -367,15 +193,16 @@ def loop_selector(
start,
stop,
step,
*,
used_args=[],
iter_args=[],
iter_arg_names=[],
unroll=-1,
unroll_full=False,
constexpr=None,
pipelining=None,
):
log().debug(
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
start,
stop,
step,
@ -383,7 +210,7 @@ def loop_selector(
iter_args,
unroll,
unroll_full,
constexpr,
pipelining,
)
from .typing import Integer, Numeric
@ -408,7 +235,7 @@ def loop_selector(
iter_arg_names,
unroll,
unroll_full,
constexpr,
pipelining,
)
return ir_loop
@ -443,7 +270,6 @@ def while_executor(
used_args=[],
yield_args=[],
yield_arg_names=[],
constexpr=None,
):
return executor.while_execute(
pred,
@ -452,7 +278,6 @@ def while_executor(
used_args,
yield_args,
yield_arg_names,
constexpr,
)
@ -463,10 +288,9 @@ def if_executor(
used_args=[],
yield_args=[],
yield_arg_names=[],
constexpr=None,
):
return executor.if_execute(
pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
@ -475,75 +299,70 @@ def if_executor(
# =============================================================================
class range_dynamic:
class range:
@overload
def __new__(cls, stop, unroll=0, unroll_full=False):
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
pass
@overload
def __new__(cls, start, stop, step, unroll=0, unroll_full=False):
def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
pass
def __new__(cls, *args, **kwargs):
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
class range_constexpr:
def __init__(self, *args):
if len(args) == 1:
self.start = 0
self.stop = args[0]
self.step = 1
elif len(args) == 2:
self.start, self.stop = args
self.step = 1
elif len(args) == 3:
self.start, self.stop, self.step = args
else:
raise DSLRuntimeError(
"range_constexpr supports up to 3 arguments (start, stop, step)"
)
# Ensure the arguments are compile-time constants (if required)
for arg_name, arg_value in [
("step", self.step),
("start", self.start),
("stop", self.stop),
]:
if executor._is_dynamic_expression(arg_value):
raise DSLRuntimeError(
f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, "
f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL "
f"will handle them during runtime. ",
suggestion="Use `range` instead of `range_constexpr`.",
)
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
def __iter__(self) -> Iterator[int]:
current = self.start
while current < self.stop:
yield current
current += self.step
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
@deprecated(
"range_dynamic is deprecated and will be removed in the future, please remove it."
)
def range_dynamic(*args, **kwargs):
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
def range_constexpr(*args):
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
# =============================================================================
# If expressions
# =============================================================================
def const_expr(expression):
if executor._is_dynamic_expression(expression):
"""
This function is used to check if the expression is a python value.
If the expression is a python value, return the boolean value of the expression.
If the expression is a dynamic expression, raise an error.
"""
from .typing import Numeric
failed = False
if isinstance(expression, Numeric):
if isinstance(expression.value, (int, float, bool)):
return expression.value
else:
failed = True
elif executor._is_dynamic_expression(expression):
failed = True
if failed:
raise DSLRuntimeError(
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
context={
"const_expr": "Accepts only constexpr (compile-time constant)",
"If your expression depends on dynamic values": "Avoid marking it as `const_expr()`",
"If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically",
"If your expression depends on dynamic values": "Remove `const_expr()`",
},
)
return expression
@deprecated(
"dynamic_expr is deprecated and will be removed in the future, please remove it."
)
def dynamic_expr(expression):
raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR")
return expression
# =============================================================================
@ -582,3 +401,86 @@ def bool_cast(value):
suggestion = "Please explicitly convert to boolean with expressions like comparision."
)
return bool(value)
def compare_executor(left, comparators, ops):
"""
Executes comparison operations with a left operand and a list of comparators.
Args:
left: The leftmost value in the comparison chain
comparators: A list of values to compare against
ops: A list of comparison operators to apply
Returns:
The result of the comparison chain
Raises:
AssertionError: If the executor function is not set before execution
"""
assert (
executor._compare_executor is not None
), "Function must be set before execution."
return executor._compare_executor(left, comparators, ops)
def any_executor(iterable):
"""Executes the 'any' operation on an iterable, handling both dynamic and static expressions.
:param iterable: An iterable to check if any elements evaluate to True
:type iterable: Iterable
:return: boolean of Python value or IR value
:rtype: bool or cutlass.Boolean
"""
if executor._any_executor and executor._is_dynamic_expression(iterable):
return executor._any_executor(iterable)
else:
return any(iterable)
def all_executor(iterable):
"""Executes the 'all' operation on an iterable, handling both dynamic and static expressions.
:param iterable: An iterable to check if all elements evaluate to True
:type iterable: Iterable
:return: boolean of Python value or IR value
:rtype: bool or cutlass.Boolean
"""
if executor._all_executor and executor._is_dynamic_expression(iterable):
return executor._all_executor(iterable)
else:
return all(iterable)
# =============================================================================
# Control flow checks
# =============================================================================
def range_value_check(*args):
"""
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
"""
try:
return tuple(arg.__index__() for arg in args)
except:
raise DSLRuntimeError(
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
suggestion="Use `range` instead of `range_constexpr`.",
)
def range_perf_warning(filename, lineno, *args):
has_dynamic_expr = False
for arg in args:
if executor._is_dynamic_expression(arg):
has_dynamic_expr = True
break
if not has_dynamic_expr:
warnings.warn_explicit(
(
"The loop was previously unrolled in Python, but now it may not unroll in IR. This may cause performance regression."
"If you want to unroll the loop in Python, please use `range_constexpr` instead of `range`."
),
category=UserWarning,
filename=filename,
lineno=lineno,
)

File diff suppressed because it is too large Load Diff

View File

@ -164,16 +164,17 @@ def _mlir_type_to_numpy_type(type):
def is_dynamic_expression(value):
"""
Check if the value is an MLIR's SSA value.
Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value
"""
# Case 1: If the value has MLIR's SSA value, return True
# Case 2: If the value supports __extract_mlir_values__ then it's possible to get SSA value
return (
isinstance(value, ir.Value)
or hasattr(value, "__extract_mlir_values__")
or len(extract_mlir_values(value)) > 0
)
if isinstance(value, (tuple, list)):
for x in value:
if is_dynamic_expression(x):
return True
elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr(
value, "__extract_mlir_values__"
):
return True
return False
def extract_mlir_values(obj):
"""
@ -726,6 +727,7 @@ class BaseDSL:
)
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
jit_adapted_args = []
default_attr = ir.DictAttr.get({})
input_args = [*args, *kwargs.values()]
@ -759,7 +761,9 @@ class BaseDSL:
# If not any known type, try JIT argument adapter
# to convert the argument
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
arg = adapter(arg) if adapter else arg
if adapter:
arg = adapter(arg)
jit_adapted_args.append(arg)
if is_host:
jit_exec_arg.extend(get_c_pointers(arg))
@ -798,14 +802,14 @@ class BaseDSL:
jit_arg_types.extend(jit_arg_type)
jit_arg_attrs.extend(jit_arg_attr)
return jit_exec_args, jit_arg_types, jit_arg_attrs
return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args
def generate_mlir_function_types(
self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec
):
"""Convert input arguments to MLIR function signature also convert numpy arrays to memref."""
exe_args, types, _ = self._generate_jit_func_args(
exe_args, types, attrs, adapted_args = self._generate_jit_func_args(
func, function_name, input_args, kwargs, args_spec, is_host=True
)
@ -816,7 +820,7 @@ class BaseDSL:
types
), "expects the same number of arguments and function parameters"
return exe_args, types
return exe_args, types, adapted_args
@dataclass
class LaunchConfig:
@ -1158,7 +1162,7 @@ class BaseDSL:
"""Generate MLIR module and compile iself.T_provider."""
with ir.Context(), ir.Location.unknown():
# Convert input arguments to MLIR arguments
exe_args, func_types = self.generate_mlir_function_types(
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
funcBody, function_name, args, kwargs, args_spec
)
@ -1476,7 +1480,7 @@ class BaseDSL:
if self.device_compilation_only:
return kernel_operands, kernel_arg_types, kernel_arg_attrs
kernel_operands, kernel_arg_types, kernel_arg_attrs = (
kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = (
self._generate_jit_func_args(
kernel_func, kernel_name, args, kwargs, args_spec, is_host=False
)
@ -1586,12 +1590,14 @@ class BaseDSL:
if self.device_compilation_only:
log().debug("Generating cuda-python arguments")
# Convert input arguments to MLIR arguments
self.exe_args, kernel_types = self.generate_mlir_function_types(
funcBody,
kernel_name,
canonicalized_args,
canonicalized_kwargs,
args_spec,
self.exe_args, kernel_types, _ = (
self.generate_mlir_function_types(
funcBody,
kernel_name,
canonicalized_args,
canonicalized_kwargs,
args_spec,
)
)
helper = kernelGenHelper()

View File

@ -78,11 +78,9 @@ def detect_gpu_arch(prefix):
major, minor = arch
suffix = ""
if major >= 9 and minor >= 0:
if major >= 9:
suffix = "a"
elif minor != 0:
# e.g sm_86, belong with sm_80 family
minor = 0
return f"sm_{major}{minor}{suffix}"

View File

@ -12,23 +12,24 @@
"""
This module provides jit executor related classes
"""
import io
import inspect
import ctypes
import numpy as np
import inspect
import io
from typing import get_origin
import numpy as np
# MLIR modules imports
from .._mlir import ir
# Local modules imports
from .utils.timer import timer
from .utils.logger import log
from . import typing as t
from .common import DSLRuntimeError
from .runtime import cuda as cuda_helpers
from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr
from .typing import get_c_pointers
from . import typing as t
# MLIR modules imports
from .._mlir import ir
from .utils.logger import log
from .utils.timer import timer
class CudaSingleModule:
@ -64,6 +65,7 @@ class JitExecutor:
self.args_spec = args_spec
self.function_name = function_name
if args_spec is not None:
self.original_args_spec = args_spec
self.args_spec = self.filter_runtime_arg_spec(args_spec)
# cuda kernels
self.cuda_modules = cuda_modules
@ -135,6 +137,29 @@ class JitExecutor:
for module in set(cuda_modules):
cuda_helpers.unload_cubin_module(module)
def get_constexpr_args(self) -> list[dict[str, int | str]]:
"""
This function returns the constexpr args that have been pruned from the original function signature.
The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
:return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
:rtype: list[dict[str, int | str]]
"""
if self.original_args_spec is None:
return list()
constexpr_args = list()
for i, arg_name in enumerate(self.original_args_spec.args):
if arg_name not in self.args_spec.args:
constexpr_args.append({"argument_index": i, "argument_name": arg_name})
if self.original_args_spec.kwonlyargs:
for kwarg in self.original_args_spec.kwonlyargs:
if kwarg not in self.args_spec.kwonlyargs:
constexpr_args.append(
{"argument_index": None, "argument_name": kwarg}
)
return constexpr_args
def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec):
"""
This function is the prune version of `generate_mlir_function_types` which only generates execution args
@ -175,6 +200,7 @@ class JitExecutor:
)
exe_args = []
adapted_args = []
input_args = rectified_args + list(rectified_kwargs.values())
input_arg_names = args_spec.args + args_spec.kwonlyargs
for arg, arg_name in zip(input_args, input_arg_names):
@ -193,13 +219,16 @@ class JitExecutor:
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
if adapter:
arg = adapter(arg)
adapted_args.append(arg)
exe_args.extend(get_c_pointers(arg))
return exe_args
return exe_args, adapted_args
def __call__(self, *args, **kwargs):
exe_args = self.generate_execution_args(args, kwargs, self.args_spec)
exe_args, adapted_args = self.generate_execution_args(
args, kwargs, self.args_spec
)
self.run_compiled_program(exe_args)

View File

@ -46,29 +46,75 @@ from .._mlir.dialects import arith, math
@runtime_checkable
class DynamicExpression(Protocol):
"""
This is a protocol class that provides a common interface
to generate user-defined dynamic expressions.
"""Protocol defining the interface for object holding dynamic values in the DSL.
The DSL checks this protocol to determine if a class is a dynamic expression (SSA value) or not.
This protocol enables classes to represent dynamic values in the DSL. Classes implementing
this protocol can be used in JIT-compiled functions and dynamic value generation.
It is required for custom data types to work correctly with following JIT features:
* as function argument to call another JIT function from JIT function
* as return value from JIT function
* for constructions like if-else, while-loop, etc.
:param value: The MLIR operation result value to initialize the object with
:type value: ir.Value
**Required Methods**
* ``__extract_mlir_values__``: Extract MLIR values from the object
* ``__new_from_mlir_values__``: Create new instance from MLIR values
**Implementation Example**
To implement a custom data type that works with the DSL:
.. code-block:: python
class CustomData(metaclass=DslType):
def __init__(self, int_value):
self.int_value = int_value
def __extract_mlir_values__(self):
return [self.int_value]
def __new_from_mlir_values__(self, values):
return CustomData(values[0])
**Usage in JIT Functions**
When used in JIT-compiled functions, the DSL automatically extracts MLIR values:
.. code-block:: python
@jit
def caller():
x = CustomData(1)
return foo(x)
This generates MLIR like:
.. code-block:: mlir
func @caller() -> i32 {
%0 = func.call @foo(%arg0) : (i32) -> i32
return %0 : i32
}
"""
def __extract_mlir_values__(self):
"""
Generate a dynamic expression for the current object.
"""Extract MLIR values from this object.
:return: List of MLIR values
:return: List of MLIR values representing this object's data
:rtype: List[ir.Value]
"""
raise NotImplementedError
def __new_from_mlir_values__(self, values):
"""
Create a new object from MLIR values.
"""Create a new instance from MLIR values.
:param values: List of MLIR values
:param values: List of MLIR values to construct the object from
:type values: List[ir.Value]
:return: A new instance of the class that implements this protocol
:return: New instance of the implementing class
:rtype: Any
"""
raise NotImplementedError
@ -77,50 +123,73 @@ class DynamicExpression(Protocol):
@runtime_checkable
class JitArgument(Protocol):
"""
This is a protocol class that provides a common interface
for JIT function arguments generation for Python to call JIT functions.
Protocol class defining the interface for JIT function argument generation.
The DSL checks this protocol to determine if a class is capable of providing information
needed for generating JIT function arguments.
This protocol enables classes to provide the necessary information for generating
JIT function arguments and allow the DSL JIT executor to call JIT compiled functions.
See breakdowns below for JitArgument protocol based JIT function calls.
**Required Methods**
* ``__c_pointers__``: Returns ctypes pointers for runtime execution
* ``__get_mlir_types__``: Returns MLIR types for function definition
* ``__new_from_mlir_values__``: Creates new instances from MLIR values
**Example**
.. code-block:: python
class CustomData:
def __init__(self, int_value, ...):
self.int_value = int_value
...
def __c_pointers__(self):
return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...]
def __get_mlir_types__(self):
return [ir.IntegerType.get(32), ...]
def __new_from_mlir_values__(self, values):
return CustomData(values[0], ...)
@jit
def foo(x: CustomData):
return x.int_value + 1
a = x.int_value + 1
...
# Emit: `%c0 = arith.constant(1, i32)`
c1 = const(1, Int32)
# `c1` tracks `%c0` defined outside of function body of `foo`
# `%c0` can't be used directly in function body of `foo`
x = CustomData(c1, ...)
# `CustomData` is an argument of `foo`
foo(CustomData(1, ...))
When called like ``y = foo(x)``, the following steps occur:
1. JIT compiler generates MLIR function definition using ``__get_mlir_types__``:
1. JIT compiler generates MLIR function definition using ``__get_mlir_types__``
.. code-block:: mlir
func @foo(%arg0: i32, ...) -> i32 {
func.func @foo(%arg0: i32, ...) {
...
return
}
2. Function is traced in Python, wrapping MLIR values with ``__new_from_mlir_values__``:
2. JIT function can't use values from Python, so it needs to reconstruct the object from
MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`.
Following code demonstrates how JIT compiler reconstructs the object and pass to Python.
.. code-block:: python
# Implementation of IR tracing
new_x = CustomData(ir.Value(%arg0), ...)
y = foo(new_x)
# `x.int_value` is %arg0 rather than `c1` defined outside
# `x.int_value` is %arg0 rather than `c1` defined by Python.
3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__``:
3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__``
pointing to the underlying data object passing to JIT compiled function.
.. code-block:: python
jit_engine.invoke(foo, concat([x.__c_pointers__(), ...]))
jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...]))
"""
def __c_pointers__(self):
@ -224,47 +293,6 @@ class DslType(type):
:property mlir_type: Returns the corresponding MLIR type for this DSL type
:type mlir_type: Any
**Examples**
Define a custom data type:
.. code-block:: python
class CustomData(metaclass=DslType, ...):
def __init__(self, int_value, ...):
self.int_value = int_value
...
def __str__(cls):
return "CustomData[int, ...]"
def __c_pointers__(self):
return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...]
def __get_mlir_types__(self):
return [_T.i32(), ...]
def __extract_mlir_values__(self):
return [self.int_value, ...]
def __new_from_mlir_values__(self, values):
return CustomData(values[0], ...)
For JIT function calls, MLIR values are extracted with ``__extract_mlir_values__``:
.. code-block:: python
@jit
def caller():
x = CustomData(1, ...)
return foo(x)
.. code-block:: mlir
func @caller() -> i32 {
%0 = func.call @foo(%arg0, ...) : (i32, ...) -> i32
return %0 : i32
}
"""
_is_abstract: bool
@ -946,9 +974,12 @@ class Numeric(metaclass=NumericMeta, is_abstract=True):
:return: The result of the logical not operation
:rtype: Boolean
"""
ty = type(self)
zero_val = arith.constant(ty.mlir_type, ty.zero)
return self.__eq__(ty(zero_val), loc=loc, ip=ip)
if isinstance(self.value, (int, float, bool)):
return not self.value
else:
ty = type(self)
zero_val = arith.constant(ty.mlir_type, ty.zero)
return self.__eq__(ty(zero_val), loc=loc, ip=ip)
def __dsl_and__(self, other, *, loc=None, ip=None):
"""DSL implementation of Python's `and` operator.
@ -1057,6 +1088,15 @@ class Numeric(metaclass=NumericMeta, is_abstract=True):
],
)
def __index__(self):
if isinstance(self.value, (int, float, bool)):
return self.value
else:
raise DSLRuntimeError(
f"'{type(self.value)}' object cannot be interpreted as an integer",
suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator",
)
def __neg__(self, *, loc=None, ip=None):
if isinstance(self, (bool, int, float)):
return type(self)(-self.value) # type: ignore
@ -1813,7 +1853,7 @@ class IRVariadic:
def __init__(self, operands):
"""
Create a list of variadic operands. `operands` must be SSA values.
Create a list of variadic operands. `operands` must be dynamic values.
"""
self.operands = operands

View File

@ -68,7 +68,9 @@ from .core import (
select,
front,
is_major,
leading_dim,
find,
find_if,
coalesce,
group_modes,
cosize,
@ -221,7 +223,9 @@ __all__ = [
"select",
"front",
"is_major",
"leading_dim",
"find",
"find_if",
"coalesce",
"group_modes",
"cosize",

View File

@ -25,12 +25,13 @@ __all__ = [
#
# mbar.py
#
"mbarrier_init_arrive_cnt",
"mbarrier_init",
"mbarrier_init_fence",
"mbarrier_init_tx_bytes",
"mbarrier_arrive_and_expect_tx",
"mbarrier_expect_tx",
"mbarrier_wait",
"mbarrier_try_wait",
"conditional_mbarrier_try_wait",
"mbarrier_conditional_try_wait",
"mbarrier_arrive",
#
# nvvm_wrappers.py
@ -51,6 +52,7 @@ __all__ = [
"shuffle_sync_down",
"shuffle_sync_bfly",
"barrier",
"barrier_arrive",
"sync_threads",
"sync_warp",
"fence_acq_rel_cta",

View File

@ -69,7 +69,16 @@ def elect_one(*, loc=None, ip=None) -> IfOpRegion:
pass
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
is_thread_leader = nvvm.elect_sync(T.bool())
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)

View File

@ -8,6 +8,7 @@
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Optional
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
@ -26,7 +27,7 @@ from ...impl_utils import check_value_in
@dsl_user_op
def mbarrier_init_arrive_cnt(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None:
def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None:
"""
Initializes a mbarrier with the specified thread arrival count.
@ -46,16 +47,25 @@ def mbarrier_init_fence(*, loc=None, ip=None) -> None:
A fence operation that applies to the mbarrier initializations.
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
@dsl_user_op
def mbarrier_init_tx_bytes(
def mbarrier_arrive_and_expect_tx(
mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
) -> None:
"""
Initializes a mbarrier with the specified number of transaction bytes.
Arrives on a mbarrier and expects a specified number of transaction bytes.
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
@ -66,7 +76,16 @@ def mbarrier_init_tx_bytes(
SMEM.
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
@ -91,6 +110,56 @@ def mbarrier_init_tx_bytes(
)
@dsl_user_op
def mbarrier_expect_tx(
mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
) -> None:
"""
Expects a specified number of transaction bytes without an arrive.
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
:param bytes: The number of transaction bytes
:type bytes: Int
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
the mbarrier is converted to a remote address in the peer CTA's
SMEM.
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
mbar_llvm_ptr = nvvm.mapa(
mbar_llvm_ptr.type,
mbar_llvm_ptr,
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
space = nvvm.MBarrierSpaceKind.CLUSTER
else:
space = nvvm.MBarrierSpaceKind.CTA
nvvm.mbarrier_txn(
mbar_llvm_ptr,
Int32(bytes).ir_value(loc=loc, ip=ip),
kind=nvvm.MBarrierTxnKind.EXPECT_TX,
space=space,
loc=loc,
ip=ip,
)
@dsl_user_op
def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
"""
@ -102,7 +171,16 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
:type phase: Int
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
timeout_ns = 10000000
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
@ -129,7 +207,16 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo
:rtype: Boolean
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
return Boolean(
nvvm.mbarrier_wait_parity(
@ -144,7 +231,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo
@dsl_user_op
def conditional_mbarrier_try_wait(
def mbarrier_conditional_try_wait(
cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None
) -> Boolean:
"""
@ -159,7 +246,16 @@ def conditional_mbarrier_try_wait(
:rtype: Boolean
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
return if_generate(
cond,
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
@ -171,7 +267,11 @@ def conditional_mbarrier_try_wait(
@dsl_user_op
def mbarrier_arrive(
mbar_ptr: Pointer, peer_cta_rank_in_cluster: Int = None, *, loc=None, ip=None
mbar_ptr: Pointer,
peer_cta_rank_in_cluster: Optional[Int] = None,
*,
loc=None,
ip=None,
) -> None:
"""
Arrives on an mbarrier.
@ -185,7 +285,16 @@ def mbarrier_arrive(
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
mbar_llvm_ptr.type,

View File

@ -225,6 +225,25 @@ def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> No
barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
)
@dsl_user_op
def barrier_arrive(
*, barrier_id=None, number_of_threads=None, loc=None, ip=None
) -> None:
if barrier_id is not None:
barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip)
if number_of_threads is None:
raise ValueError(
"barrier_arrive needs pass number_of_threads to arrive the barrier",
)
number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip)
nvvm.barrier_arrive(
barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
)
@dsl_user_op
def sync_threads(*, loc=None, ip=None) -> None:
"""
@ -545,3 +564,20 @@ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
# TODO: add `fastmath` flag for this op
@dsl_user_op
def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
LOG2_E = 1.4426950408889634
return exp2(a * LOG2_E, loc=loc, ip=ip)
# TODO: add `fastmath` flag for this op
@dsl_user_op
def exp_packed_f32x2(
a: Tuple[Float32, Float32], *, loc=None, ip=None
) -> Tuple[Float32, Float32]:
LOG2_E = Float32(1.4426950408889634)
b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip)
return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip)

File diff suppressed because it is too large Load Diff

View File

@ -26,7 +26,7 @@ __all__ = [
#
# helpers.py
#
"make_tma_tile_atom",
"make_tiled_tma_atom",
"tma_partition",
"create_tma_multicast_mask",
"prefetch_descriptor",

View File

@ -127,7 +127,12 @@ class CopyBulkTensorTileG2SOp(CopyOp):
cta_group: CtaGroup = CtaGroup.ONE
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
admissible_archs = [
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
]
def __post_init__(self) -> None:
if not isinstance(self.cta_group, CtaGroup):
@ -159,7 +164,7 @@ class CopyBulkTensorTileG2SOp(CopyOp):
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "CopyBulkTensorTileG2SNonExecTrait":
raise NotImplementedError(
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
)
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
@ -224,7 +229,12 @@ class CopyBulkTensorTileG2SMulticastOp(CopyOp):
cta_group: CtaGroup = CtaGroup.ONE
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
admissible_archs = [
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
]
def __post_init__(self):
if not isinstance(self.cta_group, CtaGroup):
@ -256,7 +266,7 @@ class CopyBulkTensorTileG2SMulticastOp(CopyOp):
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "CopyBulkTensorTileG2SMulticastNonExecTrait":
raise NotImplementedError(
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
)
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
@ -326,7 +336,12 @@ class CopyBulkTensorTileS2GOp(CopyOp):
This Operation uses TMA in the ``.tile`` mode.
"""
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
admissible_archs = [
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
]
def __post_init__(self):
# Arch verification
@ -345,7 +360,7 @@ class CopyBulkTensorTileS2GOp(CopyOp):
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "CopyBulkTensorTileS2GTrait":
raise NotImplementedError(
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
)

View File

@ -29,14 +29,14 @@ from .copy import (
@dsl_user_op
def make_tma_tile_atom(
def make_tiled_tma_atom(
op: Union[
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
CopyBulkTensorTileS2GOp,
],
gmem_tensor: Tensor,
smem_layout: Layout,
smem_layout: Union[Layout, core.ComposedLayout],
cta_tiler: Tiler,
num_multicast: int = 1,
*,
@ -45,7 +45,7 @@ def make_tma_tile_atom(
ip=None,
) -> Tuple[core.CopyAtom, Tensor]:
"""
Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from and SMEM
Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM
buffer with the given Layout.
Given
@ -71,7 +71,7 @@ def make_tma_tile_atom(
:param gmem_tensor: The GMEM tensor involved in the Copy
:type gmem_tensor: Tensor
:param smem_layout: The SMEM layout to construct the Copy Atom for
:type smem_layout: Layout
:type smem_layout: Union[Layout, core.ComposedLayout]
:param cta_tiler: The CTA Tiler to use
:type cta_tiler: Tiler
:param num_multicast: The multicast factor
@ -94,6 +94,12 @@ def make_tma_tile_atom(
ip=ip,
)
# Wrap smem_layout in a composed layout to make it a TMA-friendly layout
if isinstance(smem_layout, Layout):
smem_layout = core.make_composed_layout(
core.make_swizzle(0, 4, 3), 0, smem_layout
)
if isinstance(op, CopyBulkTensorTileG2SOp):
if num_multicast != 1:
raise ValueError(

View File

@ -34,7 +34,7 @@ from .cpasync.copy import (
@dsl_user_op
def make_tma_tile_atom_A(
def make_tiled_tma_atom_A(
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
gmem_tensor: Tensor,
smem_layout: Layout,
@ -46,6 +46,51 @@ def make_tma_tile_atom_A(
loc=None,
ip=None,
) -> Tuple[core.CopyAtom, Tensor]:
"""
Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation
accounting for the MK projections of the TiledMMA for A tensor loads.
Given
- a GMEM tensor
- a SMEM layout
- a MMA Tiler
- a TiledMma
- a Cluster-level shape
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
"TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided
layout and consistent with the provided Tiler & tiled_mma (considering the M-mode & K-mode).
The Cluster-level shape is used to determine the multicast factor across the N-mode for A tensor loads.
This function returns two results:
1. the Copy Atom
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned
similarly to any other CuTe tensors using the algebra.
:param op: The Copy Operation to construct an Atom for
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp]
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
:type gmem_tensor: Tensor
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
:type smem_layout: Layout
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
:type mma_tiler_mnk: Shape
:param tiled_mma: The TiledMMA that will consume the load as operands
:type tiled_mma: core.TiledMma
:param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions
:type cluster_shape_vmnk: Shape
:param internal_type: An optional parameter for the internal data type to when element
type does not match the copy type
:type internal_type: Type[Numeric]
:return: A copy atom for this operation and the associated TMA coord tensor
:rtype: Tuple[core.CopyAtom, Tensor]
"""
if internal_type is not None:
if not isinstance(internal_type, NumericMeta):
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
@ -54,7 +99,7 @@ def make_tma_tile_atom_A(
op,
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
"op",
"make_tma_tile_atom_A",
"make_tiled_tma_atom_A",
)
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
@ -94,7 +139,7 @@ def make_tma_tile_atom_A(
@dsl_user_op
def make_tma_tile_atom_B(
def make_tiled_tma_atom_B(
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
gmem_tensor: Tensor,
smem_layout: Layout,
@ -106,6 +151,51 @@ def make_tma_tile_atom_B(
loc=None,
ip=None,
) -> Tuple[core.CopyAtom, Tensor]:
"""
Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation
accounting for the NK projections of the TiledMMA for B tensor loads.
Given
- a GMEM tensor
- a SMEM layout
- a MMA Tiler
- a TiledMma
- a Cluster-level shape
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
"TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided
layout and consistent with the provided Tiler & tiled_mma (considering the N-mode & K-mode).
The Cluster-level shape is used to determine the multicast factor across the M-mode for B tensor loads.
This function returns two results:
1. the Copy Atom
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned
similarly to any other CuTe tensors using the algebra.
:param op: The Copy Operation to construct an Atom for
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp]
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
:type gmem_tensor: Tensor
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
:type smem_layout: Layout
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
:type mma_tiler_mnk: Shape
:param tiled_mma: The TiledMMA that will consume the load as operands
:type tiled_mma: core.TiledMma
:param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions
:type cluster_shape_vmnk: Shape
:param internal_type: An optional parameter for the internal data type to when element
type does not match the copy type
:type internal_type: Type[Numeric]
:return: A Copy Atom for this Operation and the associated TMA tensor
:rtype: Tuple[core.CopyAtom, Tensor]
"""
if internal_type is not None:
if not isinstance(internal_type, NumericMeta):
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
@ -114,7 +204,7 @@ def make_tma_tile_atom_B(
op,
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
"op",
"make_tma_tile_atom_B",
"make_tiled_tma_atom_B",
)
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
@ -154,6 +244,6 @@ def make_tma_tile_atom_B(
__all__ = [
"make_tma_tile_atom_A",
"make_tma_tile_atom_B",
"make_tiled_tma_atom_A",
"make_tiled_tma_atom_B",
]

View File

@ -98,7 +98,10 @@ class _LdBase(CopyOp):
repeat: Repetition = Repetition.x1
pack: Pack = Pack.NONE
admissible_archs = ["sm_100a"]
admissible_archs = [
"sm_100a",
"sm_100f",
]
def __post_init__(self) -> None:
# Arch verification
@ -284,7 +287,10 @@ class _StBase(CopyOp):
repeat: Repetition
unpack: Unpack = Unpack.NONE
admissible_archs = ["sm_100a"]
admissible_archs = [
"sm_100a",
"sm_100f",
]
def __post_init__(self) -> None:
# Arch verification

View File

@ -136,7 +136,10 @@ class MmaOp(MmaOp):
a_major_mode: OperandMajorMode
b_major_mode: OperandMajorMode
admissible_archs = ["sm_100a"]
admissible_archs = [
"sm_100a",
"sm_100f",
]
def __post_init__(self) -> None:
# Verify arch

View File

@ -339,10 +339,10 @@ class MmaF8Op(MmaOp):
"expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
)
# Accumulator data type verification
if self.acc_dtype != Float32:
if self.acc_dtype not in [Float16, Float32]:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be Float32",
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
)
# Verify the instruction shape
instruction_k = 32

View File

@ -20,6 +20,7 @@ from typing import Union
from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
from cutlass.base_dsl.dsl import is_dynamic_expression
from cutlass.cutlass_dsl import TensorFormat, JitArgAdapterRegistry
# Local modules imports
@ -45,7 +46,8 @@ from .typing import (
BFloat16,
Float8E5M2,
)
from .core import find, _Tensor as CoreTensor
from . import core
from .core import _Tensor as CoreTensor
class _Pointer(Pointer):
@ -131,6 +133,9 @@ class _Pointer(Pointer):
def memspace(self):
return self._addr_space
def align(self, min_align: int, *, loc=None, ip=None) -> Pointer:
raise NotImplementedError("align is not supported in runtime")
def verify(self, expected_py_type):
if expected_py_type is Pointer:
return True
@ -361,7 +366,7 @@ class _Tensor(Tensor):
* If nested leading dimensions are found, returns a tuple of indices
* If no leading dimension is found, returns None
"""
return find(1, self.stride, exclude_when=(1, self.shape))
return core.leading_dim(self.shape, self.stride)
def fill(self, value: Numeric):
raise TypeError(f"fill function is not supported in runtime")
@ -479,12 +484,8 @@ class TensorAdapter:
Convert a DLPack protocol supported tensor/array to a cute tensor.
"""
# Need reference these capsules to avoid being garbage collected
tensor_capsules = []
def __init__(self, arg):
self._arg = from_dlpack(arg).mark_layout_dynamic()
self.tensor_capsules.append(self._arg)
def __new_from_mlir_values__(self, values):
return self._arg.__new_from_mlir_values__(values)

View File

@ -9,29 +9,26 @@
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import random
import numpy as np
import functools
import hashlib
from cutlass.cutlass_dsl import (
const,
T,
CuTeDSL,
BaseDSL,
t,
Constexpr,
detect_gpu_arch,
)
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.ir as ir
from cutlass._mlir.dialects import nvvm, cf, vector, builtin
from cutlass.cute import core
from cutlass.cute import nvgpu
from typing import Type
import inspect
import logging
import os
from enum import Enum
from inspect import isclass
from itertools import product
from time import time
from typing import Any, Callable, Dict, List, Optional, Type, Union
import cuda.bindings.driver as cuda_driver
import cuda.bindings.runtime as cuda_runtime
import numpy as np
import cutlass._mlir.ir as ir
import cutlass.base_dsl.jit_executor
import cutlass.cute as cute
from cutlass._mlir.dialects import builtin, cf, nvvm, vector
from cutlass.cute import core, nvgpu
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t
def assert_(cond, msg=None):
@ -248,9 +245,10 @@ def sample_pytest(rand_cfg=None):
import functools
import os
import random
import pytest
import sys
import pytest
seed, sample_ratio = rand_cfg
random.seed(seed)
@ -270,3 +268,311 @@ def sample_pytest(rand_cfg=None):
return wrapper
return decorator
#########################################
# Benchmarking utilities
#########################################
class JitArguments:
"""
A type to hold both args and kwargs for passing to a kernel while benchmarking.
"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def _cuda_success(
err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str
):
"""
Helper function to check CUDA API errors.
"""
if isinstance(err, tuple):
_cuda_success(err[0], message)
elif isinstance(err, cuda_runtime.cudaError_t):
error_message = cuda_runtime.cudaGetErrorString(err)[1].decode("utf-8")
if err != cuda_runtime.cudaError_t.cudaSuccess:
raise RuntimeError(f"{message} : {error_message}")
elif isinstance(err, cuda_driver.CUresult):
if err != cuda_driver.CUresult.CUDA_SUCCESS:
error_message = cuda_driver.cuGetErrorString(err)[1].decode("utf-8")
raise RuntimeError(f"{message} : {error_message}")
else:
raise TypeError(
f"{err} is an unexpected type : it should be a cudaError_t or CUresult"
)
def _does_kernel_use_stream(
kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs
):
"""
This function checks if the kernel uses the provided non-default stream.
It does this by capturing the stream and then checking if any kernels were launched.
:param kernel: The kernel to check
:type kernel: Callable
:param stream: The stream to check
:type stream: cuda_driver.CUstream
:return: True if the kernel uses the stream, False otherwise
:rtype: bool
"""
assert int(stream) != int(
cuda_driver.CUstream_flags.CU_STREAM_DEFAULT
), "Stream must be a non-default stream"
err = cuda_runtime.cudaStreamBeginCapture(
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
)
_cuda_success(err, "Error on stream capture")
kernel(*args, **kwargs)
err, graph = cuda_runtime.cudaStreamEndCapture(stream)
_cuda_success(err, "Error on stream capture")
# Get number of nodes in warmup graph to check it matches what is expected
err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(graph)
_cuda_success(err, "Error on querying graph")
return num_nodes > 0
def benchmark(
callable: Callable,
*,
warmup_iterations: int = 10,
profiling_iterations: int = 100,
stream: Optional[cuda_driver.CUstream] = None,
kernel_arguments: Optional[JitArguments] = None,
workspace_generator: Optional[Callable[[], JitArguments]] = None,
workspace_count: int = 1,
use_cuda_graphs: bool = False,
) -> float:
"""Benchmarks a callable function with the specified parameters.
For example,
.. code-block:: python
from cutlass.cute.testing import benchmark
@cute.jit
def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda_driver.CUstream):
# contents of the function
pass
time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream)
warmup_iterations=10, profiling_iterations=100
stream=stream)
To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator
parameters to cycle through a number of different workspaces.
.. code-block:: python
from cutlass.cute.testing import benchmark
@cute.jit
def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
# contents of the function
pass
def workspace_generator():
# create a, b, and c
return JitArguments(a, b, c)
time_us = benchmark(user_function,
workspace_generator=workspace_generator,
workspace_count=10,
warmup_iterations=10000,
profiling_iterations=1000)
To benchmark you may always configure the function being profiled (callable), the warmup iterations, and
the number of profiling iterations.
Whenever the kernel being benchmarked runs in a non-default stream, the stream must be provided through the stream parameter.
To use CUDA graphs, the callable must be a compiled @cute.jit annotated function.
When using CUDA graphs, the kernel must be launched in a non-default stream.
:param callable: The function to benchmark
:type callable: Callable
:param warmup_iterations: Number of warmup iterations, defaults to 10
:type warmup_iterations: int, optional
:param profiling_iterations: Number of benchmark iterations, defaults to 100
:type profiling_iterations: int, optional
:param stream: Stream kernel is launched in, defaults to CUDA stream default
:type stream: CUstream, None
:param kernel_arguments: Kernel arguments to launch callable with, defaults to None
:type kernel_arguments: JitArguments, None
:param workspace_generator: Function that returns kernel arguments, defaults to None
:type workspace_generator: Callable
:param workspace_count: Number of workspaces (arguments) to loop through, looping through enough workspaces will keep the L2 cache cold
:type workspace_count: int, optional
:param use_cuda_graphs: Whether to use cuda graphs, defaults to False
:type use_cuda_graphs: bool, optional
:return: The benchmark time in microseconds
:rtype: float
"""
if stream is None:
stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT)
if workspace_count < 1:
raise ValueError("workspace_count must be at least 1")
time_us = float("nan")
if workspace_generator == None:
# If no workspace generator is provided, we need a single workspace
if workspace_count != 1:
raise ValueError("Need a single workspace if not providing a generator")
# If no workspace generator is provided, we need a kernel_argument
if kernel_arguments == None:
raise ValueError(
"Please pass a kernel argument if not providing a generator"
)
workspace_generator = lambda: kernel_arguments
workspaces = [workspace_generator() for _ in range(workspace_count)]
for workspace in workspaces:
if type(workspace) != JitArguments:
raise TypeError(
"workspace_generator and/or kernel_arguments should use JitArguments type"
)
def _loop_and_call_kernel(iterations: int, workspace_index: int = 0):
for _ in range(iterations):
current_workspace = workspaces[workspace_index]
callable(*current_workspace.args, **current_workspace.kwargs)
workspace_index = (workspace_index + 1) % workspace_count
return workspace_index
# Create CUDA events for timing
err, start_event = cuda_driver.cuEventCreate(
cuda_driver.CUevent_flags.CU_EVENT_DEFAULT
)
_cuda_success(err, "Error on creating event")
err, end_event = cuda_driver.cuEventCreate(
cuda_driver.CUevent_flags.CU_EVENT_DEFAULT
)
_cuda_success(err, "Error on creating event")
elapsed_time = float("nan")
if use_cuda_graphs:
# Check if the callable is a JitExecutor
if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor):
raise TypeError("Function must be precompiled to be used with CUDA Graphs")
# Check if the stream is a non-default stream
if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT):
raise ValueError(
"Measuring with CUDA Graphs requires executing in a non-default stream"
)
workspace_index = 0
# Capture warmup graph
err = cuda_runtime.cudaStreamBeginCapture(
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
)
_cuda_success(err, "Error on stream capture")
workspace_index = _loop_and_call_kernel(warmup_iterations)
err, gwarm = cuda_runtime.cudaStreamEndCapture(stream)
_cuda_success(err, "Error on stream capture")
# Get number of nodes in warmup graph to check it matches what is expected
err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm)
_cuda_success(err, "Error on querying graph")
# Assertion is >= since we may launch multiple kernels in one host function
if num_nodes < warmup_iterations:
raise ValueError(
f"CUDA stream passed to benchmark does not match the stream the kernel was launched in"
)
# Capture profiling graph
err = cuda_runtime.cudaStreamBeginCapture(
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
)
_cuda_success(err, "Error on stream capture")
_loop_and_call_kernel(profiling_iterations, workspace_index)
err, gprofile = cuda_runtime.cudaStreamEndCapture(stream)
_cuda_success(err, "Error on stream capture")
# Instantiate graphs
err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0)
_cuda_success(err, "Error on graph instantiation")
err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0)
_cuda_success(err, "Error on graph instantiation")
# Launch warmup graph
err = cuda_runtime.cudaGraphLaunch(gwarm, stream)
_cuda_success(err, "Error on graph launch")
# Record start time
err = cuda_driver.cuEventRecord(start_event, stream)
_cuda_success(err, "Error on recording event")
# Launch profiling graph
err = cuda_runtime.cudaGraphLaunch(gprofile, stream)
_cuda_success(err, "Error on graph launch")
# Record end time
err = cuda_driver.cuEventRecord(end_event, stream)
_cuda_success(err, "Error on recording event")
err = cuda_driver.cuEventSynchronize(end_event)
_cuda_success(err, "Error on synchronizing event")
# Get elapsed time
err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event)
_cuda_success(err, "Error on querying event")
# Destroy graphs
err = cuda_runtime.cudaGraphExecDestroy(gwarm)
_cuda_success(err, "Error on destroying graph")
err = cuda_runtime.cudaGraphExecDestroy(gprofile)
_cuda_success(err, "Error on destroying graph")
else:
if int(stream) != int(
cuda_driver.CUstream_flags.CU_STREAM_DEFAULT
) and not _does_kernel_use_stream(
callable, stream, *workspaces[0].args, **workspaces[0].kwargs
):
raise ValueError(
"CUDA stream passed to benchmark does not match the stream the kernel was launched in"
)
# Not using graphs
# Warmup
workspace_index = _loop_and_call_kernel(warmup_iterations)
# Record start event
err = cuda_driver.cuEventRecord(start_event, stream)
_cuda_success(err, "Error on recording event")
_loop_and_call_kernel(profiling_iterations, workspace_index)
# Record end event
err = cuda_driver.cuEventRecord(end_event, stream)
_cuda_success(err, "Error on recording event")
# Synchronize end event
err = cuda_driver.cuEventSynchronize(end_event)
_cuda_success(err, "Error on synchronizing event")
err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event)
_cuda_success(err, "Error on querying event")
# Destroy events
err = cuda_driver.cuEventDestroy(start_event)
_cuda_success(err, "Error on destroying event")
err = cuda_driver.cuEventDestroy(end_event)
_cuda_success(err, "Error on destroying event")
return elapsed_time / profiling_iterations * 1e3

View File

@ -68,6 +68,8 @@ class Pointer(ABC):
@property
def dtype(self) -> Type[Numeric]: ...
def align(self, min_align: int) -> "Pointer": ...
def __get_mlir_types__(self) -> List[ir.Type]: ...
def __extract_mlir_values__(self) -> List[ir.Value]: ...

View File

@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .helpers import (
Agent,
CooperativeGroup,
PipelineOp,
SyncObject,
MbarrierArray,
NamedBarrier,
TmaStoreFence,
PipelineUserType,
PipelineState,
make_pipeline_state,
pipeline_init_wait,
arrive,
arrive_unaligned,
wait,
wait_unaligned,
arrive_and_wait,
sync,
)
from .sm90 import (
PipelineAsync,
PipelineTmaAsync,
PipelineTmaMultiConsumersAsync,
PipelineTmaStore,
)
from .sm100 import (
PipelineTmaUmma,
PipelineAsyncUmma,
PipelineUmmaAsync,
)
__all__ = [
"Agent",
"CooperativeGroup",
"PipelineOp",
"SyncObject",
"MbarrierArray",
"NamedBarrier",
"TmaStoreFence",
"PipelineUserType",
"PipelineState",
"PipelineAsync",
"PipelineTmaAsync",
"PipelineTmaUmma",
"PipelineTmaMultiConsumersAsync",
"PipelineAsyncUmma",
"PipelineUmmaAsync",
"PipelineTmaStore",
]

View File

@ -0,0 +1,645 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union
import warnings
import cutlass.cute as cute
from cutlass.cutlass_dsl import Boolean, Int32, Int64, if_generate
from cutlass._mlir.dialects import llvm
import cutlass._mlir.dialects.cute as _cute_ir
##############################################################################
# Agent class
##############################################################################
class Agent(enum.Enum):
"""
Agent indicates what is participating in the pipeline synchronization.
"""
# Arbitrary grouping of N threads
Thread = enum.auto()
# Same as AsyncThread, but includes all threads in the block
ThreadBlock = enum.auto()
# Same as AsyncThread, but includes all threads in the cluster
ThreadBlockCluster = enum.auto()
class CooperativeGroup:
"""
CooperativeGroup contains size and alignment restrictions for an Agent.
"""
def __init__(self, agent: Agent, size: int = 1, alignment: int = 1):
if agent is Agent.Thread:
assert size > 0
if size == 32:
assert (
size == alignment
), "Error: Alignment does not match number of threads in a warp."
elif size == 128:
assert (
size == alignment
), "Error: Alignment does not match number of threads in a warpgroup."
elif agent is Agent.ThreadBlock:
raise NotImplementedError("Error: Not yet supported.")
elif agent is Agent.ThreadBlockCluster:
raise NotImplementedError("Error: Not yet supported.")
else:
# Should never reach this state
size = 0
if size <= 0:
raise ValueError(
"Error: The number of threads in a CooperativeGroup must be more than 0."
)
# Size indicates how many threads are participating in this CooperativeGroup
self.size = size
# Agent indicates the type of thread group
self.agent = agent
class PipelineOp(enum.Enum):
"""
PipelineOp assigns an operation to an agent corresponding to a specific hardware feature.
"""
# async-threads
AsyncThread = enum.auto()
# Blackwell (SM100a) MMA instruction
TCGen05Mma = enum.auto()
# Tensor Memory Accelerator load
TmaLoad = enum.auto()
# TMA Store consuming smem produced by AsyncThread
TmaStore = enum.auto()
# Composite of multiple PipelineOps
Composite = enum.auto()
def _get_pipeline_op(type_str):
return PipelineOp(type_str)
##############################################################################
# SyncObject class
##############################################################################
class SyncObject(ABC):
"""Abstract base class for hardware synchronization primitives.
This class defines the interface for different types of hardware synchronization
mechanisms including shared memory barriers, named barriers, and fences.
"""
@abstractmethod
def arrive(self) -> None:
pass
@abstractmethod
def wait(self) -> None:
pass
@abstractmethod
def arrive_and_wait(self) -> None:
pass
@abstractmethod
def arrive_and_drop(self) -> None:
pass
@abstractmethod
def get_barrier(self) -> Union[cute.Pointer, int, None]:
pass
@abstractmethod
def max(self) -> Union[int, None]:
pass
class MbarrierArray(SyncObject):
"""
MbarrierArray implements an abstraction for an array of smem barriers.
"""
def __init__(
self,
barrier_storage: cute.Pointer,
num_stages: int,
agent: tuple[PipelineOp, CooperativeGroup],
tx_count: int = 0,
) -> None:
self.barrier_storage = barrier_storage
self.tx_count = tx_count
self.num_stages = num_stages
self.op_type, self.cg = agent
self.arrive_count = self.cg.size
if self.num_stages <= 0:
raise ValueError("Error: Mbarrier stage count must be greater than 0.")
if self.arrive_count <= 0:
raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0:
raise ValueError(
"Error: Mbarrier tx count must not be less than 0 for TMA ops."
)
# Store mbarrier base pointer
self.mbarrier_base = self.barrier_storage
# Mbarrier initialization in constructor
self.mbarrier_init()
def recast_to_new_op_type(self, new_op_type: PipelineOp) -> "MbarrierArray":
"""
Creates a copy of MbarrierArray with a different op_type without re-initializing barriers
"""
# Create new instance without initialization
new_mbarrier_array = object.__new__(MbarrierArray)
# Copy all attributes directly
new_mbarrier_array.barrier_storage = self.barrier_storage
new_mbarrier_array.op_type = new_op_type
new_mbarrier_array.cg = self.cg
new_mbarrier_array.num_stages = self.num_stages
new_mbarrier_array.tx_count = self.tx_count
new_mbarrier_array.arrive_count = self.arrive_count
new_mbarrier_array.mbarrier_base = self.mbarrier_base
return new_mbarrier_array
# Mbarrier initialization
def mbarrier_init(self) -> None:
"""
Initializes an array of mbarriers using warp 0.
"""
def then_body():
for index in range(self.num_stages):
cute.arch.mbarrier_init(self.get_barrier(index), self.arrive_count)
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
if_generate(warp_idx == 0, then_body)
def arrive(
self,
index: int,
dst: int,
cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None,
) -> None:
"""Select the arrive corresponding to this MbarrierArray's PipelineOp.
:param index: Index of the mbarrier in the array to arrive on
:type index: int
:param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank.
When None, both ``TCGen05Mma`` and ``AsyncThread`` will arrive on their local mbarrier.
- For ``TCGen05Mma``, ``dst`` serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs
in the cluster with rank = 0, 1, and 3).
- For ``AsyncThread``, ``dst`` serves as a destination cta rank (e.g., 3 means threads will arrive on
the mbarrier with rank = 3 in the cluster).
:type dst: int | None
:param cta_group: CTA group for ``TCGen05Mma``, defaults to None for other op types
:type cta_group: ``cute.nvgpu.tcgen05.CtaGroup``, optional
"""
if self.op_type is PipelineOp.AsyncThread:
self.arrive_mbarrier(index, dst)
elif self.op_type is PipelineOp.TCGen05Mma:
assert (
cta_group is not None
), "Error: CTA group must be provided for TCGen05Mma."
self.arrive_tcgen05mma(index, dst, cta_group)
elif self.op_type in [PipelineOp.TmaLoad]:
self.arrive_and_expect_tx(index, self.tx_count)
else:
assert (
False
), f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}."
def arrive_mbarrier(self, index: int, dst_rank: Optional[int] = None) -> None:
if dst_rank is None:
cute.arch.mbarrier_arrive(self.get_barrier(index))
else:
cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank)
def arrive_tcgen05mma(
self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup
) -> None:
if mask is None:
with cute.arch.elect_one():
cute.nvgpu.tcgen05.commit(self.get_barrier(index))
else:
with cute.arch.elect_one():
cute.nvgpu.tcgen05.commit(self.get_barrier(index), mask, cta_group)
def arrive_and_expect_tx(self, index: int, tx_count: int) -> None:
with cute.arch.elect_one():
cute.arch.mbarrier_arrive_and_expect_tx(self.get_barrier(index), tx_count)
def try_wait(self, index: int, phase: int) -> Boolean:
return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase)
def wait(self, index: int, phase: int) -> None:
cute.arch.mbarrier_wait(self.get_barrier(index), phase)
def arrive_and_wait(
self,
index: int,
phase: int,
dst: int,
cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None,
) -> None:
arrive(index, dst, cta_group)
wait(index, phase)
def arrive_and_drop(self) -> None:
raise NotImplementedError("Error: Not yet supported.")
def get_barrier(self, index: int) -> cute.Pointer:
return self.mbarrier_base + index
def max(self) -> int:
# Transaction barriers have a maximum arrive count of 511 (2^9 - 1).
# Non-transaction barriers have a maximum arrive count of 1,048,575 (2^20 - 1).
return 511
def __extract_mlir_values__(self):
return [self.barrier_storage]
def __new_from_mlir_values__(self, values):
return MbarrierArray(
values[0], self.num_stages, (self.op_type, self.cg), self.tx_count
)
@dataclass(frozen=True)
class NamedBarrier(SyncObject):
"""
NamedBarrier is an abstraction for named barriers managed by hardware.
There are 16 named barriers available, with barrier_ids 0-15.
See the `PTX documentation <https://https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-bar>`__.
"""
barrier_id: int
num_threads: int
def __post_init__(self) -> None:
if self.barrier_id < 0 or self.barrier_id >= 16:
raise ValueError("Error: NamedBarrier ID must be between 0 and 16.")
if self.barrier_id == 0:
warnings.warn(
"NamedBarrier ID 0 is by other driver APIs (i.e. sync_threads()) and should not be used."
)
def arrive(self) -> None:
"""
The aligned flavor of arrive is used when all threads in the CTA will execute the
same instruction. See PTX documentation.
"""
cute.arch.barrier_arrive(
barrier_id=self.barrier_id, number_of_threads=self.num_threads
)
def arrive_unaligned(self) -> None:
"""
The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA.
"""
llvm.inline_asm(
None,
[Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()],
"barrier.arrive $0, $1;",
"r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
def wait(self) -> None:
"""
NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait.
If synchronizing two warps in a producer/consumer pairing, the arrive count would be
32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer
or consumer are counted for mbarriers, while all threads participating in the sync
are counted for NamedBarriers.
"""
warnings.warn(
"NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()."
)
self.arrive_and_wait()
def wait_unaligned(self) -> None:
warnings.warn(
"NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()."
)
llvm.inline_asm(
None,
[Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()],
"barrier.sync $0, $1;",
"r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
def arrive_and_wait(self) -> None:
cute.arch.barrier(
barrier_id=self.barrier_id, number_of_threads=self.num_threads
)
def arrive_and_drop(self) -> None:
raise NotImplementedError("Error: Not supported.")
def sync(self) -> None:
cute.arch.barrier(barrier_id=self.barrier_id)
def get_barrier(self) -> int:
return self.barrier_id
def max(self) -> int:
# Transaction barriers have a maximum arrive count of 4095 (2^12 - 1).
return 4095
class TmaStoreFence(SyncObject):
"""
TmaStoreFence is used for a multi-stage epilogue buffer.
"""
def __init__(self, num_stages: int = 0) -> None:
if num_stages <= 0:
raise ValueError("Mbarrier stage count must be greater than 0.")
self.num_stages = num_stages
def arrive(self) -> None:
cute.arch.cp_async_bulk_commit_group()
def wait(self) -> None:
cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True)
def arrive_and_wait(self) -> None:
self.arrive()
self.wait()
def arrive_and_drop(self) -> None:
raise NotImplementedError("Error: Not supported.")
# TmaStoreFence doesn't have mbarriers
def get_barrier(self) -> None:
assert (
False
), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier."
def max(self) -> None:
raise NotImplementedError("Error: Not supported.")
def tail(self) -> None:
cute.arch.cp_async_bulk_wait_group(0, read=True)
##############################################################################
# PipelineState class
##############################################################################
class PipelineUserType(enum.Enum):
Producer = enum.auto()
Consumer = enum.auto()
class PipelineState:
"""
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
"""
def __init__(self, stages: int, count, index, phase):
self._stages = stages
self._count = count
self._index = index
self._phase = phase
def clone(self) -> "PipelineState":
return PipelineState(self.stages, self._count, self.index, self.phase)
@property
def index(self) -> Int32:
return self._index
@property
def count(self) -> Int32:
return self._count
@property
def stages(self) -> int:
return self._stages
@property
def phase(self) -> Int32:
return self._phase
def reset_count(self):
self._count = Int32(0)
def advance(self):
self._index += 1
self._count += 1
def then_body(index, phase):
new_index = Int32(0)
new_phase = phase ^ 1
return new_index, new_phase
def else_body(index, phase):
return index, phase
self._index, self._phase = if_generate(
self._index == self.stages,
then_body,
else_body,
[self.index, self.phase],
[Int32, Int32],
)
def reverse(self):
self._index -= 1
self._count -= 1
def then_body(index, phase):
new_index = Int32(self.stages - 1)
new_phase = phase ^ 1
return new_index, new_phase
def else_body(index, phase):
return index, phase
self._index, self._phase = if_generate(
self._index == -1,
then_body,
else_body,
[self.index, self.phase],
[Int32, Int32],
)
def __get_mlir_types__(self):
return [self._count.type, self._index.type, self._phase.type]
def __extract_mlir_values__(self):
count = self._count
index = self._index
phase = self._phase
return [count.ir_value(), index.ir_value(), phase.ir_value()]
# This can be overridden by derived classes
def __new_from_mlir_values__(self, values):
return PipelineState(
self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
)
def make_pipeline_state(type: PipelineUserType, stages: int):
"""
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
"""
if type is PipelineUserType.Producer:
return PipelineState(
stages,
Int32(0),
Int32(0),
Int32(1),
)
elif type is PipelineUserType.Consumer:
return PipelineState(
stages,
Int32(0),
Int32(0),
Int32(0),
)
else:
assert (
False
), "Error: invalid PipelineUserType specified for make_pipeline_state."
##############################################################################
# Helper functions
##############################################################################
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
"""
Fences the mbarrier init and syncs the threadblock or cluster
"""
cute.arch.mbarrier_init_fence()
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
# If not using clusters, sync the threadblock
_sync(Agent.ThreadBlock)
else:
# If using clusters, sync the cluster
_sync(Agent.ThreadBlockCluster)
def _sync(group: Agent):
"""
Syncs all threads within an agent.
"""
if group is Agent.Thread:
raise NotImplementedError("Error: Not supported.")
elif group is Agent.ThreadBlock:
cute.arch.sync_threads()
elif group is Agent.ThreadBlockCluster:
cute.arch.cluster_arrive()
cute.arch.cluster_wait()
else:
assert (
False
), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer:
"""
Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment
"""
return cute.make_ptr(
Int64,
val.ir_value(),
mem_space=_cute_ir.AddressSpace.smem,
assumed_align=8,
)
# NamedBarrier free functions
def arrive(barrier_id: int, num_threads: int):
"""
The aligned flavor of arrive is used when all threads in the CTA will execute the
same instruction. See PTX documentation.
"""
cute.arch.barrier_arrive(barrier_id=barrier_id, number_of_threads=num_threads)
def arrive_unaligned(barrier_id: int, num_threads: int):
"""
The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA.
"""
llvm.inline_asm(
None,
[Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()],
"barrier.arrive $0, $1;",
"r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
def wait(barrier_id: int, num_threads: int):
"""
NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait.
If synchronizing two warps in a producer/consumer pairing, the arrive count would be
32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer
or consumer are counted for mbarriers, while all threads participating in the sync
are counted for NamedBarriers.
"""
warnings.warn(
"NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()."
)
arrive_and_wait()
def wait_unaligned(barrier_id: int, num_threads: int):
warnings.warn(
"NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()."
)
llvm.inline_asm(
None,
[Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()],
"barrier.sync $0, $1;",
"r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
def arrive_and_wait(barrier_id: int, num_threads: int):
cute.arch.barrier(barrier_id=barrier_id, number_of_threads=num_threads)
def sync(barrier_id: int = 0):
cute.arch.barrier(barrier_id=barrier_id)

View File

@ -0,0 +1,452 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union
import warnings
import cutlass.cute as cute
from cutlass.cutlass_dsl import Boolean, if_generate
from cutlass.pipeline import (
CooperativeGroup,
PipelineOp,
PipelineState,
pipeline_init_wait,
PipelineAsync,
)
##############################################################################
# Pipeline classes
##############################################################################
@dataclass(frozen=True)
class PipelineTmaUmma(PipelineAsync):
"""
PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops).
"""
is_leader_cta: bool
cta_group: cute.nvgpu.tcgen05.CtaGroup
@staticmethod
def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout):
"""
Computes a mask for signaling arrivals to multicasting threadblocks.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
)
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
)
block_in_cluster_coord_vmnk_peer = (
cta_in_cluster_coord_vmnk[0] ^ 1,
*cta_in_cluster_coord_vmnk[1:],
)
tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2
)
tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1
)
return (
tma_mcast_mask_a
| tma_mcast_mask_b
| tma_mcast_mask_a_peer
| tma_mcast_mask_b_peer
)
@staticmethod
def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout):
"""
Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders.
"""
bidx, bidy, _ = cute.arch.block_idx()
mma_coord_vmnk = (
bidx % cute.size(cta_layout_vmnk, mode=[0]),
bidx // cute.size(cta_layout_vmnk, mode=[0]),
bidy,
None,
)
return mma_coord_vmnk[0] == 0
@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
tx_count: int,
barrier_storage: cute.Pointer = None,
cta_layout_vmnk: Optional[cute.Layout] = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
"""
if not isinstance(barrier_storage, cute.Pointer):
raise ValueError(
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)
producer_type = PipelineOp.TmaLoad
consumer_type = PipelineOp.TCGen05Mma
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_full = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8), num_stages, producer, tx_count
)
sync_object_empty = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
# No mcast mask if not using clusters
producer_mask = None
# All threadblocks are leaders if not using clusters
is_leader_cta = True
else:
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
cta_group = (
cute.nvgpu.tcgen05.CtaGroup.ONE
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
else cute.nvgpu.tcgen05.CtaGroup.TWO
)
consumer_mask = producer_mask
pipeline_init_wait(cta_layout_vmnk)
return PipelineTmaUmma(
sync_object_full,
sync_object_empty,
num_stages,
producer_mask,
consumer_mask,
is_leader_cta,
cta_group,
)
def consumer_release(self, state: PipelineState):
"""
UMMA consumer release buffer empty, cta_group needs to be provided.
"""
self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group)
def producer_acquire(
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
):
"""
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
"""
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_empty.wait(state.index, state.phase),
)
if_generate(
self.is_leader_cta,
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
)
def producer_commit(self, state: PipelineState):
"""
TMA producer commit is a noop since TMA instruction itself updates the transaction count.
"""
pass
@dataclass(frozen=True)
class PipelineAsyncUmma(PipelineAsync):
"""
PipelineAsyncUmma is used for AsyncThread producers and UMMA consumers (e.g. Blackwell input fusion pipelines).
"""
cta_group: cute.nvgpu.tcgen05.CtaGroup
@staticmethod
def _compute_leading_cta_rank(cta_v_size):
"""
Computes the leading CTA rank.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
return cta_rank_in_cluster // cta_v_size * cta_v_size
@staticmethod
def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout):
"""
Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders.
"""
bidx, bidy, _ = cute.arch.block_idx()
mma_coord_vmnk = (
bidx % cute.size(cta_layout_vmnk, mode=[0]),
bidx // cute.size(cta_layout_vmnk, mode=[0]),
bidy,
None,
)
return mma_coord_vmnk[0] == 0
@staticmethod
def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout):
"""
Computes a mask for signaling arrivals to multicasting threadblocks.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0
)
block_in_cluster_coord_vmnk_peer = (
cta_in_cluster_coord_vmnk[0] ^ 1,
*cta_in_cluster_coord_vmnk[1:],
)
mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0
)
return mask_self | mask_peer
@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
barrier_storage: cute.Pointer = None,
cta_layout_vmnk: Optional[cute.Layout] = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineAsyncUmma.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
"""
if not isinstance(barrier_storage, cute.Pointer):
raise ValueError(
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)
producer_type = PipelineOp.AsyncThread
consumer_type = PipelineOp.TCGen05Mma
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_full = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8),
num_stages,
producer,
)
sync_object_empty = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
cta_v_size = (
cute.size(cta_layout_vmnk, mode=[0]) if cta_layout_vmnk is not None else 1
)
cta_group = (
cute.nvgpu.tcgen05.CtaGroup.ONE
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
else cute.nvgpu.tcgen05.CtaGroup.TWO
)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1:
# No mcast mask if we're not using 2CTA tcgen05 MMA
producer_mask = None
consumer_mask = None
else:
# If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA
# We need to get the target cta_rank
producer_mask = PipelineAsyncUmma._compute_leading_cta_rank(cta_v_size)
# consumer needs to get the mask to signal
consumer_mask = PipelineAsyncUmma._compute_peer_cta_mask(cta_layout_vmnk)
pipeline_init_wait(cta_layout_vmnk)
return PipelineAsyncUmma(
sync_object_full,
sync_object_empty,
num_stages,
producer_mask,
consumer_mask,
cta_group,
)
def consumer_release(self, state: PipelineState):
"""
UMMA consumer release buffer empty, cta_group needs to be provided.
"""
self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group)
@dataclass(frozen=True)
class PipelineUmmaAsync(PipelineAsync):
"""
PipelineUmmaAsync is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines).
"""
cta_group: cute.nvgpu.tcgen05.CtaGroup
@staticmethod
def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout):
"""
Computes a mask to signal completion of tmem buffers for 2CTA kernels.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
return cute.make_layout_image_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0
)
@staticmethod
def _compute_peer_cta_rank():
"""
Computes a mask to signal release of tmem buffers for 2CTA kernels.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
return cta_rank_in_cluster // 2 * 2
@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
barrier_storage: cute.Pointer = None,
cta_layout_vmnk: Optional[cute.Layout] = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
"""
if not isinstance(barrier_storage, cute.Pointer):
raise ValueError(
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)
producer_type = PipelineOp.TCGen05Mma
consumer_type = PipelineOp.AsyncThread
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_full = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8), num_stages, producer
)
sync_object_empty = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
# Set mask to None if not using clusters (i.e. 1CTA kernels)
producer_mask = None
else:
producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1:
# Set mask to None if not using 2CTA intructions
consumer_mask = None
else:
consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank()
cta_group = (
cute.nvgpu.tcgen05.CtaGroup.ONE
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
else cute.nvgpu.tcgen05.CtaGroup.TWO
)
pipeline_init_wait(cta_layout_vmnk)
return PipelineUmmaAsync(
sync_object_full,
sync_object_empty,
num_stages,
producer_mask,
consumer_mask,
cta_group,
)
def producer_commit(self, state: PipelineState):
"""
UMMA producer commit buffer full, cta_group needs to be provided.
"""
self.sync_object_full.arrive(state.index, self.producer_mask, self.cta_group)
def producer_tail(self, state: PipelineState):
"""
Make sure the last used buffer empty signal is visible to producer.
Producer tail is usually executed by producer before exit, to avoid dangling
mbarrier arrive signals after kernel exit.
:param state: The pipeline state that points to next useful buffer
:type state: PipelineState
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
is_leader_cta = cta_rank_in_cluster % 2 == 0
def then_body():
# Assume state contains that next useful buffer
# So we only need to advance to num_stages - 1 times to last used buffer
for i in range(self.num_stages - 1):
state.advance()
self.producer_acquire(state)
if_generate(is_leader_cta, then_body)

View File

@ -0,0 +1,803 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union
import warnings
import cutlass.cute as cute
from cutlass.cutlass_dsl import Boolean, Int32, if_generate
from cutlass.pipeline import (
CooperativeGroup,
PipelineOp,
SyncObject,
MbarrierArray,
TmaStoreFence,
PipelineUserType,
PipelineState,
make_pipeline_state,
pipeline_init_wait,
)
##############################################################################
# Pipeline classes
##############################################################################
@dataclass(frozen=True)
class PipelineAsync:
"""PipelineAsync is a generic pipeline class where both the producer and consumer are
AsyncThreads. It also serves as a base class for specialized pipeline classes.
This class implements a producer-consumer pipeline pattern where both sides operate
asynchronously. The pipeline maintains synchronization state using barrier objects
to coordinate between producer and consumer threads.
The pipeline state transitions of one pipeline entry(mbarrier) can be represented as:
.. table:: Pipeline State Transitions
:widths: auto
+-----------+-----------+-----------+-----------+-----------+-----------+
| Barrier | State | p.acquire | p.commit | c.wait | c.release |
+===========+===========+===========+===========+===========+===========+
| empty_bar | empty | <Return> | n/a | n/a | - |
+-----------+-----------+-----------+-----------+-----------+-----------+
| empty_bar | wait | <Block> | n/a | n/a | -> empty |
+-----------+-----------+-----------+-----------+-----------+-----------+
| full_bar | wait | n/a | -> full | <Block > | n/a |
+-----------+-----------+-----------+-----------+-----------+-----------+
| full_bar | full | n/a | - | <Return> | n/a |
+-----------+-----------+-----------+-----------+-----------+-----------+
Where:
- p: producer
- c: consumer
- <Block>: This action is blocked until transition to a state allow it to proceed by other side
- e.g. ``p.acquire()`` is blocked until ``empty_bar`` transition to ``empty`` state by ``c.release()``
.. code-block:: text
Array of mbarriers as circular buffer:
Advance Direction
<-------------------
Producer Consumer
| ^
V |
+-----------------+
--|X|X|W|D|D|D|D|R|X|<-.
/ +-----------------+ \\
| |
`------------------------'
Where:
- X: Empty buffer (initial state)
- W: Producer writing (producer is waiting for buffer to be empty)
- D: Data ready (producer has written data to buffer)
- R: Consumer reading (consumer is consuming data from buffer)
**Example:**
.. code-block:: python
# Create pipeline with 5 stages
pipeline = PipelineAsync.create(
num_stages=5, # number of pipeline stages
producer_group=producer_warp,
consumer_group=consumer_warp
barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory
)
# Producer side
producer = pipeline.make_pipeline_producer(producer_warp)
for i in range(num_iterations):
producer.acquire() # Wait for buffer to be empty
# Write data to pipeline buffer
producer.commit() # Signal buffer is full
producer.advance() # Move index to next stage
# Consumer side
consumer = pipeline.make_pipeline_consumer(consumer_warp)
for i in range(num_iterations):
consumer.wait() # Wait for buffer to be full
# Read data from pipeline buffer
consumer.release() # Signal buffer is empty
consumer.advance() # Move index to next stage
"""
sync_object_full: SyncObject
sync_object_empty: SyncObject
num_stages: int
producer_mask: Optional[Int32]
consumer_mask: Optional[Int32]
@staticmethod
def _make_sync_object(
barrier_storage: cute.Pointer,
num_stages: int,
agent: tuple[PipelineOp, CooperativeGroup],
tx_count: int = 0,
) -> SyncObject:
"""
Returns a SyncObject corresponding to an agent's PipelineOp.
"""
if agent[0] in [
PipelineOp.AsyncThread,
PipelineOp.TmaLoad,
PipelineOp.TCGen05Mma,
PipelineOp.Composite,
]:
return MbarrierArray(
barrier_storage=barrier_storage,
num_stages=num_stages,
agent=agent,
tx_count=tx_count,
)
elif agent[0] is PipelineOp.TmaStore:
# Path taken for AsyncTmaStore
return TmaStoreFence(num_stages=num_stages)
else:
assert False, "Error: Invalid PipelineOp specified."
@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
barrier_storage: cute.Pointer = None,
producer_mask: Int32 = None,
consumer_mask: Int32 = None,
):
"""Creates and initializes a new PipelineAsync instance.
This helper function computes necessary attributes and returns an instance of PipelineAsync
with the specified configuration for producer and consumer synchronization.
:param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: int
:param producer_group: `CooperativeGroup` for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: `CooperativeGroup` for the consumer agent
:type consumer_group: CooperativeGroup
:param producer_mask: Mask for signaling arrives for the producer agent, defaults to ``None``
:type producer_mask: Int32, optional
:param consumer_mask: Mask for signaling arrives for the consumer agent, defaults to ``None``
:type consumer_mask: Int32, optional
:return: A new PipelineAsync instance
:rtype: PipelineAsync
:raises ValueError: If barrier_storage is not a cute.Pointer instance
"""
if not isinstance(barrier_storage, cute.Pointer):
raise ValueError(
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)
producer_type = PipelineOp.AsyncThread
consumer_type = PipelineOp.AsyncThread
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_full = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8), num_stages, producer
)
sync_object_empty = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
pipeline_init_wait()
return PipelineAsync(
sync_object_full,
sync_object_empty,
num_stages,
producer_mask,
consumer_mask,
)
def producer_acquire(
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
):
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_empty.wait(state.index, state.phase),
)
def producer_try_acquire(self, state: PipelineState):
return self.sync_object_empty.try_wait(state.index, state.phase)
def producer_commit(self, state: PipelineState):
self.sync_object_full.arrive(state.index, self.producer_mask)
def consumer_wait(
self, state: PipelineState, try_wait_token: Optional[Boolean] = None
):
if_generate(
try_wait_token is None or try_wait_token == 0,
lambda: self.sync_object_full.wait(state.index, state.phase),
)
def consumer_try_wait(self, state: PipelineState):
return self.sync_object_full.try_wait(state.index, state.phase)
def consumer_release(self, state: PipelineState):
self.sync_object_empty.arrive(state.index, self.consumer_mask)
def producer_get_barrier(self, state: PipelineState) -> cute.Pointer:
return self.sync_object_full.get_barrier(state.index)
def producer_tail(self, state: PipelineState):
"""
Make sure the last used buffer empty signal is visible to producer.
Producer tail is usually executed by producer before exit, to avoid dangling
mbarrier arrive signals after kernel exit.
:param state: The pipeline state that points to next useful buffer
:type state: PipelineState
"""
# Assume state contains that next useful buffer
# So we only need to advance to num_stages - 1 times to last used buffer
for i in range(self.num_stages - 1):
state.advance()
self.producer_acquire(state)
# Util methods to manage produer and consumer
def make_pipeline_producer(self, group: CooperativeGroup):
state = make_pipeline_state(PipelineUserType.Producer, self.num_stages)
return PipelineProducer(self, state, group)
def make_pipeline_consumer(self, group: CooperativeGroup):
state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages)
return PipelineConsumer(self, state, group)
@dataclass(frozen=True)
class PipelineTmaAsync(PipelineAsync):
"""
PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops).
"""
is_signalling_thread: Boolean
@staticmethod
@cute.jit
def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout, tidx: Int32):
"""
Initialize the empty barrier arrive signal
This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread
"""
# Logic to optimally schedule Empty Arrives
cluster_shape_vmnk = cta_layout_vmnk.shape
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
tidx = tidx % 32
is_signalling_thread = tidx < cute.size(cluster_shape_vmnk)
dst_rank = tidx % cute.size(cluster_shape_vmnk)
dst_cta_coord = cta_layout_vmnk.get_hier_coord(dst_rank)
cur_cta_coord = cta_layout_vmnk.get_hier_coord(cta_rank_in_cluster)
is_same_row = (
dst_cta_coord[0] == cur_cta_coord[0]
and dst_cta_coord[1] == cur_cta_coord[1]
and dst_cta_coord[3] == cur_cta_coord[3]
)
is_same_col = (
dst_cta_coord[0] == cur_cta_coord[0]
and dst_cta_coord[2] == cur_cta_coord[2]
and dst_cta_coord[3] == cur_cta_coord[3]
)
is_same_row_or_col = is_same_row or is_same_col
is_signalling_thread_final = is_signalling_thread and is_same_row_or_col
return dst_rank, is_signalling_thread_final
@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
tx_count: int,
barrier_storage: cute.Pointer = None,
cta_layout_vmnk: Optional[cute.Layout] = None,
tidx: Optional[Int32] = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
:param tidx: thread index to consumer async threads
:type tidx: Int32 | None
"""
if not isinstance(barrier_storage, cute.Pointer):
raise ValueError(
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)
producer_type = PipelineOp.TmaLoad
consumer_type = PipelineOp.AsyncThread
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_full = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8), num_stages, producer, tx_count
)
sync_object_empty = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
if tidx is None:
tidx, _, _ = cute.arch.thread_idx()
if cta_layout_vmnk is None:
cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
(
dst_rank,
is_signalling_thread,
) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
dst_rank = None
else:
dst_rank = dst_rank
producer_mask = None
pipeline_init_wait(cta_layout_vmnk)
return PipelineTmaAsync(
sync_object_full,
sync_object_empty,
num_stages,
producer_mask,
dst_rank,
is_signalling_thread,
)
def producer_acquire(
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
):
"""
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
"""
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_empty.wait(state.index, state.phase),
)
self.sync_object_full.arrive(state.index, self.producer_mask)
def producer_commit(self, state: PipelineState):
"""
TMA producer commit is a noop since TMA instruction itself updates the transaction count.
"""
pass
def consumer_release(self, state: PipelineState):
"""
TMA consumer release conditionally signals the empty buffer to the producer.
"""
if_generate(
self.is_signalling_thread,
lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask),
)
@dataclass(frozen=True)
class PipelineTmaMultiConsumersAsync(PipelineAsync):
"""
PipelineTmaMultiConsumersAsync is used for TMA producers and UMMA+Async consumers.
"""
is_leader_cta: bool
sync_object_empty_umma: SyncObject
sync_object_empty_async: SyncObject
cta_group: cute.nvgpu.tcgen05.CtaGroup
@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
consumer_group_umma: CooperativeGroup,
consumer_group_async: CooperativeGroup,
tx_count: int,
barrier_storage: cute.Pointer = None,
cta_layout_vmnk: Optional[cute.Layout] = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineTmaMultiConsumersAsync.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group_umma: CooperativeGroup for the UMMA consumer agent
:type consumer_group_umma: CooperativeGroup
:param consumer_group_async: CooperativeGroup for the AsyncThread consumer agent
:type consumer_group_async: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
"""
if not isinstance(barrier_storage, cute.Pointer):
raise ValueError(
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)
producer_type = PipelineOp.TmaLoad
consumer_type = PipelineOp.Composite
consumer_type_umma = PipelineOp.TCGen05Mma
consumer_type_async = PipelineOp.AsyncThread
if consumer_group_umma.agent != consumer_group_async.agent:
raise ValueError(
"UMMA and AsyncThread consumer groups must be the same agent"
)
if cta_layout_vmnk is not None and cute.size(cta_layout_vmnk) != 1:
raise ValueError(
f"PipelineTmaMultiConsumersAsync is not verified for cta_layout_vmnk != 1, cta_layout_vmnk:{cta_layout_vmnk}"
)
consumer_group = CooperativeGroup(
consumer_group_umma.agent,
consumer_group_umma.size + consumer_group_async.size,
)
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_full = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8), num_stages, producer, tx_count
)
sync_object_empty = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
sync_object_empty_umma = sync_object_empty.recast_to_new_op_type(
consumer_type_umma
)
sync_object_empty_async = sync_object_empty.recast_to_new_op_type(
consumer_type_async
)
# No mcast mask if not using clusters
producer_mask = None
consumer_mask = None
# All threadblocks are leaders if not using clusters
is_leader_cta = True
cta_group = (
cute.nvgpu.tcgen05.CtaGroup.ONE
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
else cute.nvgpu.tcgen05.CtaGroup.TWO
)
pipeline_init_wait(cta_layout_vmnk)
return PipelineTmaMultiConsumersAsync(
sync_object_full,
sync_object_empty,
num_stages,
producer_mask,
consumer_mask,
is_leader_cta,
sync_object_empty_umma,
sync_object_empty_async,
cta_group,
)
def producer_acquire(
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
):
"""
TMA producer acquire waits on buffer empty and sets the transaction barrier for leader threadblocks.
"""
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_empty.wait(state.index, state.phase),
)
if_generate(
self.is_leader_cta,
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
)
def producer_commit(self, state: PipelineState):
"""
TMA producer commit is a noop since TMA instruction itself updates the transaction count.
"""
pass
def consumer_release(self, state: PipelineState, op_type: PipelineOp):
if op_type == PipelineOp.TCGen05Mma:
self.sync_object_empty_umma.arrive(
state.index, self.consumer_mask, self.cta_group
)
elif op_type == PipelineOp.AsyncThread:
self.sync_object_empty_async.arrive(state.index, self.consumer_mask)
else:
raise ValueError(f"Invalid PipelineOp specified. op_type:{op_type}")
@dataclass(frozen=True)
class PipelineTmaStore(PipelineAsync):
"""
PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers.
"""
@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineTmaStore.
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
"""
producer_type = PipelineOp.TmaStore
producer = (producer_type, producer_group)
sync_object_full = PipelineAsync._make_sync_object(None, num_stages, producer)
return PipelineTmaStore(sync_object_full, None, num_stages, None, None)
def producer_acquire(self):
self.sync_object_full.wait()
def producer_commit(self):
self.sync_object_full.arrive()
def consumer_wait(self):
assert False, "Error: PipelineTmaStore does not have a consumer agent."
def consumer_release(self):
assert False, "Error: PipelineTmaStore does not have a consumer agent."
def producer_tail(self):
self.sync_object_full.tail()
#################################################################
# Utilities to help user of pipeline to simplify the workflow
#################################################################
class PipelineProducer:
"""A class representing a producer in an asynchronous pipeline.
The Producer class manages the producer side of an asynchronous pipeline, handling
synchronization and state management for producing data. It provides methods for
acquiring, committing, and advancing through pipeline stages.
:ivar _pipeline: The asynchronous pipeline this producer belongs to
:type _pipeline: PipelineAsync
:ivar _state: The current state of the producer in the pipeline
:type _state: PipelineState
:ivar _group: The cooperative group this producer operates in
:type _group: CooperativeGroup
**Examples:**
.. code-block:: python
pipeline = PipelineAsync.create(...)
producer = pipeline.create_producer(producer_group, stages)
for i in range(iterations):
producer.acquire() # Wait for buffer to be empty
# Produce data
producer.commit() # Signal data is ready
producer.advance() # Move to next stage
"""
_pipeline: PipelineAsync
_state: PipelineState
_group: CooperativeGroup
def __init__(self, pipeline, state, group: CooperativeGroup):
"""Initialize a new Producer instance.
:param pipeline: The pipeline this producer belongs to
:type pipeline: PipelineAsync
:param state: Initial pipeline state
:type state: PipelineState
:param group: The cooperative group for synchronization
:type group: CooperativeGroup
"""
self._pipeline = pipeline
self._state = state
self._group = group
@property
def index(self):
"""Get the index of the current pipeline stage."""
return self._state.index
def get_barrier(self):
"""Get the barrier pointer for the current pipeline stage.
:return: Pointer to the barrier for the current stage
:rtype: cute.Pointer
"""
return self._pipeline.producer_get_barrier(self._state)
def acquire(self):
"""Wait for the current buffer to be empty before producing data.
This is a blocking operation.
"""
self._pipeline.producer_acquire(self._state)
def try_acquire(self):
"""Try to acquire the current buffer without blocking.
:return: True if acquisition was successful, False otherwise
:rtype: bool
"""
self._pipeline.producer_try_acquire(self._state)
def commit(self):
"""Signal that data production is complete for the current stage.
This allows consumers to start processing the data.
"""
self._pipeline.producer_commit(self._state)
def tail(self):
"""Ensure all used buffers are properly synchronized before producer exit.
This should be called before the producer finishes to avoid dangling signals.
"""
self._pipeline.producer_tail(self._state)
def advance(self):
"""Move to the next pipeline stage."""
self._state.advance()
def __extract_mlir_values__(self):
"""Extract MLIR values from the current state.
:return: List of MLIR values representing the current state
:rtype: list
"""
# TODO: need to handle pipeline as well
return self._state.__extract_mlir_values__()
def __new_from_mlir_values__(self, values):
"""Create a new Producer instance from MLIR values.
:param values: MLIR values to initialize the state
:type values: Any
:return: New Producer instance with state initialized from values
:rtype: Producer
"""
return PipelineProducer(
self._pipeline, self._state.__new_from_mlir_values__(values), self._group
)
class PipelineConsumer:
"""A class representing a consumer in an asynchronous pipeline.
The Consumer class manages the consumer side of an asynchronous pipeline, handling
synchronization and state management for consuming data. It provides methods for
waiting, releasing, and advancing through pipeline stages.
:ivar _pipeline: The asynchronous pipeline this consumer belongs to
:type _pipeline: PipelineAsync
:ivar _state: The current state of the consumer in the pipeline
:type _state: PipelineState
:ivar _group: The cooperative group this consumer operates in
:type _group: CooperativeGroup
**Examples:**
.. code-block:: python
pipeline = PipelineAsync.create(...)
consumer = pipeline.create_consumer(consumer_group, stages)
for i in range(iterations):
consumer.wait() # Wait for data to be ready
# Consume data
consumer.release() # Signal buffer is empty
consumer.advance() # Move to next stage
"""
_pipeline: PipelineAsync
_state: PipelineState
_group: CooperativeGroup
def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup):
"""Initialize a new Consumer instance.
:param pipeline: The pipeline this consumer belongs to
:type pipeline: PipelineAsync
:param state: Initial pipeline state
:type state: PipelineState
:param group: The cooperative group for synchronization
:type group: CooperativeGroup
"""
self._pipeline = pipeline
self._group = group
self._state = state
@property
def index(self):
"""Get the index of the current pipeline stage."""
return self._state.index
def wait(self):
"""Wait for data to be ready in the current buffer.
This is a blocking operation.
"""
self._pipeline.consumer_wait(self._state)
def try_wait(self):
"""Try to check if data is ready without blocking.
:return: True if data is ready, False otherwise
:rtype: bool
"""
self._pipeline.consumer_try_wait(self._state)
def release(self):
"""Signal that data consumption is complete for the current stage.
This allows producers to start producing new data.
"""
self._pipeline.consumer_release(self._state)
def advance(self):
"""Move to the next pipeline stage."""
self._state.advance()
def __extract_mlir_values__(self):
"""Extract MLIR values from the current state.
:return: List of MLIR values representing the current state
:rtype: list
"""
return self._state.__extract_mlir_values__()
def __new_from_mlir_values__(self, values):
"""Create a new Consumer instance from MLIR values.
:param values: MLIR values to initialize the state
:type values: Any
:return: New Consumer instance with state initialized from values
:rtype: Consumer
"""
# TODO: need to call pipeline.__new_from_mlir_values__ recursively
return PipelineConsumer(
self._pipeline, self._state.__new_from_mlir_values__(values), self._group
)

View File

@ -29,6 +29,7 @@ from cutlass.cute.typing import (
from cutlass.cute.runtime import from_dlpack
import cutlass.cute as cute
import torch
from cuda import cuda
def dtype(ty: Type[Numeric]):
@ -94,12 +95,13 @@ def create_and_permute_torch_tensor(
init_config: Optional[
Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig]
] = None,
device: Optional[torch.device] = None,
) -> "torch.Tensor":
"""
Create a torch tensor with specified shape and dtype. Optionally permute it and initialize it with specified init type and config
"""
init_dtype = torch.int32 if init_type == TensorInitType.RANDOM else torch.float32
init_torch_tensor = torch.empty(*shape, dtype=init_dtype)
init_torch_tensor = torch.empty(*shape, dtype=init_dtype, device=device)
if init_type == TensorInitType.SKIP:
assert init_config is None
f32_torch_tensor = init_torch_tensor
@ -167,3 +169,122 @@ def convert_cute_tensor(
# Copy and convert from f32 cute tensor to dtype cute tensor
cute.testing.convert(fp32_cute_tensor, cute_tensor)
return cute_tensor
def default_stream() -> cuda.CUstream:
"""
Get default CUstream from torch stream
"""
torch_stream = torch.cuda.default_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
return stream
def current_stream() -> cuda.CUstream:
"""
Get current CUstream from torch stream
"""
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
return stream
def matrix(
l: int,
mode0: int,
mode1: int,
is_mode0_major: bool,
cutlass_dtype: Type[Numeric],
init_type: TensorInitType = TensorInitType.RANDOM,
init_config: Optional[
Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig]
] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Create a torch tensor for matrix
:param l: length of the matrix
:param mode0: mode0 of the matrix
:param mode1: mode1 of the matrix
:param is_mode0_major: whether the matrix is mode0 major
:param cutlass_dtype: cutlass dtype of the matrix
:param init_type: type of initialization
:param init_config: configuration for initialization
:param device: target torch device
"""
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
if cutlass_dtype.is_float and cutlass_dtype.width <= 8:
torch_dtype = torch.int8
else:
torch_dtype = dtype(cutlass_dtype)
if init_type == TensorInitType.RANDOM and init_config is None:
if torch_dtype.is_signed:
min_val = -2
max_val = 2
else:
min_val = 0
max_val = 4
init_config = RandomInitConfig(min_val=min_val, max_val=max_val)
# Create dtype torch tensor
torch_tensor = create_and_permute_torch_tensor(
shape,
torch_dtype,
permute_order=permute_order,
init_type=init_type,
init_config=init_config,
device=device,
)
return torch_tensor
def cute_tensor_like(
data_ref: torch.Tensor,
cutlass_dtype: Type[Numeric],
is_dynamic_layout: bool,
assumed_align: Optional[int] = None,
) -> tuple[Tensor, torch.Tensor]:
"""
Create a cute tensor use a torch tensor as the data source
:param data_ref: torch tensor as the data source
:param cutlass_dtype: cutlass dtype of the cute tensor
:param is_dynamic_layout: whether the cute tensor uses dynamic layout
:param assumed_align: assumed alignment of the cute tensor
"""
# allocate device buffer for cute tensor
if cutlass_dtype.is_float and cutlass_dtype.width <= 8:
torch_dtype = torch.int8
else:
torch_dtype = dtype(cutlass_dtype)
torch_tensor = torch.empty_like(data_ref, dtype=torch_dtype, device="cuda")
# create cute tensor using the device buffer
cute_tensor = from_dlpack(torch_tensor, assumed_align=assumed_align)
cute_tensor.element_type = cutlass_dtype
if is_dynamic_layout:
for i, stride in enumerate(torch_tensor.stride()):
if stride == 1:
leading_dim = i
break
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
# initialize the cute tensor data
if cutlass_dtype.is_float and cutlass_dtype.width <= 8:
cute_tensor = convert_cute_tensor(
data_ref.to(dtype=torch.float32),
cute_tensor,
cutlass_dtype,
is_dynamic_layout,
)
else:
torch_tensor.copy_(data_ref.to(dtype=torch_dtype))
return cute_tensor, torch_tensor

View File

@ -15,20 +15,6 @@ from .static_persistent_tile_scheduler import (
StaticPersistentTileScheduler,
)
from .pipeline import (
Agent,
CooperativeGroup,
PipelineUserType,
PipelineState,
make_pipeline_state,
PipelineAsync,
PipelineTmaAsync,
PipelineTmaUmma,
PipelineUmmaAsync,
PipelineTmaStore,
pipeline_init_wait,
)
from .hardware_info import (
HardwareInfo,
)
@ -65,6 +51,8 @@ from .smem_allocator import SmemAllocator
from .layout import LayoutEnum
__all__ = [
"SmemAllocator",
"LayoutEnum",
"WorkTileInfo",
"PersistentTileSchedulerParams",
"StaticPersistentTileScheduler",

View File

@ -51,8 +51,13 @@ from cutlass.cute.nvgpu.tcgen05 import (
is_tmem_load,
get_tmem_copy_properties,
)
from cutlass.cute.nvgpu.cpasync import (
CopyBulkTensorTileG2SMulticastOp,
CopyBulkTensorTileG2SOp,
)
from cutlass.utils.layout import LayoutEnum
@dsl_user_op
def compute_epilogue_tile_shape(
cta_tile_shape: cute.Shape,
@ -716,6 +721,7 @@ def make_smem_layout_b(
return b_smem_layout_staged
@dsl_user_op
def get_smem_layout_atom_epi(
layout: LayoutEnum,
@ -827,6 +833,7 @@ SMEM_CAPACITY = {
"sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value,
}
@dsl_user_op
def make_trivial_tiled_mma(
ab_dtype: Type[Numeric],
@ -908,3 +915,139 @@ def make_trivial_tiled_mma(
raise TypeError(f"unsupported ab_dtype, got {ab_dtype}")
return cute.make_tiled_mma(cute.make_mma_atom(mma_op))
@dsl_user_op
def cluster_shape_to_tma_atom_A(
cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None
) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]:
"""
Select the appropriate TMA copy atom for A based on the number of SMs and the multicast flag.
:param cluster_shape_mnk: The shape of the cluster
:type cluster_shape_mnk: cute.Shape
:param atom_thr_id: The thread ID of the atom
:type atom_thr_id: cute.Layout
:return: The appropriate TMA copy atom kind
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
:raise ValueError: If the atom_sm_cnt is invalid
:raise ValueError: If the cluster shape is not divisible by the atom SM count
"""
atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip)
mcast = not (cute.size(cluster_shape_mnk, mode=[1], loc=loc, ip=ip) == 1)
cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip)
if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int):
raise ValueError(
f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}"
)
if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0:
raise ValueError(
f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}"
)
if atom_sm_cnt == 2 and mcast:
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO)
elif atom_sm_cnt == 2 and not mcast:
return CopyBulkTensorTileG2SOp(CtaGroup.TWO)
elif atom_sm_cnt == 1 and mcast:
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE)
elif atom_sm_cnt == 1 and not mcast:
return CopyBulkTensorTileG2SOp(CtaGroup.ONE)
raise ValueError(
f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}"
)
@dsl_user_op
def cluster_shape_to_tma_atom_B(
cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None
) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]:
"""
Select the appropriate TMA copy atom for Bbased on the number of SMs and the multicast flag.
:param cluster_shape_mnk: The shape of the cluster
:type cluster_shape_mnk: cute.Shape
:param atom_thr_id: The thread ID of the atom
:type atom_thr_id: cute.Layout
:return: The appropriate TMA copy atom kind
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
:raise ValueError: If the atom_sm_cnt is invalid
:raise ValueError: If the cluster shape is not divisible by the atom SM count
"""
atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip)
mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == atom_sm_cnt)
cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip)
if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int):
raise ValueError(
f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}"
)
if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0:
raise ValueError(
f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}"
)
if atom_sm_cnt == 2 and mcast:
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO)
elif atom_sm_cnt == 2 and not mcast:
return CopyBulkTensorTileG2SOp(CtaGroup.TWO)
elif atom_sm_cnt == 1 and mcast:
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE)
elif atom_sm_cnt == 1 and not mcast:
return CopyBulkTensorTileG2SOp(CtaGroup.ONE)
raise ValueError(
f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}"
)
@dsl_user_op
def cluster_shape_to_tma_atom_SFB(
cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None
) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]:
"""
Select the appropriate TMA copy atom for SFB based on the number of SMs and the multicast flag.
:param cluster_shape_mnk: The shape of the cluster
:type cluster_shape_mnk: cute.Shape
:param atom_thr_id: The thread ID of the atom
:type atom_thr_id: cute.Layout
:return: The appropriate TMA copy atom kind
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
:raise ValueError: If the atom_sm_cnt is invalid
:raise ValueError: If the cluster shape is not divisible by the atom SM count
"""
atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip)
mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == 1)
cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip)
if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int):
raise ValueError(
f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}"
)
if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0:
raise ValueError(
f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}"
)
if atom_sm_cnt == 2:
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO)
elif atom_sm_cnt == 1 and mcast:
return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE)
elif atom_sm_cnt == 1 and not mcast:
return CopyBulkTensorTileG2SOp(CtaGroup.ONE)
raise ValueError(
f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}"
)

View File

@ -96,6 +96,10 @@ def make_trivial_tiled_mma(
acc_dtype: Type[Numeric],
atom_layout_mnk: Tuple[int, int, int],
tiler_mn: Tuple[int, int],
a_source: OperandSource = OperandSource.SMEM,
*,
loc=None,
ip=None,
) -> cute.TiledMma:
"""Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
By default, the MMA atom is created with SMEM operand source for A.
@ -131,7 +135,7 @@ def make_trivial_tiled_mma(
a_dtype,
acc_dtype,
(*tiler_mn, 16),
OperandSource.SMEM,
a_source,
a_leading_mode,
b_leading_mode,
)
@ -144,7 +148,7 @@ def make_trivial_tiled_mma(
b_dtype,
acc_dtype,
(*tiler_mn, 32),
OperandSource.SMEM,
a_source,
a_leading_mode,
b_leading_mode,
)

File diff suppressed because it is too large Load Diff

View File

@ -18,41 +18,25 @@ from cutlass.cute.arch import get_dyn_smem
class SmemAllocator:
"""
A class for managing shared memory allocation on GPU.
"""A class for managing shared memory allocation on GPU.
This class manages a chunk of shared memory and provide APIs for sub-allocation
This class manages a chunk of shared memory and provides APIs for sub-allocation
inside the chunk.
Attributes
----------
_base : cute.Pointer as i8 typed dynamic value
The current base address of the shared memory.
:ivar _base: The current base address of the shared memory as an i8 typed dynamic value.
:type _base: cute.Pointer
:ivar _allocated_bytes: The total number of bytes allocated in shared memory.
:type _allocated_bytes: int
_allocated_bytes:
The bytes allocated in shared memory.
Methods
-------
allocate(num_bytes, alignment)
Allocates num_bytes in the shared memory with the given byte alignment.
allocate_value(value_ty, num_elems)
Allocates num_elems of value_ty values in the shared memory.
allocate_tensor(value_ty, layout, alignment)
Allocates a tensor in the shared memory with given layout and byte alignment.
Notes
-----
This class is responsible for managing the allocation of tensors in shared memory.
.. note::
This class is responsible for managing the allocation of tensors in shared memory.
The base pointer is aligned to 1024 bytes upon initialization.
"""
def __init__(self):
"""
Initializes the SmemAllocator instance with dynamic smem base ptr,
which is i8 type and aligned to 1024.
"""Initialize the SmemAllocator instance.
Creates a dynamic shared memory base pointer of type i8, aligned to 1024 bytes.
"""
self._base = get_dyn_smem(Int8, alignment=1024)
self._allocated_bytes = 0
@ -64,30 +48,19 @@ class SmemAllocator:
def allocate(self, size_or_type: cute.struct, byte_alignment: int): ...
def allocate(self, size_or_type, byte_alignment: int = 1) -> int:
"""Allocate a block of memory with specified size and alignment.
This method adjusts the base pointer to ensure proper alignment and updates
the internal state to track allocated memory.
:param size_or_type: The number of bytes to allocate or a struct class
:type size_or_type: Union[int, cute.struct]
:param byte_alignment: The byte alignment requirement, defaults to 1 (no alignment)
:type byte_alignment: int, optional
:return: Pointer to the start of the allocated memory block or struct instance
:rtype: cute.Pointer
:raises ValueError: If size is negative or alignment is less than 1
"""
Allocates a block of memory with the specified size and byte alignment.
This method adjusts the base cute.Pointer to ensure that the allocated memory
is aligned according to the specified byte alignment. It updates the internal
state to reflect the new base cute.Pointer and the total allocated bytes.
Parameters
----------
size_or_type : int or struct
The number of bytes to allocate or struct class.
byte_alignment : int
The byte alignment requirement for the allocation. Defaults to 1 (no alignment).
Returns
----------
A cute.Pointer to the start of the allocated memory block or struct instance.
Raises
----------
ValueError
If num_bytes is negative or if byte_alignmemt is less than 1.
"""
if isinstance(size_or_type, cute.struct):
alignment = max(byte_alignment, size_or_type.__alignof__())
base_ptr = self.allocate(size_or_type.__sizeof__(), alignment)
@ -110,27 +83,16 @@ class SmemAllocator:
return ptr
def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1):
"""
Allocates num_elems values of element_type in shared memory.
"""Allocate an array of elements in shared memory.
This method calls allocate() to return a byte ptr, pointing to start of shared
memory. Then calls cute.recast_ptr() to recast this byte cute.Pointer to element_type.
Parameters
----------
element_type : Type[Numeric]
The type of the values in the tensor.
num_elems : int, optional
The number of elements for each allocation. Defaults to 1.
Returns
----------
A value_type cute.Pointer to the start of the allocated memory block.
Raises
----------
ValueError
If num_elems is less than 1.
:param element_type: The type of elements to allocate
:type element_type: Type[Numeric]
:param num_elems: Number of elements to allocate, defaults to 1
:type num_elems: int, optional
:return: Pointer to the start of the allocated array
:rtype: cute.Pointer
:raises ValueError: If num_elems is less than 1
:raises TypeError: If element_type is not a Numeric type
"""
if num_elems < 1:
raise ValueError("num_elems must be at least 1")
@ -152,28 +114,21 @@ class SmemAllocator:
byte_alignment: int = 1,
swizzle: cute.Swizzle = None,
):
"""
Allocates a tensor in the shared memory with value type, layout and byte alignment.
"""Allocate a tensor in shared memory.
Parameters
----------
element_type : Type[Numeric]
The type of the values in the tensor.
layout : int | DynamicInt | cute.Layout | cute.ComposedLayout
The layout of the tensor.
byte_alignment : int, optional
The byte alignment requirement for the allocation. Defaults to 1 (no alignment).
swizzle : cute.Swizzle
A swizzle for the iterator (for position-dependent swizzling).
Returns
-------
tensor : cute.Tensor
The allocated tensor with specified value type, layout and byte alignment.
Notes
-----
The base address is updated to point to the next available memory location.
:param element_type: The type of elements in the tensor
:type element_type: Type[Numeric]
:param layout: The layout specification for the tensor
:type layout: Union[int, cute.Layout, cute.ComposedLayout]
:param byte_alignment: The byte alignment requirement, defaults to 1
:type byte_alignment: int, optional
:param swizzle: Swizzle for position-dependent swizzling, defaults to None
:type swizzle: cute.Swizzle, optional
:return: The allocated tensor with specified properties
:rtype: cute.Tensor
:raises TypeError: If element_type is not a Numeric type or if swizzle conflicts with layout
:raises ValueError: If allocation is not byte-aligned
:raises NotImplementedError: If dynamic layout is specified
"""
if not isinstance(element_type, NumericMeta):
raise TypeError(

View File

@ -23,6 +23,11 @@ from ..base_dsl.ast_helpers import (
dynamic_expr,
assert_executor,
bool_cast,
compare_executor,
any_executor,
all_executor,
range_value_check,
range_perf_warning,
)
from ..base_dsl import *

View File

@ -20,6 +20,7 @@ from inspect import isclass
import functools
import pkgutil
from dataclasses import is_dataclass
from collections.abc import Sequence
from ..base_dsl import *
from ..base_dsl import compiler
@ -51,6 +52,11 @@ from ..base_dsl.ast_helpers import (
while_executor,
assert_executor,
bool_cast,
compare_executor,
any_executor,
all_executor,
range_value_check,
range_perf_warning,
)
from ..base_dsl.runtime.dlpack_runtime import (
get_cute_tensor_c_pointer,
@ -67,18 +73,6 @@ from .cutlass_ast_decorators import (
_while_execute_dynamic,
)
# =============================================================================
# Set the AST decorator
# =============================================================================
# Set the DSL specific functions
executor.set_functions(
is_dynamic_expression,
_loop_execute_range_dynamic,
_if_execute_dynamic,
_while_execute_dynamic,
)
# =============================================================================
# Cutlass DSL Base Abstract Class
@ -1023,7 +1017,6 @@ def select_(cond, if_value, else_value):
)
return value
# Non-DSL dynamic cond should be handled before this.
if const_expr(not is_dynamic_expression(cond)):
raise DSLRuntimeError("Conditional expression must be dynamic")
@ -1089,6 +1082,7 @@ def for_generate(
iter_args: Optional[Sequence[ir.Value]] = None,
*,
unroll: LoopUnroll = None,
pipelining=None,
loc=None,
ip=None,
):
@ -1126,6 +1120,9 @@ def for_generate(
if unroll is not None:
for_op.attributes["loop_annotation"] = unroll
if pipelining is not None:
for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining)
iv = for_op.induction_variable
new_results = new_from_mlir_values(iter_args, for_op.results)
new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args)
@ -1319,3 +1316,122 @@ def while_generate(
Generate a WhileLoopContext for a dynamic loop.
"""
return WhileLoopContext(inputs, condition, loc=loc, ip=ip)
def equal(lhs, rhs):
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return lhs == rhs
# Both sequence
if isinstance(lhs, Sequence) and isinstance(rhs, Sequence):
# Short-circuit for unequal length
if len(lhs) != len(rhs):
return False
return all_(equal(l, r) for l, r in zip(lhs, rhs))
return lhs == rhs
def in_(lhs, rhs, op):
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return lhs in rhs
if not isinstance(rhs, Sequence):
raise DSLRuntimeError(
f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}"
)
return any_(equal(lhs, r) for r in rhs)
def _lt_gt(lhs, rhs, op):
def native_lt_gt(lhs, rhs, op):
if op == "<":
return lhs < rhs
elif op == ">":
return lhs > rhs
else:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return native_lt_gt(lhs, rhs, op)
# Both sequence, comparisons other than == and != do not allow mixing different types of sequences
if (
isinstance(lhs, Sequence)
and isinstance(rhs, Sequence)
and type(lhs) == type(rhs)
):
unequal_found = False
comp_results = []
mask = []
for l, r in zip(lhs, rhs):
is_equal = equal(l, r)
mask.append(not_(or_(is_equal, unequal_found)))
unequal_found = not_(is_equal)
comp_results.append(_lt_gt(l, r, op))
result = any_(and_(r, m) for r, m in zip(comp_results, mask))
if len(lhs) != len(rhs):
# Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types
# If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one
has_valid_mask = any_(mask)
if op == "<":
length_result = len(lhs) < len(rhs)
elif op == ">":
length_result = len(lhs) > len(rhs)
if type(has_valid_mask) == bool:
return result if has_valid_mask else length_result
else:
return select_(has_valid_mask, result, length_result)
else:
return result
else:
return native_lt_gt(lhs, rhs, op)
def greater_than(lhs, rhs):
return _lt_gt(lhs, rhs, ">")
def less_than(lhs, rhs):
return _lt_gt(lhs, rhs, "<")
def _compare_executor(left, comparators, ops):
result = left
for comparator, op in zip(comparators, ops):
# 'is' and 'is not' are pure python operators
if op == "is":
result = result is comparator
elif op == "is not":
result = result is not comparator
elif op in ["in", "not in"]:
result = in_(left, comparator, op)
elif op in ["==", "!="]:
result = equal(left, comparator)
elif op in ["<", ">="]:
result = less_than(left, comparator)
elif op in [">", "<="]:
result = greater_than(left, comparator)
else:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
# Invert the result for NotIn, NotEq, GtE, LtE
if op in ["not in", "!=", ">=", "<="]:
result = not_(result)
return result
# =============================================================================
# Set the AST decorator
# =============================================================================
# Set the DSL specific functions
executor.set_functions(
is_dynamic_expression,
_loop_execute_range_dynamic,
_if_execute_dynamic,
_while_execute_dynamic,
_compare_executor,
any_,
all_,
)

View File

@ -10,6 +10,7 @@
# is strictly prohibited.
from typing import List, Tuple
from types import NoneType
from cutlass._mlir import ir
from cutlass._mlir.dialects import scf, arith
from cutlass._mlir.extras import types as T
@ -145,6 +146,30 @@ class ScfGenerator:
# Use the provided terminator generator
block_term_op_builder[builder](region_result)
else:
# For standard yield op, check result
for arg, result, name in zip(
mix_iter_args,
(
region_result
if isinstance(region_result, list)
else [region_result]
),
mix_iter_arg_names,
):
if isinstance(arg, NoneType) and not isinstance(
result, NoneType
):
raise DSLRuntimeError(
(
f"`{name}` is None prior to this `{op_type_name}`, "
f"and update to non-None value inside of this `{op_type_name}` is not supported."
),
suggestion=(
f"Please make sure `{name}` is not None prior to this `{op_type_name}`, "
f"or mark this `{op_type_name}` with "
f"`{'range' if op_type_name == 'for' else 'const_expr'}`."
),
)
# Normalize region_result
region_result_list = ScfGenerator._normalize_region_result_to_list(
region_result
@ -200,6 +225,7 @@ def _loop_execute_range_dynamic(
mix_iter_arg_names: List[str] = [],
unroll: int = -1,
unroll_full: bool = False,
pipelining: int = None,
):
"""
Example: build an scf.for with optional unroll, using our universal helper.
@ -236,6 +262,18 @@ def _loop_execute_range_dynamic(
unroll_attr = LoopUnroll(count=unroll)
log().debug("Unroll attribute: %s", unroll_attr)
pipelining_attr = None
if pipelining is not None:
if pipelining >= 0:
pipelining_attr = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), pipelining
)
else:
raise DSLRuntimeError(
f"Pipelining must be non-negative, got {pipelining}"
)
log().debug("Pipelining attribute: %s", pipelining_attr)
log().debug(
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
start_,
@ -265,6 +303,9 @@ def _loop_execute_range_dynamic(
if unroll_attr is not None:
for_op.attributes["loop_annotation"] = unroll_attr
if pipelining_attr is not None:
for_op.attributes["cutlass.pipelining"] = pipelining_attr
return for_op
def for_body_builder(

View File

@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.0.0
nvidia-cutlass-dsl==4.1.0.dev0

View File

@ -73,6 +73,7 @@ class EVTFrontendBase:
self.dag_ir = DAGIR(self.cc, self.element_compute)
self.compute_cnt = 0
self.layout_cnt = 0
self.imm_cnt = 0
self.pass_manager = EVTPassManager(
self.dag_ir,
@ -107,6 +108,13 @@ class EVTFrontendBase:
# Parse the input
self.parse(*args, **kwargs)
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
if (self.cc >= 90):
if (self.dag_ir.out_degree("D") != 0):
raise RuntimeError(
f"On SM90 or higher, D is expected to be a output node with 0 users to "
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
# Run the passes
self.pass_manager()
# Set the epilogue type
@ -187,7 +195,8 @@ class EVTFrontendBase:
except:
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
name = f"imm_{value}".replace('.', '_')
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
self.imm_cnt += 1
load_node = LoadNode(name)
load_node.tensor = {"tensor": value, "is_constant": True}
self.add_node(load_node)

View File

@ -42,7 +42,7 @@ from cutlass_library import DataType
import cutlass
from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase
from cutlass.backend.epilogue import relu
from cutlass.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
from cutlass.backend.library import FunctionalOp
@ -72,10 +72,17 @@ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
ast.Div: FunctionalOp.Divides,
"maximum": FunctionalOp.Maximum,
"minimum": FunctionalOp.Minimum,
"identity": identity.binding_type,
"relu": relu.binding_type,
"tanh": tanh.binding_type,
"sigmoid": sigmoid.binding_type,
"silu": silu.binding_type,
"hardswish": hardswish.binding_type,
"gelu": gelu.binding_type,
"multiply_add": FunctionalOp.MultiplyAdd,
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum)
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
"exp": FunctionalOp.Exp
}
return mapping[op]

View File

@ -38,7 +38,9 @@ import networkx as nx
from cutlass_library import DataType
from cutlass.backend.evt.ir.compute_nodes import ComputeNode
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.library import ActivationOp
from cutlass.backend.utils import device_cc
@ -59,6 +61,8 @@ class DAGIR:
self.cc = cc
self.identity_counter = 0
#
# IR manipulator
#
@ -79,7 +83,21 @@ class DAGIR:
raise SyntaxError(f"Variable '{src}' is undefined.")
if not self.has_node(dst):
raise SyntaxError(f"Variable '{dst}' is undefined.")
self._graph.add_edge(src, dst, weight=weight)
if self._graph.has_edge(src, dst):
# The DiGraph doesn't support multiple edges between two nodes
# We insert an identity node in such case as a workaround
identity_name = f"autogen_identity_{self.identity_counter}"
self.identity_counter += 1
compute_node = ComputeNode(
name=identity_name, fn=ActivationOp.Identity,
element_output=self.element_compute,
element_compute=self.element_compute)
self.add_node(compute_node)
self.add_edge(src, identity_name, 0)
self.add_edge(identity_name, dst, weight)
else:
self._graph.add_edge(src, dst, weight=weight)
def remove_node(self, node: str):
"""

View File

@ -51,15 +51,19 @@ class Tensor:
"""
The tensor abstracts the data type
"""
def __init__(self, tensor=None, element=None, shape=None, layout_tag=None, is_constant=False) -> None:
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
if element is not None and tensor is not None:
raise Exception(f"Must not specify both element and tensor")
elif shape is not None and tensor is not None:
raise Exception(f"Must not specify both shape and tensor")
elif layout_tag is not None and tensor is not None:
raise Exception(f"Must not specify both layout_tag and tensor")
elif (element is None or layout_tag is None or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout) or (tensor)")
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
elif stride is not None and tensor is not None:
raise Exception(f"Must not specify both stride and tensor")
elif stride is not None and layout_tag is not None:
raise Exception(f"Must not specify layout_tag when stride is provided")
if isinstance(tensor, Tensor):
# Directly copy all the attributes
@ -70,10 +74,13 @@ class Tensor:
else:
self.element, layout_tag = get_datatype_and_layout(tensor)
shape = get_tensor_shape(tensor)
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
if stride is not None:
self.layout = Layout(shape[::-1], stride[::-1])
else:
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
self.layout = canonicalization(self.layout)
self.is_constant = is_constant

View File

@ -77,11 +77,12 @@ class PassDAG2Tree(EVTPassBase):
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
# get the common reachable objects
common_items = set.intersection(*reachable_nodes)
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
lca = None
# If common ancestor exists, find the lowest one
if len(common_items) > 0:
topo_order = self.dag_ir.nodes_topological_order()
lca = None
topo_idx = -1
for item in common_items:
if lca is None:
@ -91,53 +92,74 @@ class PassDAG2Tree(EVTPassBase):
if topo_idx > topo_order.index(item):
lca = item
topo_idx = topo_order.index(item)
# The lca is the output node of the DAG node
# Get the nodes to be fused
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
node_to_fuse.add(lca)
# Get all the input nodes
all_input_nodes = []
all_output_nodes = []
for node in node_to_fuse:
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
all_output_nodes.append(set(self.dag_ir.get_users(node)))
all_input_nodes = set.union(*all_input_nodes)
all_output_nodes = set.union(*all_output_nodes)
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR(self.dag_ir.cc)
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:
meta.disabled = True
subgraph.add_node(meta)
for edge in subgraph_.edges:
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
# Create the fused node
dag_node = TopoVisitorNode(
name=f"dag_{lca}", subgraph=subgraph,
output_node=self.dag_ir.get_node_meta(lca))
self.dag_ir.add_node(dag_node)
# Add input edges
for idx, node in enumerate(all_input_nodes):
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
# Replace all uses with DAG node (only 1 output node)
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
# Remove all fused nodes
node_to_fuse.remove(lca)
for node in node_to_fuse:
self.dag_ir.remove_node(node)
else:
raise NotImplementedError("No LCA found. Consider SplitTreeVisitor.")
# there is no common ancestor for all the parents, we pack all the reachable
# nodes into a single DAG node as a fallback. The lca should be the input node of
# one of the output nodes with out_degree = 0
potential_output_nodes = []
for node in node_to_fuse:
if self.dag_ir.out_degree(node) == 0:
potential_output_nodes.append(node)
if len(potential_output_nodes) == 0:
raise RuntimeError(f"No output node with out degree = 0 found.")
output_node = None
if (self.dag_ir.cc >= 90):
# For SM90, the lca should be the input node of D
if (not self.dag_ir.has_node("D")):
raise RuntimeError(f"D is not a node in the DAG IR.")
output_node = "D"
else:
output_node = potential_output_nodes[0]
if (output_node is None):
raise RuntimeError(f"No output node found.")
lca = self.dag_ir.get_all_inputs(output_node)[0]
node_to_fuse.remove(output_node)
# The lca is the output node of the DAG node
# Get the nodes to be fused
node_to_fuse.add(lca)
# Get all the input nodes
all_input_nodes = []
all_output_nodes = []
for node in node_to_fuse:
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
all_output_nodes.append(set(self.dag_ir.get_users(node)))
all_input_nodes = set.union(*all_input_nodes)
all_output_nodes = set.union(*all_output_nodes)
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR(self.dag_ir.cc)
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:
meta.disabled = True
subgraph.add_node(meta)
for edge in subgraph_.edges:
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
# Create the fused node
dag_node = TopoVisitorNode(
name=f"dag_{lca}", subgraph=subgraph,
output_node=self.dag_ir.get_node_meta(lca))
self.dag_ir.add_node(dag_node)
# Add input edges
for idx, node in enumerate(all_input_nodes):
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
# Replace all uses with DAG node (only 1 output node)
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
# Remove all fused nodes
node_to_fuse.remove(lca)
for node in node_to_fuse:
self.dag_ir.remove_node(node)
def ensures(self) -> None:
# Ensure that after the pass, the resulting DAG becomes a tree

View File

@ -118,6 +118,7 @@ class FunctionalOp(enum.Enum):
Multiplies = enum_auto()
MultiplyAdd = enum_auto()
Plus = enum_auto()
Exp = enum_auto()
FunctionalOpTag = {
@ -130,6 +131,7 @@ FunctionalOpTag = {
FunctionalOp.Multiplies: "cutlass::multiplies",
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
FunctionalOp.Plus: "cutlass::plus",
FunctionalOp.Exp: "cutlass::fast_exp_op",
}

View File

@ -52,4 +52,5 @@ from cutlass.epilogue.evt_ops import (
reshape,
maximum,
minimum,
exp
)

View File

@ -73,6 +73,12 @@ def minimum(x, y):
elif is_torch_tensor(x):
return torch.minimum(x, torch.tensor(y))
def exp(x):
if is_numpy_tensor(x):
return np.exp(x)
elif is_torch_tensor(x):
return torch.exp(x)
##############################################################################
# Layout manipulate nodes

View File

@ -297,7 +297,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
sm100_mma_data_type_general = [
'gemm_f16_f16_f16_f16_f16',
'gemm_f16_f16_f16_void_f16',
'gemm_f16_f16_f32_f16_f16',
#'gemm_f16_f16_f32_f16_f16',
'tf32gemm_f32_f32_f32_f32_f32',
'bf16gemm_f32_f32_f32_f32_f32',
]
@ -336,7 +336,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
#'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
@ -547,7 +547,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
if dynamic_cluster:
if mode == "functional_L0":
runtime_cluster_shapes = [[1,1,1], [2,1,1], [2,2,1], [4,1,1], [4,4,1]]
runtime_cluster_shapes = [[1,1,1], [2,2,1]]
else:
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape