v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -300,6 +300,21 @@ def if_executor(
class range:
"""
A range-like object for dynamic loop iteration in the DSL.
This class provides a range interface similar to Python's built-in range,
but is designed to be preprocessed into constructs for dynamic
loop execution.
The class supports both single-argument (stop) and three-argument
(start, stop, step) constructors with additional parameters for loop
optimization:
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
- unroll_full: Whether to fully unroll the loop
- pipelining: Compiler generated pipeline configuration
"""
@overload
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
pass
@ -460,7 +475,31 @@ def range_value_check(*args):
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
"""
try:
return tuple(arg.__index__() for arg in args)
args = tuple(arg.__index__() for arg in args)
# Compute range size and warn if it's too large
start = 0
end = 0
step = 1
if len(args) == 1:
end = args[0]
elif len(args) == 2:
start = args[0]
end = args[1]
elif len(args) == 3:
start = args[0]
end = args[1]
step = args[2]
range_length = (abs(end - start) - 1) // abs(step) + 1
if range_length >= 64:
warnings.warn(
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
category=UserWarning,
stacklevel=2,
)
return (start, end, step)
except:
raise DSLRuntimeError(
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
@ -477,8 +516,8 @@ def range_perf_warning(filename, lineno, *args):
if not has_dynamic_expr:
warnings.warn_explicit(
(
"The loop was previously unrolled in Python, but now it may not unroll in IR. This may cause performance regression."
"If you want to unroll the loop in Python, please use `range_constexpr` instead of `range`."
"This loop is no longer unrolled and may cause performance regression. "
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
),
category=UserWarning,
filename=filename,

View File

@ -102,6 +102,8 @@ class ScopeManager:
return cls([])
def add_to_scope(self, name: str) -> None:
if name == "_":
return
self.scopes[-1].add(name)
def get_active_symbols(self) -> List[Set[str]]:
@ -361,13 +363,13 @@ class DSLPreprocessor(ast.NodeTransformer):
isinstance(func, ast.Name)
and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS
):
return func.id, True
return func.id, True, len(iter_node.keywords) != 0
if (
isinstance(func, ast.Attribute)
and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS
):
return func.attr, False
return None, None
return func.attr, False, len(iter_node.keywords) != 0
return None, None, None
def transform(self, original_function, exec_globals):
"""
@ -378,6 +380,7 @@ class DSLPreprocessor(ast.NodeTransformer):
transformed_tree = self.transform_function(
original_function.__name__, original_function
)
self.function_globals = None
unified_tree = ast.Module(body=transformed_tree, type_ignores=[])
unified_tree = ast.fix_missing_locations(unified_tree)
@ -731,7 +734,7 @@ class DSLPreprocessor(ast.NodeTransformer):
self.scope_manager.add_to_scope(node.target.id)
# For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop.
range_kind, is_builtin_range = self._get_range_kind(node.iter)
range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter)
if range_kind == "range_constexpr" or range_kind == None:
self.generic_visit(node)
if range_kind == "range_constexpr":
@ -752,7 +755,7 @@ class DSLPreprocessor(ast.NodeTransformer):
warnings.simplefilter("default", DeprecationWarning) # reset filter
warning_call = None
if range_kind == "range" and is_builtin_range:
if range_kind == "range" and is_builtin_range and not has_keyword:
# Warn about possible performance regression due to behavior change
warning_call = ast.Expr(
ast.Call(
@ -1109,6 +1112,12 @@ class DSLPreprocessor(ast.NodeTransformer):
self.generic_visit(node)
return node
def visit_Name(self, node):
self.generic_visit(node)
if node.id == "_" and isinstance(node.ctx, ast.Load):
raise DSLAstPreprocessorError("Read '_' is not allowed")
return node
def check_decorator(self, node: ast.AST) -> bool:
"""
Check if the function has the correct decorator for preprocessing.

View File

@ -19,7 +19,9 @@ from typing import Sequence, Optional, Tuple
import os
import sys
import inspect
import argparse
from .common import DSLRuntimeError
from .utils.logger import log
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
@ -182,7 +184,67 @@ class Compiler:
return self.jit(module, opt_level, shared_libs)
class CompileOptions:
def __init__(self, options: str = ""):
"""
This class encapsulates all compilation options relevant to function compilation.
It provides a convenient way to manage and pass compilation options,
particularly for controlling compilation settings.
By centralizing these options, it ensures consistent and flexible configuration of
compilation parameters such as optimization level, debugging control, etc.
:param options: The options for the function. Will be parsed by argparse.
:type options: str
"""
if not isinstance(options, str):
raise DSLRuntimeError(
f"Invalid compilation `options`: {options}, it should be a string"
)
self._parser = argparse.ArgumentParser()
self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
self._parser.add_argument(
"--enable-device-assertions", action="store_true", default=False
)
try:
self._options = self._parser.parse_args(options.split())
except SystemExit as e:
# catch argparse error and raise as DSLRuntimeError
raise DSLRuntimeError(
f"Invalid compile options: '{options}'. Please check the option values and format."
)
log().info("`cute.compile` CompileOptions: options=" + options)
def to_str(self):
"""
Generate a string representation of all compilation options
which will be used in pipeline options.
"""
option_strings = []
for key, value in vars(self._options).items():
hyphen_key = key.replace("_", "-")
if isinstance(value, bool):
formatted_value = "true" if value else "false"
else:
formatted_value = str(value)
option_strings.append(f"{hyphen_key}={formatted_value}")
return " ".join(option_strings)
def compile(func, *args, **kwargs):
"""
This function is used to compile a `cute.jit` decorated function.
It will process the compile options and input parameters, do explicit compilation and return the jit executor.
:param func: The function to compile. It can be a regular function, a method or a class instance.
:param args: The arguments to pass to the function.
:param kwargs: The keyword arguments to pass to the function. It can contain `options` like
`opt_level` to control the compilation flags.
:return: The jit executor.
:raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
"""
if func is None:
raise DSLRuntimeError("Function is not set or invalid.")
@ -217,5 +279,8 @@ def compile(func, *args, **kwargs):
if not hasattr(func, "_dsl_object"):
raise DSLRuntimeError("Function is not decorated with jit decorator.")
# process compile options, extract the options and remove them from the kwargs
options = kwargs.pop("options", "")
func._dsl_object.compile_options = CompileOptions(options)
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
return func._dsl_object._func(fcn_ptr, *args, **kwargs)

View File

@ -38,6 +38,7 @@ import warnings
from . import typing as t
from .env_manager import EnvironmentVarManager
from .compiler import CompileOptions
# =============================================================================
# CUDA Python
@ -232,6 +233,50 @@ def new_from_mlir_values(obj, values):
return obj
class DSLCallable:
"""
Wrapper class for a callable object used within the DSL.
DSLCallable is designed to wrap a function and provide additional
introspection utilities such as retrieving the argument specification
and signature. It ensures that the wrapped function can only be called
once, after which the reference to the function is cleared to prevent
further invocations. This is useful in scenarios where a function should
only be executed a single time within the DSL's execution model.
Attributes:
func (callable): The function to be wrapped and managed.
Methods:
__call__(*args, **kwargs): Calls the wrapped function and clears it.
get_arg_spec(): Returns the argument specification of the function.
get_signature(): Returns the signature of the function.
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
ret = self.__func__(*args, **kwargs)
self.func = None
return ret
@property
def __func__(self):
assert self.func is not None, "DSLCallable is already called"
return self.func
@property
def __name__(self):
return self.__func__.__name__
def get_arg_spec(self):
return inspect.getfullargspec(self.__func__)
def get_signature(self):
return inspect.signature(self.__func__)
class BaseDSL:
gpu_module = None
@ -306,6 +351,8 @@ class BaseDSL:
self.kernel_symbols = []
# used to generate unique name for gpu.launch
self.launch_inner_count = 0
# initialize default compile options
self.compile_options = CompileOptions()
if preprocess:
self.preprocessor = DSLPreprocessor()
@ -392,26 +439,24 @@ class BaseDSL:
if hasattr(func, "_transformed_ast"):
# If the function ptr is already materialized, use the existing one
func._dsl_object.frame = func._decorator_frame
if func._transformed_ast is None:
func._transformed_ast = func._dsl_object.run_preprocessor(func)
if func._transformed_ast is None:
del func._decorator_frame
del func._transformed_ast
func._dsl_object.frame = None
return func
fcn_ptr = func._dsl_object.get_function_ptr(func, func._transformed_ast)
fcn_ptr = func._dsl_object.get_function_ptr(func)
# If the function is decorated, de-decorate it
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
return fcn_ptr
func._dsl_object.frame = None
return DSLCallable(fcn_ptr)
return func
def jit_runner(self, frame, executor, *dargs, **dkwargs):
def jit_runner(self, executor, frame, *dargs, **dkwargs):
"""
Decorator to mark a function for JIT compilation.
"""
# Set the frame, that can be used AST preprocessor
self.frame = frame
log().info("jit_runner")
def jit_runner_decorator(func):
@ -444,7 +489,7 @@ class BaseDSL:
frame = inspect.currentframe().f_back
# Instantiate the DSL Class
main_dsl = cls._get_dsl()
return main_dsl.jit_runner(frame, main_dsl._func, *dargs, **dkwargs)
return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs)
@classmethod
def kernel(cls, *dargs, **dkwargs):
@ -454,7 +499,7 @@ class BaseDSL:
frame = inspect.currentframe().f_back
# Instantiate the DSL Class
main_dsl = cls._get_dsl()
return main_dsl.jit_runner(frame, main_dsl._kernel_helper, *dargs, **dkwargs)
return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs)
@abstractmethod
def _kernel_helper(self, func, *args, **kwargs):
@ -627,6 +672,12 @@ class BaseDSL:
pass
@abstractmethod
def _get_module_globals(self):
"""
Get the module's globals.
"""
pass
def _get_globals(self):
"""
Combines global and local variables from the current context and the
@ -639,7 +690,11 @@ class BaseDSL:
AST preprocessor generates a new python code, so the resulting globals
dictionary is used to execute the python code.
"""
pass
all_globals = self._get_module_globals().copy()
if self.frame:
all_globals.update(self.frame.f_globals)
all_globals.update(self.frame.f_locals)
return all_globals
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
return isinstance(
@ -881,20 +936,15 @@ class BaseDSL:
Get python location information and generate MLIR location
"""
frame = self.frame
if frame is None:
print("Frame is None")
if self.frame is None:
log().debug("Frame is None")
return None
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
file_loc = ir.Location.file(
self.frame.f_code.co_filename, self.frame.f_lineno, 0
)
def print_all_frames():
for i, frame in enumerate(inspect.stack()):
print(
f"Frame {i}: {frame.function} in {frame.filename}, line {frame.lineno}"
)
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc)
return loc
def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
@ -992,6 +1042,8 @@ class BaseDSL:
for attr, value in self.envar.__dict__.items():
if value is not None:
s.write(str(value).encode())
# Add compile options to the hash
s.write(self.compile_options.to_str().encode())
module_hash = self.get_version().copy()
module_hash.update(s.getvalue())
module_hash = module_hash.hexdigest()
@ -1145,6 +1197,8 @@ class BaseDSL:
self.launch_inner_count = 0
# reset num_kernels to 0 for next compilation.
self.num_kernels = 0
# reset the compile options after the compilation is done.
self.compile_options = CompileOptions()
def generate_mlir(
self,
@ -1226,9 +1280,11 @@ class BaseDSL:
return transformed_ast
return None
def get_function_ptr(self, original_function, transformed_ast):
def get_function_ptr(self, original_function):
file_name = inspect.getsourcefile(original_function)
code_object = compile(transformed_ast, filename=file_name, mode="exec")
code_object = compile(
original_function._transformed_ast, filename=file_name, mode="exec"
)
return self.preprocessor.exec(
original_function.__name__,
original_function,
@ -1236,10 +1292,6 @@ class BaseDSL:
self._get_globals(),
)
@lru_cache(maxsize=None)
def _get_function_signature(self, func):
return inspect.signature(func)
def _get_function_bound_args(self, sig, func_name, *args, **kwargs):
"""
Binds provided arguments to a function's signature and applies default values.
@ -1260,12 +1312,11 @@ class BaseDSL:
)
return bound_args
def _canonicalize_args(self, *args, **kwargs):
def _canonicalize_args(self, sig, *args, **kwargs):
"""
Canonicalize the input arguments so that returned args only contain
positional arguments and kwargs only contain keyword arguments.
"""
sig = self._get_function_signature(self.funcBody)
function_name = self.funcBody.__name__
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
canonicalized_args = bound_args.args
@ -1276,8 +1327,11 @@ class BaseDSL:
if not self.funcBody:
raise DSLRuntimeError("Function body is not set.")
# Pass the actual function object to _get_function_signature.
sig = self._get_function_signature(self.funcBody)
# Pass the actual function object to inspect.signature to get the signature.
if isinstance(self.funcBody, DSLCallable):
sig = self.funcBody.get_signature()
else:
sig = inspect.signature(self.funcBody)
function_name = self.funcBody.__name__
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
@ -1292,6 +1346,8 @@ class BaseDSL:
f"Missing required argument in `{function_name}`: '{param.name}'"
)
return sig
def _func(self, funcBody, *args, **kwargs):
"""Decorator for MLIR functions.
It cuts the boilerplate code, does the following:
@ -1324,13 +1380,16 @@ class BaseDSL:
self.print_warning("Cache is disabled as user wants to compile only.")
# Check the number of arguments
self._check_arg_count(*args, **kwargs)
sig = self._check_arg_count(*args, **kwargs)
args_spec = inspect.getfullargspec(funcBody)
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
# Canonicalize the input arguments
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
*args, **kwargs
sig, *args, **kwargs
)
# Simple name mangling
@ -1528,7 +1587,10 @@ class BaseDSL:
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
kernel_name = funcBody.__name__
args_spec = inspect.getfullargspec(funcBody)
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
self.funcBody = funcBody
# Give each kernel a unique name. (The same kernel may be
@ -1568,11 +1630,11 @@ class BaseDSL:
), "kernelGenHelper should be explicitly specified!"
# check arguments
self._check_arg_count(*args, **kwargs)
sig = self._check_arg_count(*args, **kwargs)
# Canonicalize the input arguments
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
*args, **kwargs
sig, *args, **kwargs
)
kernel_operands, kernel_types, kernel_arg_attrs = (

View File

@ -527,7 +527,16 @@ class IntegerMeta(NumericMeta):
return 2**cls.width - 1
def recast_width(cls, width):
return eval(f"Int{width}")
type_map = {
8: Int8,
16: Int16,
32: Int32,
64: Int64,
128: Int128,
}
if width not in type_map:
raise TypeError(f"Unsupported width: {width}")
return type_map[width]
class FloatMeta(NumericMeta):
@ -603,7 +612,14 @@ class FloatMeta(NumericMeta):
return cls._mantissa_width
def recast_width(cls, width):
return eval(f"Float{width}")
type_map = {
16: Float16,
32: Float32,
64: Float64,
}
if width not in type_map:
raise TypeError(f"Unsupported width: {width}")
return type_map[width]
def _arith_signless_to_int(a, target_type):

View File

@ -118,6 +118,9 @@ from .core import (
make_tiled_copy,
make_tiled_copy_S,
make_tiled_copy_D,
make_tiled_copy_A,
make_tiled_copy_B,
make_tiled_copy_C,
make_tiled_copy_C_atom,
basic_copy,
basic_copy_if,

View File

@ -90,6 +90,7 @@ __all__ = [
#
"alloc_smem",
"get_dyn_smem",
"get_dyn_smem_size",
#
# tmem.py
#

View File

@ -26,7 +26,19 @@ from cutlass._mlir.dialects.nvvm import (
RoundingModeKind,
)
from ..typing import Int, Boolean, Int32, Float32, Numeric, as_numeric
from ..typing import (
Int,
Boolean,
Int16,
Uint16,
Int32,
Uint32,
Int64,
Float32,
BFloat16,
Numeric,
as_numeric,
)
WARP_SIZE = 32
FULL_MASK = 0xFFFFFFFF
@ -190,19 +202,97 @@ def shuffle_sync_op(
"""
if not isinstance(value, Numeric):
value = as_numeric(value)
return type(value)(
nvvm.shfl_sync(
type(value).mlir_type,
if value.width > 64:
raise ValueError("shuffle_sync only supports values up to 64 bits")
orig_type = type(value)
if value.width < 32:
if value.dtype.is_float:
value = value.to(Float32)
else:
if value.signed:
value = value.to(Int32)
else:
value = value.to(Uint32)
return orig_type(
nvvm.shfl_sync(
type(value).mlir_type,
Int32(mask).ir_value(loc=loc, ip=ip),
value.ir_value(loc=loc, ip=ip),
Int32(offset).ir_value(loc=loc, ip=ip),
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
kind,
loc=loc,
ip=ip,
)
)
elif value.width == 32:
return orig_type(
nvvm.shfl_sync(
type(value).mlir_type,
Int32(mask).ir_value(loc=loc, ip=ip),
value.ir_value(loc=loc, ip=ip),
Int32(offset).ir_value(loc=loc, ip=ip),
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
kind,
loc=loc,
ip=ip,
)
)
else:
if value.width != 64:
raise ValueError(
"shuffle_sync only supports 64 bits values when the bit width is larger than 32"
)
value = llvm.bitcast(
T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip
)
# extract low 32 bits
low_32_bits = llvm.trunc(
T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip
)
# extract high 32 bits
high_32_bits = llvm.lshr(
value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
)
high_32_bits = llvm.trunc(
T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip
)
low_32_bits_shfl = nvvm.shfl_sync(
T.i32(),
Int32(mask).ir_value(loc=loc, ip=ip),
value.ir_value(loc=loc, ip=ip),
low_32_bits,
Int32(offset).ir_value(loc=loc, ip=ip),
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
kind,
loc=loc,
ip=ip,
)
high_32_bits_shfl = nvvm.shfl_sync(
T.i32(),
Int32(mask).ir_value(loc=loc, ip=ip),
high_32_bits,
Int32(offset).ir_value(loc=loc, ip=ip),
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
kind,
loc=loc,
ip=ip,
)
)
# combine low and high 32 bits
low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip)
high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip)
shlf_res = llvm.shl(
high_64_bit,
Int64(32).ir_value(loc=loc, ip=ip),
llvm.IntegerOverflowFlags.none,
loc=loc,
ip=ip,
)
shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip)
shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip)
return orig_type(shlf_res)
shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx)
shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up)

View File

@ -94,3 +94,15 @@ def get_dyn_smem(
alignment,
)
return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip)
@dsl_user_op
def get_dyn_smem_size(*, loc=None, ip=None) -> int:
"""
Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time.
This can be used for bounds checking during shared memory allocation.
:return: The size of dynamic shared memory in bytes
:rtype: int
"""
return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip)

View File

@ -31,6 +31,7 @@ from typing import (
Optional,
)
from enum import Enum, auto
from typing_extensions import deprecated
from cutlass.cutlass_dsl import (
const,
@ -517,10 +518,14 @@ class ScaledBasis:
sb3 = ScaledBasis(4, [0, 1]) # 4 * E([0, 1])
# Scaled basis elements are commonly used in layout strides
layout = make_layout((4, 8), stride=(ScaledBasis(1, 0), ScaledBasis(1, 1)))
layout = make_layout((4, 8), stride=(ScaledBasis(2, 0), ScaledBasis(1, 1)))
# This creates a layout with strides (1@0, 1@1) representing
# This creates a layout with strides (2@0, 1@1) representing
# a coordinate system where each dimension has its own basis
# Example: Mapping coordinates to indices using the layout
coord = (2, 3)
idx = crd2idx(coord, layout) # Maps (2, 3) to (4, 3)
"""
def __init__(self, value, mode) -> None:
@ -712,8 +717,9 @@ class Swizzle(ir.Value):
e.g. Given
0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
the result is
0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ `xor` YY
"""
@ -897,7 +903,7 @@ class _Layout(Layout):
@ir.register_value_caster(_cute_ir.ComposedLayoutType.get_static_typeid(), replace=True)
class ComposedLayout(ir.Value):
"""ComposedLayout represents the functional composition of layouts in CuTe.
r"""ComposedLayout represents the functional composition of layouts in CuTe.
A ComposedLayout is formed by the composition of three components:
inner o offset o outer, where:
@ -907,7 +913,10 @@ class ComposedLayout(ir.Value):
- outer: The outer layout that is applied first
ComposedLayout implements the functional composition operation where:
R(c) := (inner o offset o outer)(c) := inner(offset + outer(c))
.. math::
R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c))
This composition allows for complex transformations of coordinates and indices,
enabling operations like tiling, partitioning, and reshaping of data.
@ -1670,7 +1679,10 @@ def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None):
:type ip: insertion pointer, optional
:raises NotImplementedError: If the tensor type doesn't support trivial dereferencing
Example output:
**Example output:**
.. code-block:: text
tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data=
[[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042],
[-0.8462, 0.9871, 0.4389, 0.7298, 0.6948],
@ -1973,7 +1985,8 @@ def find_if(
:return: Index if found at top level, tuple of indices showing nested position, or None if not found
:rtype: Union[int, Tuple[int, ...], None]
Examples:
**Examples:**
.. code-block:: python
# Find the first position of x in t
@ -2186,6 +2199,23 @@ def is_congruent(
) -> bool:
"""
Returns whether a is congruent to b.
Congruence is an equivalence relation between hierarchical structures.
Two objects are congruent if:
* They have the same rank, AND
* They are both non-tuple values, OR
* They are both tuples AND all corresponding elements are congruent.
Congruence requires type matching at each level -- scalar values match with
scalar values, and tuples match with tuples of the same rank.
:param a: First object to compare
:type a: Union[XTuple, Layout, ComposedLayout, Tensor]
:param b: Second object to compare
:type b: Union[XTuple, Layout, ComposedLayout, Tensor]
:return: True if a and b are congruent, False otherwise
:rtype: bool
"""
if isinstance(a, (Layout, ComposedLayout, Tensor)):
a = a.shape
@ -2204,6 +2234,22 @@ def is_weakly_congruent(
) -> bool:
"""
Returns whether a is weakly congruent to b.
Weak congruence is a partial order on hierarchical structures.
Object X is weakly congruent to object Y if:
* X is a non-tuple value, OR
* X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent.
Weak congruence allows scalar values to match with tuples, making it useful
for determining whether an object has a hierarchical structure "up to" another.
:param a: First object to compare
:type a: Union[XTuple, Layout, ComposedLayout, Tensor]
:param b: Second object to compare
:type b: Union[XTuple, Layout, ComposedLayout, Tensor]
:return: True if a and b are weakly congruent, False otherwise
:rtype: bool
"""
if isinstance(a, (Layout, ComposedLayout, Tensor)):
a = a.shape
@ -2261,8 +2307,11 @@ def get(input, mode: List[int], *, loc=None, ip=None):
**Examples:**
For a layout like ((4,8),2):((16,1),8), get with mode=[0,1] would extract
the element 8 from the shape component.
.. code-block:: python
layout = make_layout(((4, 8), (16, 1), 8), stride=((1, 4), (32, 0), 512))
sub_layout = get(layout, mode=[0, 1]) # 8:4
sub_layout = get(layout, mode=[1]) # (16, 1):(32, 0)
"""
# Empty mode returns input and terminates the recursive call
if not mode:
@ -5065,6 +5114,11 @@ def make_layout_tv(
* 2 elements per thread
"""
if not isinstance(thr_layout, Layout):
raise TypeError(f"expected a Layout for thr_layout, but got {type(thr_layout)}")
if not isinstance(val_layout, Layout):
raise TypeError(f"expected a Layout for val_layout, but got {type(val_layout)}")
# Take the raked_products to compute the Layout_MN
# (M,N) -> (thr_idx, val_idx)
layout_mn = raked_product(thr_layout, val_layout, loc=loc, ip=ip)
@ -5081,8 +5135,52 @@ def make_layout_tv(
return (tiler_mn, layout_tv)
def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
if type(tiler_mn) is tuple:
tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip)
assert isinstance(tiler_mn, ir.Value) and _cute_ir.TileType.isinstance(
tiler_mn.type
), f"tiler_mn must be a Tile, but got {type(tiler_mn)}"
assert is_static(layout_tv.type) and is_static(
tiler_mn.type
), "layout tv and tiler mn must be static"
tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get(
atom.type, layout_tv.type, tiler_mn.type
)
val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip)
# Instead of modifying atom which might have been provided by the user, create a brand new
# trait instance and replace the Atom ir.Value with the tiled one
trait = new_from_mlir_values(atom._trait, [val])
return TiledCopy(atom.op, trait)
@deprecated("Use make_tiled_copy_tv instead")
def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
"""Create a tiled type given a TV partitioner and tiler.
:param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc.
:type atom: CopyAtom
:param layout_tv: Thread-value layout
:type layout_tv: Layout
:param tiler_mn: Tile size
:type tiler_mn: Tiler
:param loc: Source location for MLIR, defaults to None
:type loc: Optional[Location], optional
:param ip: Insertion point, defaults to None
:type ip: Optional[InsertionPoint], optional
:return: A tiled copy for the partitioner
:rtype: TiledCopy
"""
return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
@dsl_user_op
def make_tiled_copy_tv(atom, thr_layout, val_layout, *, loc=None, ip=None) -> TiledCopy:
def make_tiled_copy_tv(
atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None
) -> TiledCopy:
"""Create a tiled copy given separate thread and value layouts.
A TV partitioner is inferred based on the input layouts. The input thread layout
@ -5105,30 +5203,17 @@ def make_tiled_copy_tv(atom, thr_layout, val_layout, *, loc=None, ip=None) -> Ti
tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip)
tiler_mn = _pack_tile(product_each(tiler_mn, loc=loc, ip=ip), loc=loc, ip=ip)
if not is_static(layout_tv.type) or not is_static(tiler_mn.type):
raise ValueError(
f"expects layout tv and tiler mn, but got {layout_tv.type} and {tiler_mn.type}"
)
tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get(
atom.type, layout_tv.type, tiler_mn.type
)
val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip)
# Instead of modifying atom which might have been provided by the user, create a brand new
# trait instance and replace the Atom ir.Value with the tiled one
trait = new_from_mlir_values(atom._trait, [val])
return TiledCopy(atom.op, trait)
return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
@dsl_user_op
def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
"""Create a tiled type given a TV partitioner and tiler.
def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None):
"""Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma.
:param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc.
:param atom: Copy atom
:type atom: CopyAtom
:param layout_tv: Thread-value layout
:type layout_tv: Layout
:param tiler_mn: Tile size
:type tiler_mn: Tiler
:param tiled_mma: Tiled MMA
:type tiled_mma: TiledMma
:param loc: Source location for MLIR, defaults to None
:type loc: Optional[Location], optional
:param ip: Insertion point, defaults to None
@ -5138,21 +5223,65 @@ def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
:rtype: TiledCopy
"""
# tiler_mn = pack_tuple(tiler_mn, make_tile)
if type(tiler_mn) is tuple:
tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip)
assert is_static(layout_tv.type) and is_static(
tiler_mn.type
), "layout tv and tiler mn must be static"
tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get(
atom.type, layout_tv.type, tiler_mn.type
return _make_tiled_copy(
atom,
tiled_mma.tv_layout_A_tiled,
(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
loc=loc,
ip=ip,
)
@dsl_user_op
def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None):
"""Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma.
:param atom: Copy atom
:type atom: CopyAtom
:param tiled_mma: Tiled MMA
:type tiled_mma: TiledMma
:param loc: Source location for MLIR, defaults to None
:type loc: Optional[Location], optional
:param ip: Insertion point, defaults to None
:type ip: Optional[InsertionPoint], optional
:return: A tiled copy for the partitioner
:rtype: TiledCopy
"""
return _make_tiled_copy(
atom,
tiled_mma.tv_layout_B_tiled,
(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
loc=loc,
ip=ip,
)
@dsl_user_op
def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None):
"""Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma.
:param atom: Copy atom
:type atom: CopyAtom
:param tiled_mma: Tiled MMA
:type tiled_mma: TiledMma
:param loc: Source location for MLIR, defaults to None
:type loc: Optional[Location], optional
:param ip: Insertion point, defaults to None
:type ip: Optional[InsertionPoint], optional
:return: A tiled copy for the partitioner
:rtype: TiledCopy
"""
return _make_tiled_copy(
atom,
tiled_mma.tv_layout_C_tiled,
(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)),
loc=loc,
ip=ip,
)
val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip)
# Instead of modifying atom which might have been provided by the user, create a brand new
# trait instance and replace the Atom ir.Value with the tiled one
trait = new_from_mlir_values(atom._trait, [val])
return TiledCopy(atom.op, trait)
@dsl_user_op
@ -5172,7 +5301,7 @@ def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None):
:rtype: TiledCopy
"""
return make_tiled_copy(
return _make_tiled_copy(
atom, tiled_copy.layout_src_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip
)
@ -5194,7 +5323,7 @@ def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None):
:rtype: TiledCopy
"""
return make_tiled_copy(
return _make_tiled_copy(
atom, tiled_copy.layout_dst_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip
)
@ -5273,7 +5402,7 @@ def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None):
tiler_mn = _pack_tile(tiler, loc=loc, ip=ip)
return make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
####################################################################################################
@ -5297,7 +5426,7 @@ def gemm(
) -> None:
"""The GEMM algorithm.
Computes ``D <- AB + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g.
Computes ``D <- A * B + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g.
warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field.
All tensors must be partitioned according to the provided MMA Atom.
@ -6416,7 +6545,8 @@ class struct:
"""
Decorator to abstract C structure in Python DSL.
Usage:
**Usage:**
.. code-block::
# Supports base_dsl scalar int/float elements, array and nested struct:
@ -6424,12 +6554,15 @@ class struct:
class complex:
real : cutlass.Float32
imag : cutlass.Float32
@cute.struct
class StorageA:
mbarA : cute.struct.MemRange[cutlass.Int64, stage]
compA : complex
intA : cutlass.Int16
# Supports aligment for its elements:
@cute.struct
class StorageB:
@ -6442,6 +6575,7 @@ class struct:
x: cute.struct.Align[cutlass.Int32, 16]
compA: cute.struct.Align[complex, 16]
# Statically get size and alignment:
size = StorageB.__sizeof__()
align = StorageB.__alignof__()

View File

@ -94,12 +94,6 @@ def make_tiled_tma_atom(
ip=ip,
)
# Wrap smem_layout in a composed layout to make it a TMA-friendly layout
if isinstance(smem_layout, Layout):
smem_layout = core.make_composed_layout(
core.make_swizzle(0, 4, 3), 0, smem_layout
)
if isinstance(op, CopyBulkTensorTileG2SOp):
if num_multicast != 1:
raise ValueError(

View File

@ -37,7 +37,7 @@ from .cpasync.copy import (
def make_tiled_tma_atom_A(
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
gmem_tensor: Tensor,
smem_layout: Layout,
smem_layout: Union[Layout, core.ComposedLayout],
mma_tiler_mnk: Shape,
tiled_mma: core.TiledMma,
cluster_shape_vmnk: Shape,
@ -76,7 +76,7 @@ def make_tiled_tma_atom_A(
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
:type gmem_tensor: Tensor
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
:type smem_layout: Layout
:type smem_layout: Union[Layout, core.ComposedLayout]
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
:type mma_tiler_mnk: Shape
:param tiled_mma: The TiledMMA that will consume the load as operands
@ -142,7 +142,7 @@ def make_tiled_tma_atom_A(
def make_tiled_tma_atom_B(
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
gmem_tensor: Tensor,
smem_layout: Layout,
smem_layout: Union[Layout, core.ComposedLayout],
mma_tiler_mnk: Shape,
tiled_mma: core.TiledMma,
cluster_shape_vmnk: Shape,
@ -181,7 +181,7 @@ def make_tiled_tma_atom_B(
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
:type gmem_tensor: Tensor
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
:type smem_layout: Layout
:type smem_layout: Union[Layout, core.ComposedLayout]
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
:type mma_tiler_mnk: Shape
:param tiled_mma: The TiledMMA that will consume the load as operands

View File

@ -42,6 +42,9 @@ __all__ = [
"MmaF16BF16Op",
"MmaI8Op",
"MmaFP8Op",
"MmaMXF8Op",
"MmaMXF4Op",
"MmaMXF4NVF4Op",
"SmemLayoutAtomKind",
#
# helpers.py
@ -54,4 +57,6 @@ __all__ = [
"get_tmem_copy_properties",
"find_tmem_tensor_col_offset",
"make_tmem_copy",
"make_s2t_copy",
"get_s2t_smem_desc_tensor",
]

View File

@ -23,6 +23,8 @@ from ..common import OpError
from ...core import CopyOp, Trait
from ...typing import Numeric
from .mma import CtaGroup
class Repetition(enum.Enum):
"""
@ -469,3 +471,193 @@ class St32x32bOp(_StBase):
class St32x32bTrait(Trait):
pass
@dataclass(frozen=True)
class _S2TCopyBase(CopyOp):
cta_group: CtaGroup
admissible_archs = [
"sm_100a",
"sm_100f",
]
def __post_init__(self) -> None:
# Arch verification
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
# Verify that the user provided enum values
if not isinstance(self.cta_group, CtaGroup):
raise OpError(
self,
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
)
def __str__(self) -> str:
res = (
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
+ f"\n CTA group = {self.cta_group}"
)
return res
@dataclass(frozen=True)
class Cp128x256bOp(_S2TCopyBase):
"""
128x256b SMEM to TMEM Copy Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
This Operation corresponds to the ``.128x256b`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Cp128x256bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
copy_internal_type.mlir_type,
128,
256,
self.cta_group.value,
_cute_nvgpu_ir.CopyS2TBroadcast.none,
)
return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Cp128x256bTrait(Trait):
pass
@dataclass(frozen=True)
class Cp128x128bOp(_S2TCopyBase):
"""
128x128b SMEM to TMEM Copy Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
This Operation corresponds to the ``.128x128b`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Cp128x128bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
copy_internal_type.mlir_type,
128,
128,
self.cta_group.value,
_cute_nvgpu_ir.CopyS2TBroadcast.none,
)
return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Cp128x128bTrait(Trait):
pass
@dataclass(frozen=True)
class Cp4x256bOp(_S2TCopyBase):
"""
4x256b SMEM to TMEM Copy Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
This Operation corresponds to the ``.4x256b`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Cp4x256bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
copy_internal_type.mlir_type,
4,
256,
self.cta_group.value,
_cute_nvgpu_ir.CopyS2TBroadcast.none,
)
return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Cp4x256bTrait(Trait):
pass
@dataclass(frozen=True)
class Cp4x32x128bOp(_S2TCopyBase):
"""
32x128b SMEM to TMEM Copy Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Cp4x32x128bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
copy_internal_type.mlir_type,
32,
128,
self.cta_group.value,
_cute_nvgpu_ir.CopyS2TBroadcast.x4,
)
return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Cp4x32x128bTrait(Trait):
pass
@dataclass(frozen=True)
class Cp2x64x128b0213Op(_S2TCopyBase):
"""
64x128b SMEM to TMEM Copy Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Cp2x64x128b0213Trait":
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
copy_internal_type.mlir_type,
64,
128,
self.cta_group.value,
_cute_nvgpu_ir.CopyS2TBroadcast.lw_0213,
)
return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Cp2x64x128b0213Trait(Trait):
pass
@dataclass(frozen=True)
class Cp2x64x128b0123Op(_S2TCopyBase):
"""
64x128b SMEM to TMEM Copy Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Cp2x64x128b0123Trait":
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
copy_internal_type.mlir_type,
64,
128,
self.cta_group.value,
_cute_nvgpu_ir.CopyS2TBroadcast.lw_0123,
)
return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Cp2x64x128b0123Trait(Trait):
pass

View File

@ -299,3 +299,30 @@ def make_tmem_copy(
)
new_trait = type(atom._trait)(tiled_copy_val)
return core.TiledCopy(atom.op, new_trait)
@dsl_user_op
def make_s2t_copy(
atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None
) -> core.TiledCopy:
"""
Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor.
"""
tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy(
atom._trait.value, tmem_tensor.value, loc=loc, ip=ip
)
new_trait = type(atom._trait)(tiled_copy_val)
return core.TiledCopy(atom.op, new_trait)
@dsl_user_op
def get_s2t_smem_desc_tensor(
atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None
) -> Tensor:
"""
Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor.
"""
smem_desc_tensor = _cute_nvgpu_ir.atom_get_copy_s2t_smem_desc_view(
atom._trait.value, smem_tensor.value, loc=loc, ip=ip
)
return smem_desc_tensor

View File

@ -20,9 +20,12 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ..common import OpError
from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor
from ... import core
from ...core import Trait, _pack_shape, rank, depth, _Tensor
from ...typing import (
Shape,
Float4E2M1FN,
Float8E8M0FNU,
Float8E5M2,
Float8E4M3FN,
Float16,
@ -35,6 +38,7 @@ from ...typing import (
Int32,
Numeric,
AddressSpace,
Pointer,
)
@ -104,7 +108,6 @@ class CtaGroup(enum.Enum):
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
class Field(enum.Enum):
"""
An enumeration for the fields of the MMA Atom that can be modified at runtime.
@ -113,6 +116,8 @@ class Field(enum.Enum):
NEGATE_A = "neg_a"
NEGATE_B = "neg_b"
ACCUMULATE = "accum_c"
SFA = "sf_a"
SFB = "sf_b"
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
@ -124,9 +129,9 @@ class Field(enum.Enum):
return self.value
# Base class for all tcgen05 MMA Ops used to factor out some internal code
# Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code
@dataclass(frozen=True)
class MmaOp(MmaOp):
class MmaOp(core.MmaOp):
a_dtype: Type[Numeric]
b_dtype: Type[Numeric]
acc_dtype: Type[Numeric]
@ -256,6 +261,155 @@ class MmaTrait(Trait):
)
# Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code
@dataclass(frozen=True)
class BlockScaledMmaOp(core.MmaOp):
a_dtype: Type[Numeric]
b_dtype: Type[Numeric]
acc_dtype: Float32
sf_dtype: Type[Numeric]
sf_vec_size: int
shape_mnk: Shape
cta_group: CtaGroup
a_src: OperandSource
a_major_mode: OperandMajorMode
b_major_mode: OperandMajorMode
admissible_archs = [
"sm_100a",
]
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
# Verify that the user provided enum values
if not isinstance(self.cta_group, CtaGroup):
raise OpError(
self,
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
)
if not isinstance(self.a_src, OperandSource):
raise OpError(
self,
"expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance",
)
if not isinstance(self.a_major_mode, OperandMajorMode):
raise OpError(
self,
"expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
)
if not isinstance(self.b_major_mode, OperandMajorMode):
raise OpError(
self,
"expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
)
# Verify the instruction shape
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
raise OpError(
self,
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
f"but got {self.shape_mnk}",
)
m, n = self.shape_mnk[0], self.shape_mnk[1]
if self.cta_group == CtaGroup.ONE:
if m != 128:
raise OpError(self, f"expects the M-mode to be 128, but got {m}")
if (n < 8) or (n > 256) or (n % 8 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
)
else:
if m not in [128, 256]:
raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
if (n < 16) or (n > 256) or (n % 16 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
)
if self.sf_vec_size not in [16, 32]:
raise OpError(
self,
f"expects the scale factor vector size to be 16 or 32, but got {self.sf_vec_size}",
)
def __str__(self) -> str:
return (
self.__class__.descriptive_name # type: ignore
+ f"\n A data type = {self.a_dtype}"
+ f"\n B data type = {self.b_dtype}"
+ f"\n Accumulator data type = {self.acc_dtype}"
+ f"\n Scale factor data type = {self.sf_dtype}"
+ f"\n Scale factor vector size = {self.sf_vec_size}"
+ f"\n CTA group = {self.cta_group}"
+ f"\n A source location = {self.a_src}"
+ f"\n A major mode = {self.a_major_mode}"
+ f"\n B major mode = {self.b_major_mode}"
+ f"\n Instruction shape MNK = {self.shape_mnk}"
)
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
if input.memspace == AddressSpace.smem and isinstance(
input.layout.type, _cute_ir.ComposedLayoutType
):
raise OpError(
self,
f"Expected affine layout for {self._make_trait()}'s operand A, "
f"but got composed layout instead: {input.layout}"
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
)
return True
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
if input.memspace == AddressSpace.smem and isinstance(
input.layout.type, _cute_ir.ComposedLayoutType
):
raise OpError(
self,
f"Expected affine layout for {self._make_trait()}'s operand B, "
f"but got composed layout instead: {input.layout}"
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
)
return True
class BlockScaledMmaTraits(Trait):
admissible_fields = [
Field.ACCUMULATE,
Field.NEGATE_A,
Field.NEGATE_B,
Field.SFA,
Field.SFB,
]
def set(self, field, value, *, loc=None, ip=None) -> None:
if field not in self.admissible_fields:
raise ValueError(
f"expects field to be one of {self.admissible_fields}, but got {field}"
)
if field in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]:
value = Boolean(value).ir_value(loc=loc, ip=ip)
elif field in [Field.SFA, Field.SFB]:
if not isinstance(value, Pointer):
raise ValueError(
f"expects value to be a pointer for {field}, but got {type(value).__name__}"
)
value = value.value
field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>"
attr = ir.Attribute.parse(field_name)
self.value = _cute_nvgpu_ir.atom_set_value(
self.value, attr, value, loc=loc, ip=ip
)
#
# TF32 MMA
#
@ -602,6 +756,262 @@ class MmaFP8Trait(MmaTrait):
pass
#
# MXF8F6F4 MMA
#
@dataclass(frozen=True)
class MmaMXF8Op(BlockScaledMmaOp):
"""
MXF8 tcgen05 BlockScaled MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier.
"""
descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation"
def __init__(
self,
ab_dtype: Type[Numeric],
instruction_shape: Shape,
cta_group: CtaGroup,
a_src: OperandSource,
a_major_mode: OperandMajorMode,
b_major_mode: OperandMajorMode,
) -> None:
super().__init__(
ab_dtype,
ab_dtype,
Float32,
Float8E8M0FNU,
32,
instruction_shape,
cta_group,
a_src,
a_major_mode,
b_major_mode,
)
self._verify()
def _verify(self) -> None:
# Input data type verification
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
raise OpError(
self,
"expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
)
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
# Instruction shape verification
instruction_k = 32
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
shape_mnk.type.attribute,
self.cta_group.value,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.sf_dtype.mlir_type,
self.a_src._to_ir(),
self.sf_vec_size,
)
return MmaMXF8Trait(
_cute_nvgpu_ir.make_sm100_mma_bs(
ty,
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
loc=loc,
ip=ip,
)
)
class MmaMXF8Trait(BlockScaledMmaTraits):
pass
#
# MXF4 MMA
#
@dataclass(frozen=True)
class MmaMXF4Op(BlockScaledMmaOp):
"""
MXF4 tcgen05 BlockScaled MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
This Operation corresponds to the ``.kind::mxf4`` qualifier.
"""
descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation"
def __init__(
self,
instruction_shape: Shape,
cta_group: CtaGroup,
a_src: OperandSource,
) -> None:
super().__init__(
Float4E2M1FN,
Float4E2M1FN,
Float32,
Float8E8M0FNU,
32,
instruction_shape,
cta_group,
a_src,
OperandMajorMode.K,
OperandMajorMode.K,
)
self._verify()
def _verify(self) -> None:
# Instruction shape verification
instruction_k = 64
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
shape_mnk.type.attribute,
self.cta_group.value,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.sf_dtype.mlir_type,
self.a_src._to_ir(),
self.sf_vec_size,
)
return MmaMXF4Trait(
_cute_nvgpu_ir.make_sm100_mma_bs(
ty,
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
loc=loc,
ip=ip,
)
)
class MmaMXF4Trait(BlockScaledMmaTraits):
pass
#
# MXF4NVF4 MMA
#
@dataclass(frozen=True)
class MmaMXF4NVF4Op(BlockScaledMmaOp):
"""
MXF4NVF4 tcgen05 BlockScaled MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier.
"""
descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation"
def __init__(
self,
sf_dtype: Type[Numeric],
instruction_shape: Shape,
cta_group: CtaGroup,
a_src: OperandSource,
) -> None:
super().__init__(
Float4E2M1FN,
Float4E2M1FN,
Float32,
sf_dtype,
16,
instruction_shape,
cta_group,
a_src,
OperandMajorMode.K,
OperandMajorMode.K,
)
self._verify()
def _verify(self) -> None:
# Scale Factor data type verification
if self.sf_dtype not in [Float8E8M0FNU, Float8E4M3FN]:
raise OpError(
self,
"expects the 'sf_dtype' Op parameter to be one of Float8E8M0FNU",
)
# Instruction shape verification
instruction_k = 64
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
shape_mnk.type.attribute,
self.cta_group.value,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.sf_dtype.mlir_type,
self.a_src._to_ir(),
self.sf_vec_size,
)
return MmaMXF4NVF4Trait(
_cute_nvgpu_ir.make_sm100_mma_bs(
ty,
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
loc=loc,
ip=ip,
)
)
class MmaMXF4NVF4Trait(BlockScaledMmaTraits):
pass
####################################################################################################
#
# SMEM layout atoms

View File

@ -28,11 +28,12 @@ import cutlass.base_dsl.jit_executor
import cutlass.cute as cute
from cutlass._mlir.dialects import builtin, cf, nvvm, vector
from cutlass.cute import core, nvgpu
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op
def assert_(cond, msg=None):
cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "")
@dsl_user_op
def assert_(cond, msg=None, *, loc=None, ip=None):
cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip)
def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout):
@ -214,7 +215,14 @@ def convert(src: core.Tensor, dst: core.Tensor):
dst.shape
), "Shape of src and dst tensors should be the same rank."
# find leading mode
leading_mode = np.argmin([np.min(s) for s in src.stride])
leading_mode = [
idx
for idx, (shape, stride) in enumerate(zip(src.shape, src.stride))
if shape > 1 and stride == 1
]
if len(leading_mode) != 1:
raise ValueError(f"Leading mode should be unique, but got {leading_mode}")
leading_mode = leading_mode[0]
elem_per_copy = 2
@ -345,7 +353,7 @@ def benchmark(
callable: Callable,
*,
warmup_iterations: int = 10,
profiling_iterations: int = 100,
iterations: int = 100,
stream: Optional[cuda_driver.CUstream] = None,
kernel_arguments: Optional[JitArguments] = None,
workspace_generator: Optional[Callable[[], JitArguments]] = None,
@ -365,7 +373,7 @@ def benchmark(
pass
time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream)
warmup_iterations=10, profiling_iterations=100
warmup_iterations=10, iterations=100
stream=stream)
To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator
@ -388,7 +396,7 @@ def benchmark(
workspace_generator=workspace_generator,
workspace_count=10,
warmup_iterations=10000,
profiling_iterations=1000)
iterations=1000)
To benchmark you may always configure the function being profiled (callable), the warmup iterations, and
the number of profiling iterations.
@ -402,8 +410,8 @@ def benchmark(
:type callable: Callable
:param warmup_iterations: Number of warmup iterations, defaults to 10
:type warmup_iterations: int, optional
:param profiling_iterations: Number of benchmark iterations, defaults to 100
:type profiling_iterations: int, optional
:param iterations: Number of benchmark iterations, defaults to 100
:type iterations: int, optional
:param stream: Stream kernel is launched in, defaults to CUDA stream default
:type stream: CUstream, None
:param kernel_arguments: Kernel arguments to launch callable with, defaults to None
@ -502,7 +510,7 @@ def benchmark(
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
)
_cuda_success(err, "Error on stream capture")
_loop_and_call_kernel(profiling_iterations, workspace_index)
_loop_and_call_kernel(iterations, workspace_index)
err, gprofile = cuda_runtime.cudaStreamEndCapture(stream)
_cuda_success(err, "Error on stream capture")
@ -557,7 +565,7 @@ def benchmark(
# Record start event
err = cuda_driver.cuEventRecord(start_event, stream)
_cuda_success(err, "Error on recording event")
_loop_and_call_kernel(profiling_iterations, workspace_index)
_loop_and_call_kernel(iterations, workspace_index)
# Record end event
err = cuda_driver.cuEventRecord(end_event, stream)
_cuda_success(err, "Error on recording event")
@ -573,6 +581,30 @@ def benchmark(
err = cuda_driver.cuEventDestroy(end_event)
_cuda_success(err, "Error on destroying event")
return elapsed_time / profiling_iterations * 1e3
return elapsed_time / iterations * 1e3
def get_workspace_count(
one_workspace_bytes: int, warmup_iterations: int, iterations: int
) -> int:
"""Calculate the number of workspaces needed to fill L2 cache.
:param one_workspace_bytes: Size of one workspace in bytes
:type one_workspace_bytes: int
:param warmup_iterations: Number of warmup iterations
:type warmup_iterations: int
:param iterations: Number of iterations
:type iterations: int
:return: Number of workspaces needed
:rtype: int
"""
num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes()
return max(
1,
min(
warmup_iterations + iterations, # Don't create more workspaces than needed
(num_l2_cache_bytes + one_workspace_bytes - 1)
// one_workspace_bytes, # Ceiling division
),
)

View File

@ -91,33 +91,6 @@ class PipelineAsync:
- D: Data ready (producer has written data to buffer)
- R: Consumer reading (consumer is consuming data from buffer)
**Example:**
.. code-block:: python
# Create pipeline with 5 stages
pipeline = PipelineAsync.create(
num_stages=5, # number of pipeline stages
producer_group=producer_warp,
consumer_group=consumer_warp
barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory
)
# Producer side
producer = pipeline.make_pipeline_producer(producer_warp)
for i in range(num_iterations):
producer.acquire() # Wait for buffer to be empty
# Write data to pipeline buffer
producer.commit() # Signal buffer is full
producer.advance() # Move index to next stage
# Consumer side
consumer = pipeline.make_pipeline_consumer(consumer_warp)
for i in range(num_iterations):
consumer.wait() # Wait for buffer to be full
# Read data from pipeline buffer
consumer.release() # Signal buffer is empty
consumer.advance() # Move index to next stage
"""
sync_object_full: SyncObject
@ -259,16 +232,6 @@ class PipelineAsync:
state.advance()
self.producer_acquire(state)
# Util methods to manage produer and consumer
def make_pipeline_producer(self, group: CooperativeGroup):
state = make_pipeline_state(PipelineUserType.Producer, self.num_stages)
return PipelineProducer(self, state, group)
def make_pipeline_consumer(self, group: CooperativeGroup):
state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages)
return PipelineConsumer(self, state, group)
@dataclass(frozen=True)
class PipelineTmaAsync(PipelineAsync):
"""
@ -593,211 +556,3 @@ class PipelineTmaStore(PipelineAsync):
self.sync_object_full.tail()
#################################################################
# Utilities to help user of pipeline to simplify the workflow
#################################################################
class PipelineProducer:
"""A class representing a producer in an asynchronous pipeline.
The Producer class manages the producer side of an asynchronous pipeline, handling
synchronization and state management for producing data. It provides methods for
acquiring, committing, and advancing through pipeline stages.
:ivar _pipeline: The asynchronous pipeline this producer belongs to
:type _pipeline: PipelineAsync
:ivar _state: The current state of the producer in the pipeline
:type _state: PipelineState
:ivar _group: The cooperative group this producer operates in
:type _group: CooperativeGroup
**Examples:**
.. code-block:: python
pipeline = PipelineAsync.create(...)
producer = pipeline.create_producer(producer_group, stages)
for i in range(iterations):
producer.acquire() # Wait for buffer to be empty
# Produce data
producer.commit() # Signal data is ready
producer.advance() # Move to next stage
"""
_pipeline: PipelineAsync
_state: PipelineState
_group: CooperativeGroup
def __init__(self, pipeline, state, group: CooperativeGroup):
"""Initialize a new Producer instance.
:param pipeline: The pipeline this producer belongs to
:type pipeline: PipelineAsync
:param state: Initial pipeline state
:type state: PipelineState
:param group: The cooperative group for synchronization
:type group: CooperativeGroup
"""
self._pipeline = pipeline
self._state = state
self._group = group
@property
def index(self):
"""Get the index of the current pipeline stage."""
return self._state.index
def get_barrier(self):
"""Get the barrier pointer for the current pipeline stage.
:return: Pointer to the barrier for the current stage
:rtype: cute.Pointer
"""
return self._pipeline.producer_get_barrier(self._state)
def acquire(self):
"""Wait for the current buffer to be empty before producing data.
This is a blocking operation.
"""
self._pipeline.producer_acquire(self._state)
def try_acquire(self):
"""Try to acquire the current buffer without blocking.
:return: True if acquisition was successful, False otherwise
:rtype: bool
"""
self._pipeline.producer_try_acquire(self._state)
def commit(self):
"""Signal that data production is complete for the current stage.
This allows consumers to start processing the data.
"""
self._pipeline.producer_commit(self._state)
def tail(self):
"""Ensure all used buffers are properly synchronized before producer exit.
This should be called before the producer finishes to avoid dangling signals.
"""
self._pipeline.producer_tail(self._state)
def advance(self):
"""Move to the next pipeline stage."""
self._state.advance()
def __extract_mlir_values__(self):
"""Extract MLIR values from the current state.
:return: List of MLIR values representing the current state
:rtype: list
"""
# TODO: need to handle pipeline as well
return self._state.__extract_mlir_values__()
def __new_from_mlir_values__(self, values):
"""Create a new Producer instance from MLIR values.
:param values: MLIR values to initialize the state
:type values: Any
:return: New Producer instance with state initialized from values
:rtype: Producer
"""
return PipelineProducer(
self._pipeline, self._state.__new_from_mlir_values__(values), self._group
)
class PipelineConsumer:
"""A class representing a consumer in an asynchronous pipeline.
The Consumer class manages the consumer side of an asynchronous pipeline, handling
synchronization and state management for consuming data. It provides methods for
waiting, releasing, and advancing through pipeline stages.
:ivar _pipeline: The asynchronous pipeline this consumer belongs to
:type _pipeline: PipelineAsync
:ivar _state: The current state of the consumer in the pipeline
:type _state: PipelineState
:ivar _group: The cooperative group this consumer operates in
:type _group: CooperativeGroup
**Examples:**
.. code-block:: python
pipeline = PipelineAsync.create(...)
consumer = pipeline.create_consumer(consumer_group, stages)
for i in range(iterations):
consumer.wait() # Wait for data to be ready
# Consume data
consumer.release() # Signal buffer is empty
consumer.advance() # Move to next stage
"""
_pipeline: PipelineAsync
_state: PipelineState
_group: CooperativeGroup
def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup):
"""Initialize a new Consumer instance.
:param pipeline: The pipeline this consumer belongs to
:type pipeline: PipelineAsync
:param state: Initial pipeline state
:type state: PipelineState
:param group: The cooperative group for synchronization
:type group: CooperativeGroup
"""
self._pipeline = pipeline
self._group = group
self._state = state
@property
def index(self):
"""Get the index of the current pipeline stage."""
return self._state.index
def wait(self):
"""Wait for data to be ready in the current buffer.
This is a blocking operation.
"""
self._pipeline.consumer_wait(self._state)
def try_wait(self):
"""Try to check if data is ready without blocking.
:return: True if data is ready, False otherwise
:rtype: bool
"""
self._pipeline.consumer_try_wait(self._state)
def release(self):
"""Signal that data consumption is complete for the current stage.
This allows producers to start producing new data.
"""
self._pipeline.consumer_release(self._state)
def advance(self):
"""Move to the next pipeline stage."""
self._state.advance()
def __extract_mlir_values__(self):
"""Extract MLIR values from the current state.
:return: List of MLIR values representing the current state
:rtype: list
"""
return self._state.__extract_mlir_values__()
def __new_from_mlir_values__(self, values):
"""Create a new Consumer instance from MLIR values.
:param values: MLIR values to initialize the state
:type values: Any
:return: New Consumer instance with state initialized from values
:rtype: Consumer
"""
# TODO: need to call pipeline.__new_from_mlir_values__ recursively
return PipelineConsumer(
self._pipeline, self._state.__new_from_mlir_values__(values), self._group
)

View File

@ -29,7 +29,7 @@ from cutlass.cute.typing import (
from cutlass.cute.runtime import from_dlpack
import cutlass.cute as cute
import torch
from cuda import cuda
import cuda.bindings.driver as cuda
def dtype(ty: Type[Numeric]):

View File

@ -28,12 +28,22 @@ from .blackwell_helpers import (
make_smem_layout_b,
make_smem_layout_epi,
make_trivial_tiled_mma,
make_blockscaled_trivial_tiled_mma,
)
from .hopper_helpers import (
sm90_get_smem_store_op,
)
from .blockscaled_layout import (
BlockScaledBasicChunk,
tile_atom_to_shape_SF,
make_smem_layout_sfa,
make_smem_layout_sfb,
make_tmem_layout_sfa,
make_tmem_layout_sfb,
)
from .grouped_gemm_tile_scheduler_helper import (
GroupSearchResult,
GroupedGemmGroupSearchState,
@ -50,7 +60,12 @@ from .smem_allocator import SmemAllocator
from .layout import LayoutEnum
from .smem_capacity import (
get_smem_capacity_in_bytes,
)
__all__ = [
"get_smem_capacity_in_bytes",
"SmemAllocator",
"LayoutEnum",
"WorkTileInfo",

View File

@ -10,14 +10,22 @@
# is strictly prohibited.
from enum import Enum
from typing_extensions import deprecated
import warnings
@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead")
class SmemCapacity(Enum):
SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024
SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
warnings.warn(
"SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead",
DeprecationWarning,
stacklevel=2,
)
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value,

View File

@ -12,6 +12,8 @@
from enum import Enum
from math import log2, ceil
from typing import List, Type, Union, Tuple
from typing_extensions import deprecated
import warnings
from cutlass.cutlass_dsl import (
Float16,
@ -22,6 +24,7 @@ from cutlass.cutlass_dsl import (
Int8,
Float8E4M3FN,
Float8E5M2,
Float4E2M1FN,
Numeric,
NumericMeta,
dsl_user_op,
@ -34,6 +37,9 @@ from cutlass.cute.nvgpu.tcgen05 import (
MmaTF32Op,
MmaI8Op,
MmaFP8Op,
MmaMXF8Op,
MmaMXF4Op,
MmaMXF4NVF4Op,
OperandSource,
OperandMajorMode,
CtaGroup,
@ -58,6 +64,24 @@ from cutlass.cute.nvgpu.cpasync import (
from cutlass.utils.layout import LayoutEnum
@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead")
class SmemCapacity(Enum):
SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
warnings.warn(
"SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead",
DeprecationWarning,
stacklevel=2,
)
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value,
"sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value,
}
@dsl_user_op
def compute_epilogue_tile_shape(
cta_tile_shape: cute.Shape,
@ -822,18 +846,6 @@ def make_smem_layout_epi(
return epi_smem_layout_staged
class SmemCapacity(Enum):
SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value,
"sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value,
}
@dsl_user_op
def make_trivial_tiled_mma(
ab_dtype: Type[Numeric],
@ -917,6 +929,76 @@ def make_trivial_tiled_mma(
return cute.make_tiled_mma(cute.make_mma_atom(mma_op))
@dsl_user_op
def make_blockscaled_trivial_tiled_mma(
ab_dtype: Type[Numeric],
a_leading_mode: OperandMajorMode,
b_leading_mode: OperandMajorMode,
sf_dtype: Type[Numeric],
sf_vec_size: int,
cta_group: CtaGroup,
mma_tiler_mn: Tuple[int, int],
a_source: OperandSource = OperandSource.SMEM,
*,
loc=None,
ip=None,
) -> cute.TiledMma:
"""Make a BlockScaled tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
By default, the MMA atom is created with SMEM operand source for A.
:param ab_dtype: Data type of operands A and B.
:type ab_dtype: type[Numeric]
:param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N).
:type a_leading_mode: tcgen05.OperandMajorMode
:param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N).
:type b_leading_mode: tcgen05.OperandMajorMode
:param sf_dtype: Data type of the Scale Factor.
:type sf_dtype: type[Numeric]
:param sf_vec_size: The vector size of the Scale Factor.
:type sf_vec_size: int
:param cta_group: The CTA group to use.
:type cta_group: tcgen05.CtaGroup
:param mma_tiler_mn: The shape (M, N, K) of the MMA tiler.
:type mma_tiler_mn: Tuple[int, int]
:param a_source: The source of operand A (SMEM by default or TMEM).
:type a_source: OperandSource
:return: A tiled MMA atom.
:rtype: cute.TiledMma
:raises TypeError: If the data type is not supported.
"""
if ab_dtype in {Float8E4M3FN, Float8E5M2}:
mma_op = MmaMXF8Op(
ab_dtype,
(*mma_tiler_mn, 32),
cta_group,
a_source,
a_leading_mode,
b_leading_mode,
)
elif ab_dtype == Float4E2M1FN:
if sf_vec_size == 32:
mma_op = MmaMXF4Op(
(*mma_tiler_mn, 64),
cta_group,
a_source,
)
elif sf_vec_size == 16:
mma_op = MmaMXF4NVF4Op(
sf_dtype,
(*mma_tiler_mn, 64),
cta_group,
a_source,
)
else:
raise ValueError(f"unsupported sf_vec_size, got {sf_vec_size}")
else:
raise TypeError(f"unsupported ab_dtype, got {ab_dtype}")
return cute.make_tiled_mma(cute.make_mma_atom(mma_op))
@dsl_user_op
def cluster_shape_to_tma_atom_A(
cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None

View File

@ -0,0 +1,287 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from dataclasses import dataclass, field
from typing import Union
from cutlass.cutlass_dsl import dsl_user_op
import cutlass.cute as cute
from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
@dataclass(frozen=True)
class BlockScaledBasicChunk:
"""
The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops.
This class represents the fixed layout pattern for scale factors used in
tcgen05 BlockScaled MMA Ops. The layout is determined by the
instruction specification and cannot be modified.
See `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x>`.
"""
sf_vec_size: int
major_mode: OperandMajorMode = OperandMajorMode.K
_layout: cute.Layout = field(init=False, repr=False)
def __post_init__(self) -> None:
if self.major_mode == OperandMajorMode.K:
# K-major layout: (AtomMN, AtomK)
atom_shape = ((32, 4), (self.sf_vec_size, 4))
atom_stride = ((16, 4), (0, 1))
else:
# MN-major layout: (AtomK, AtomMN)
atom_shape = ((self.sf_vec_size, 4), (32, 4))
atom_stride = ((0, 1), (16, 4))
object.__setattr__(
self, "_layout", cute.make_layout(atom_shape, stride=atom_stride)
)
@property
def layout(self) -> cute.Layout:
"""
Get the layout for this block scaled chunk.
:return: The layout representing the scale factor atom
:rtype: cute.Layout
"""
return self._layout
@dsl_user_op
def tile_atom_to_shape_SF(
Shape: cute.Shape,
sf_vec_size: int,
*,
loc=None,
ip=None,
) -> cute.Layout:
"""
A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout.
:param Shape: The shape of the A/B tensor
:param sf_vec_size: Scale factor vector size
:return: The layout of the SFA/SFB tensor
:rtype: cute.Layout
"""
# ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL)
sf_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3)
)
return sf_layout
@dsl_user_op
def make_smem_layout_sfa(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc=None,
ip=None,
) -> cute.Layout:
"""
Make smem layout for SFA based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
# (CTA_Tile_Shape_M, MMA_Tile_Shape_K)
sfa_tile_shape = (
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape),
mma_tiler_mnk[2],
)
# ((Atom_M, Rest_M),(Atom_K, Rest_K))
smem_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout,
sfa_tile_shape,
(2, 1),
)
mma_tile_inst_k = 4
# (CTA_Tile_Shape_M, MMA_Inst_Shape_K)
sfa_tile_shape = cute.shape_div(sfa_tile_shape, (1, mma_tile_inst_k))
# ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K))
smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape)
atom_m = 128
tiler_inst = ((atom_m, sf_vec_size),)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfa_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfa_smem_layout_staged
@dsl_user_op
def make_smem_layout_sfb(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc=None,
ip=None,
) -> cute.Layout:
"""
Make smem layout for SFB based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
# (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K)
sfb_tile_shape = (
cute.round_up(mma_tiler_mnk[1], 128),
mma_tiler_mnk[2],
)
# ((Atom_N, Rest_N),(Atom_K, Rest_K))
smem_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout,
sfb_tile_shape,
(2, 1),
)
mma_tile_inst_k = 4
# (CTA_Tile_Shape_N, MMA_Inst_Shape_K)
sfb_tile_shape = cute.shape_div(sfb_tile_shape, (1, mma_tile_inst_k))
# ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K)
smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape)
atom_n = 128
tiler_inst = ((atom_n, sf_vec_size),)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfb_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfb_smem_layout_staged
@dsl_user_op
def make_tmem_layout_sfa(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
smem_layout: cute.Layout,
*,
loc=None,
ip=None,
) -> cute.Layout:
"""Make tmem layout for SFA based on:
1. SFA smem layout per stage
2. Cta tile shape m
3. tiled MMA atom thr size
4. Scale factor vector size
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param smem_layout: The smem layout of SFA per stage
:type smem_layout: cute.Layout
:return: TMEM layout for SFA
:rtype: cute.Layout
"""
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size
sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa(
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
)
return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip)
@dsl_user_op
def make_tmem_layout_sfb(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
smem_layout: cute.Layout,
*,
loc=None,
ip=None,
) -> cute.Layout:
"""Make tmem layout for SFB based on:
1. SFB smem layout per stage
2. Cta tile shape m
3. tiled MMA atom thr size
4. Scale factor vector size
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param smem_layout: The smem layout of SFB per stage
:type smem_layout: cute.Layout
:return: TMEM layout for SFB
:rtype: cute.Layout
"""
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size
sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb(
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
)
return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip)

View File

@ -11,6 +11,8 @@
from typing import Type, Tuple
from enum import Enum
from typing_extensions import deprecated
import warnings
from cutlass.utils.layout import LayoutEnum
from cutlass.cutlass_dsl import (
@ -34,6 +36,23 @@ from cutlass.cute.nvgpu.warpgroup import (
OperandSource,
)
@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead")
class SmemCapacity(Enum):
SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
warnings.warn(
"SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead",
DeprecationWarning,
stacklevel=2,
)
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value,
}
@dsl_user_op
def sm90_get_smem_store_op(
layout_d: LayoutEnum,
@ -79,15 +98,6 @@ def sm90_get_smem_store_op(
return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip)
class SmemCapacity(Enum):
SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value,
}
def make_trivial_tiled_mma(
a_dtype: Type[Numeric],
b_dtype: Type[Numeric],

View File

@ -14,7 +14,7 @@ from typing import Type, Union, overload
from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta
import cutlass.cute as cute
from cutlass.cute.arch import get_dyn_smem
from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size
class SmemAllocator:
@ -60,6 +60,7 @@ class SmemAllocator:
:return: Pointer to the start of the allocated memory block or struct instance
:rtype: cute.Pointer
:raises ValueError: If size is negative or alignment is less than 1
:raises RuntimeError: If allocation would exceed available shared memory
"""
if isinstance(size_or_type, cute.struct):
alignment = max(byte_alignment, size_or_type.__alignof__())
@ -80,6 +81,14 @@ class SmemAllocator:
byte_alignment - self._allocated_bytes % byte_alignment
)
self._allocated_bytes += num_bytes
# Check bounds against available dynamic shared memory
cute.testing.assert_(
self._allocated_bytes <= get_dyn_smem_size(),
f"Allocation failed: shared memory allocation exceeds available memory set in kernel launch. "
f"Allocated bytes: {self._allocated_bytes} bytes. "
f"Please reduce the allocation or set a larger smem size in kernel launch.",
)
return ptr
def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1):

View File

@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
SMEM_CAPACITY_MAP = {
"sm_120": (100 - 1) * 1024,
"sm_100": (228 - 1) * 1024,
"sm_90": (228 - 1) * 1024,
"sm_80": (164 - 1) * 1024,
"sm_86": (100 - 1) * 1024,
"sm_89": (100 - 1) * 1024,
}
def get_smem_capacity_in_bytes(compute_capability: str) -> int:
if compute_capability not in SMEM_CAPACITY_MAP:
raise ValueError(f"Unsupported compute capability: {compute_capability}")
return SMEM_CAPACITY_MAP[compute_capability]

View File

@ -159,7 +159,11 @@ class CutlassBaseDSL(BaseDSL):
pipeline = super()._get_pipeline(pipeline)
if pipeline == None:
# cubin format is required to be cubin as we launch cuda module at python level.
return "builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3})"
return (
"builtin.module(cute-to-nvvm{cubin-format=bin "
+ self.compile_options.to_str()
+ "})"
)
return pipeline
@ -294,13 +298,8 @@ class CutlassBaseDSL(BaseDSL):
self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs
)
def _get_globals(self):
caller_globals = self.frame.f_globals
caller_locals = self.frame.f_locals
all_globals = globals().copy()
all_globals.update(caller_globals)
all_globals.update(caller_locals)
return all_globals
def _get_module_globals(self):
return globals()
def _preprocess_launch_config_args(self, args, kwargs):
"""Helper to preprocess args and kwargs for LaunchConfig"""
@ -459,7 +458,10 @@ class KernelLauncher:
def _check_func_args(self, funcBody, *func_args, **func_kwargs):
# Get function signature
sig = inspect.signature(funcBody)
if isinstance(funcBody, DSLCallable):
sig = funcBody.get_signature()
else:
sig = inspect.signature(funcBody)
# func_args and func_kwargs should match funcBody's signature,
# no extra or missing arguments.
@ -485,6 +487,7 @@ class KernelLauncher:
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
self.dsl.kernel_symbols.append(name)
self.dsl.frame = None
return ret.launch_op_ret
def __call__(self, *args, **kwargs):
@ -537,14 +540,18 @@ def pack_from_irvalue(
mixed_values[idx] = obj
elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"):
mixed_values[idx] = obj.__new_from_mlir_values__(chunk)
elif isinstance(chunk, list) and chunk[0] is None:
mixed_values[idx] = class_types[idx]
else:
try:
if isinstance(chunk, list) and chunk[0] is None:
mixed_values[idx] = class_types[idx]
else:
if len(chunk) == 1:
try:
mixed_values[idx] = t.as_numeric(chunk[0])
except DSLRuntimeError as e:
mixed_values[idx] = chunk[0]
except ValueError:
# Suppress the conversion error and try new_from_mlir_values below
pass
if mixed_values[idx] is None:
mixed_values[idx] = new_from_mlir_values(obj, chunk)
log().debug("------------------ ")
for idx, packed in enumerate(mixed_values):

View File

@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.1.0.dev0
nvidia-cutlass-dsl==4.1.0