v4.1 release

This commit is contained in:
Junkai-Wu
2025-07-03 20:07:53 +08:00
committed by GitHub
parent b995f93317
commit a1aaf2300a
155 changed files with 18407 additions and 6068 deletions

View File

@ -0,0 +1,259 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Demonstrating JIT GEMM Implementation with Static Shape Wrapper
This example illustrates how to invoke a JIT-compiled GEMM implementation through a wrapper function
with static shapes. It showcases the integration between PyTorch and CuTe tensors in a JIT context.
Key features demonstrated:
1. Seamless conversion between PyTorch and CuTe tensors using the JitArgument protocol
2. Integration of static shape GEMM operations within a JIT-compiled wrapper function
Core components:
- BufferWithLayout: Handles memory buffer management with configurable stride ordering
- tensor_op_gemm_wrapper: JIT-compiled entry point that orchestrates the GEMM operation
Usage:
.. code-block:: bash
python examples/ampere/call_from_jit.py
Default configuration:
- Batch dimension (L): 16
- Matrix dimensions: M=512, N=256, K=128
- Precision: Float16 inputs with Float32 accumulation
Requirements:
- CUDA-capable GPU
- PyTorch with CUDA support
"""
import os
import sys
from typing import Type, Tuple
import torch
import cutlass
import cutlass.cute as cute
from cutlass.torch import dtype as torch_dtype
from cutlass.cute.runtime import make_ptr
# Add the current directory to sys.path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from tensorop_gemm import TensorOpGemm
class BufferWithLayout:
def __init__(self, ptr: cute.Pointer, stride_order: tuple[int, int, int]):
self.ptr = ptr
# static properties
self.stride_order = stride_order
def to_tensor(
self, shape: tuple[int, int, int], *, loc=None, ip=None
) -> cute.Tensor:
assert len(shape) == len(self.stride_order), (
f"Shape {shape} and stride_order {self.stride_order} must have the "
"same rank."
)
layout = cute.make_ordered_layout(shape, self.stride_order)
# permute (l, mn, k) -> (mn, k, l)
res = cute.make_tensor(self.ptr, cute.select(layout, mode=[1, 2, 0]))
return res
# Implement JitArgument Protocol and DynamicExpression Protocol
def __c_pointers__(self):
"""Get the C pointers for the underlying pointer.
This method is part of the JitArgument Protocol and returns the C pointers
from the underlying pointer object.
This is required for user to define a custom data type which can pass to JIT function.
When JIT compiled function is called, JIT executor will call this method to get raw pointers
to underlying data object.
Following condition must be satisfied:
len(__c_pointers__()) == len(__get_mlir_types__()) == len(__extract_mlir_values__())
:return: The C pointers from the underlying pointer object
:rtype: Any
"""
return self.ptr.__c_pointers__()
def __get_mlir_types__(self):
"""Get the MLIR types for the underlying pointer.
This method is part of the JitArgument Protocol and returns the MLIR types
used for compiler to generate code. It must match the type of the underlying pointers
returned by __c_pointers__().
:return: The MLIR types from the underlying pointer object
:rtype: Any
"""
return self.ptr.__get_mlir_types__()
def __extract_mlir_values__(self):
"""Extract MLIR values from the underlying pointer.
This method is part of the DynamicExpression Protocol and extracts MLIR values
from the underlying pointer object.
It is used by compiler to generate function call in MLIR to another JIT function.
It must match the types returned by __get_mlir_types__().
:return: The MLIR values extracted from the underlying pointer object
:rtype: Any
"""
return self.ptr.__extract_mlir_values__()
def __new_from_mlir_values__(self, values):
"""Create a new BufferWithLayout instance from MLIR values.
This method is part of the JitArgument & DynamicExpression Protocol and creates a new
BufferWithLayout instance with pointer initialized from the given MLIR values.
It is used by compiler to generate function body in MLIR called by JIT function.
It must match the types returned by __c_pointers__() and __get_mlir_types__().
code generator takes function arguments and reconstructs python object which is legal
inside function body.
:param values: MLIR values to initialize the underlying pointer
:type values: Any
:return: A new BufferWithLayout instance with pointer initialized from values
:rtype: BufferWithLayout
"""
return BufferWithLayout(
self.ptr.__new_from_mlir_values__(values), self.stride_order
)
@cute.jit
def tensor_op_gemm_wrapper(
buffer_a: BufferWithLayout,
buffer_b: BufferWithLayout,
buffer_c: BufferWithLayout,
mnkl: cutlass.Constexpr[tuple[int, int, int, int]],
acc_dtype: Type[cutlass.Numeric],
atom_layout_mnk: cutlass.Constexpr[tuple[int, int, int]],
):
print(f"\n[DSL INFO] Input Parameters:")
print(f"[DSL INFO] mnkl: {mnkl}")
print(f"[DSL INFO] buffer_a: {buffer_a}")
print(f"[DSL INFO] buffer_b: {buffer_b}")
print(f"[DSL INFO] buffer_c: {buffer_c}")
print(f"[DSL INFO] acc_dtype: {acc_dtype}")
print(f"[DSL INFO] atom_layout_mnk: {atom_layout_mnk}")
mA = buffer_a.to_tensor(cute.select(mnkl, mode=[3, 0, 2]))
mB = buffer_b.to_tensor(cute.select(mnkl, mode=[3, 1, 2]))
mC = buffer_c.to_tensor(cute.select(mnkl, mode=[3, 0, 1]))
print(f"\n[DSL INFO] Created Tensors:")
print(f"[DSL INFO] mA = {mA}")
print(f"[DSL INFO] mB = {mB}")
print(f"[DSL INFO] mC = {mC}")
tensor_op_gemm = TensorOpGemm(
buffer_a.ptr.value_type,
buffer_c.ptr.value_type,
acc_dtype,
atom_layout_mnk,
)
print(f"\n[DSL INFO] Created TensorOpGemm instance")
print(f"[DSL INFO] Input dtype: {buffer_a.ptr.value_type}")
print(f"[DSL INFO] Output dtype: {buffer_c.ptr.value_type}")
print(f"[DSL INFO] Accumulation dtype: {acc_dtype}")
print(f"[DSL INFO] Atom layout: {atom_layout_mnk}")
# No need to compile inside jit function
tensor_op_gemm(mA, mB, mC)
print(f"\n[DSL INFO] Executed TensorOpGemm")
def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
print(f"\nRunning TensorOpGemm test with:")
print(f"Tensor dimensions: {mnkl}")
ab_dtype = cutlass.Float16
c_dtype = cutlass.Float16
a = torch.randn(
mnkl[3], mnkl[0], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda"
)
b = torch.randn(
mnkl[3], mnkl[1], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda"
)
c = torch.randn(
mnkl[3], mnkl[0], mnkl[1], dtype=torch_dtype(c_dtype), device="cuda"
)
print(f"Input tensor shapes:")
print(f"a: {a.shape}, dtype: {a.dtype}")
print(f"b: {b.shape}, dtype: {b.dtype}")
print(f"c: {c.shape}, dtype: {c.dtype}\n")
buffer_a = BufferWithLayout(
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem),
(2, 1, 0),
)
buffer_b = BufferWithLayout(
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem),
(2, 1, 0),
)
buffer_c = BufferWithLayout(
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem),
(2, 1, 0),
)
tensor_op_gemm_wrapper(
buffer_a,
buffer_b,
buffer_c,
mnkl, # pass shape as static value
# no stride passing
cutlass.Float32,
(2, 2, 1),
)
torch.cuda.synchronize()
ref = torch.einsum("lmk,lnk->lmn", a, b)
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
print(f"\n[DSL INFO] Results verified successfully!")
print(f"First few elements of result: \n{c[:3, :3, :3]}")
if __name__ == "__main__":
run_tensor_op_gemm_wrapper((512, 256, 128, 16))

View File

@ -28,16 +28,17 @@
import argparse
import torch
import time
from typing import Type
import cuda.bindings.driver as cuda
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import cutlass.cute.testing as testing
import cutlass.torch as cutlass_torch
from cutlass.cute.runtime import from_dlpack
"""
An Elementwise Addition Example using CuTe DSL.
@ -153,6 +154,7 @@ def elementwise_add_kernel(
blkC = gC[blk_coord] # (TileM,TileN)
blkCrd = cC[blk_coord] # (TileM, TileN)
# Note: these prints only run at compile/jit time
print(f"[DSL INFO] Sliced Tensors per thread block:")
print(f"[DSL INFO] blkA = {blkA.type}")
print(f"[DSL INFO] blkB = {blkB.type}")
@ -189,7 +191,7 @@ def elementwise_add_kernel(
print(f"[DSL INFO] thrC = {thrC.type}")
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
for i in cutlass.range_dynamic(0, cute.size(frgPred), 1):
for i in range(0, cute.size(frgPred), 1):
val = cute.elem_less(thrCrd[i], shape)
frgPred[i] = val
@ -270,9 +272,6 @@ def run_elementwise_add(
warmup_iterations=2,
iterations=200,
):
if not torch.cuda.is_available():
raise RuntimeError(f"Ampere GPU is required to run this example!")
print(f"\nRunning Elementwise Add test with:")
print(f"Tensor dimensions: [{M}, {N}]")
print(f"Input and Output Data type: {dtype}")
@ -315,10 +314,8 @@ def run_elementwise_add(
print("Executing vector add kernel...")
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
# Get current CUstream from torch
current_stream = cutlass_torch.current_stream()
if not skip_ref_check:
compiled_func(a_tensor, b_tensor, c_tensor)
@ -329,41 +326,52 @@ def run_elementwise_add(
if not benchmark:
return
# Create CUDA events for timing
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
def generate_tensors():
if dtype.is_integer:
a = torch.randint(
0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype
)
b = torch.randint(
0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype
)
else:
a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
# Warmup
for _ in range(warmup_iterations):
compiled_func(a_tensor, b_tensor, c_tensor)
c = torch.zeros_like(a)
# Use the current stream for CUDA events instead of the default stream
# Record start event
cuda.cuEventRecord(start_event, current_stream)
if not is_a_dynamic_layout:
a_tensor = from_dlpack(a).mark_layout_dynamic()
else:
a_tensor = a
# Execute the kernel
for _ in range(iterations):
compiled_func(a_tensor, b_tensor, c_tensor)
if not is_b_dynamic_layout:
b_tensor = from_dlpack(b).mark_layout_dynamic()
else:
b_tensor = b
# Record end event
cuda.cuEventRecord(end_event, current_stream)
cuda.cuEventSynchronize(end_event)
if not is_result_dynamic_layout:
c_tensor = from_dlpack(c).mark_layout_dynamic()
else:
c_tensor = c
# Calculate elapsed time
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
avg_time = elapsed_time / iterations
return testing.JitArguments(a_tensor, b_tensor, c_tensor)
avg_time_us = testing.benchmark(
compiled_func,
workspace_generator=generate_tensors,
workspace_count=10,
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
)
# Print execution results
print(f"Kernel execution time: {avg_time:.4f} ms")
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
print(
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s"
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
)
print(f"First few elements of result: \n{c[:3, :3]}")
# Destroy events
cuda.cuEventDestroy(start_event)
cuda.cuEventDestroy(end_event)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@ -377,6 +385,10 @@ if __name__ == "__main__":
parser.add_argument("--benchmark", action="store_true")
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError(f"Ampere GPU is required to run this example!")
run_elementwise_add(
args.M,
args.N,

View File

@ -29,14 +29,15 @@
import argparse
import operator
import torch
from typing import Type
import time
from typing import Type, List
import cuda.bindings.driver as cuda
import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.torch as cutlass_torch
from cutlass.cute.runtime import from_dlpack
@ -77,8 +78,7 @@ while maintaining high performance through efficient memory access patterns.
@cute.kernel
def elementwise_apply_kernel(
op: cutlass.Constexpr,
gA: cute.Tensor,
gB: cute.Tensor,
inputs: List[cute.Tensor],
gC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
@ -90,48 +90,46 @@ def elementwise_apply_kernel(
# slice for CTAs
cta_coord = ((None, None), bidx)
# logical coord -> address
ctaA = gA[cta_coord] # (TileM, TileN)
ctaB = gB[cta_coord] # (TileM, TileN)
# Leverage the meta-programming capability of the DSL to slice the tensors for each input
# All for loops below on input tensors would be fully unrolled automatically at compile time
ctaInputs = [t[cta_coord] for t in inputs] # (TileM, TileN)
ctaC = gC[cta_coord] # (TileM, TileN)
ctaCrd = cC[cta_coord] # (TileM, TileN)
print(f"[DSL INFO] Sliced Tensors per thread block:")
print(f"[DSL INFO] ctaA = {ctaA.type}")
print(f"[DSL INFO] ctaB = {ctaB.type}")
for i in cutlass.range_constexpr(len(ctaInputs)):
print(f"[DSL INFO] ctaInputs{i} = {ctaInputs[i].type}")
print(f"[DSL INFO] ctaC = {ctaC.type}")
print(f"[DSL INFO] ctaCrd = {ctaCrd.type}")
# compose with CTA TV layout
# (tid, vid) -> address
tidfrgA = cute.composition(ctaA, tv_layout)
tidfrgB = cute.composition(ctaB, tv_layout)
tidfrgInputs = [cute.composition(t, tv_layout) for t in ctaInputs]
tidfrgC = cute.composition(ctaC, tv_layout)
tidfrgCrd = cute.composition(ctaCrd, tv_layout)
# print(f"{tv_layout = }")
# print(f"{tidfrgA = }")
# print(f"{tidfrgAB[0] = }")
thr_coord = (tidx, (None, None))
# slice for threads
# vid -> address
thrA = tidfrgA[thr_coord] # (V)
thrB = tidfrgB[thr_coord] # (V)
thrInputs = [t[thr_coord] for t in tidfrgInputs] # (V)
thrC = tidfrgC[thr_coord] # (V)
thrCrd = tidfrgCrd[thr_coord]
print(f"[DSL INFO] Sliced Tensors per thread:")
print(f"[DSL INFO] thrA = {thrA.type}")
print(f"[DSL INFO] thrB = {thrB.type}")
for i in cutlass.range_constexpr(len(thrInputs)):
print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}")
print(f"[DSL INFO] thrC = {thrC.type}")
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
# allocate fragments for gmem->rmem
frgA = cute.make_fragment_like(thrA, gA.element_type)
frgB = cute.make_fragment_like(thrB, gB.element_type)
frgInputs = [cute.make_fragment_like(t, t.element_type) for t in thrInputs]
frgC = cute.make_fragment_like(thrC, gC.element_type)
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
for i in cutlass.range_dynamic(cute.size(frgPred), unroll=1):
for i in cutlass.range(cute.size(frgPred), unroll=1):
frgPred[i] = cute.elem_less(thrCrd[i], shape)
# if tidx == 0 and bidx == 0:
@ -142,10 +140,13 @@ def elementwise_apply_kernel(
##########################################################
# declare the atoms which will be used later for memory copy
# Compile time validation: expect same element type for all input tensors so as to reuse the copy atom for load
assert all(t.element_type == inputs[0].element_type for t in inputs)
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
gA.element_type,
num_bits_per_copy=gA.element_type.width,
inputs[0].element_type,
num_bits_per_copy=inputs[0].element_type.width,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
@ -153,12 +154,12 @@ def elementwise_apply_kernel(
num_bits_per_copy=gC.element_type.width,
)
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
for thrInput, frgInput in zip(thrInputs, frgInputs):
cute.copy(copy_atom_load, thrInput, frgInput, pred=frgPred)
# Load data before use. The compiler will optimize the copy and load
# operations to convert some memory ld/st into register uses.
result = op(frgA.load(), frgB.load())
result = op(*[frgInput.load() for frgInput in frgInputs])
# Save the results back to registers. Here we reuse b's registers.
frgC.store(result)
@ -173,6 +174,7 @@ def elementwise_apply(
a: cute.Tensor,
b: cute.Tensor,
result: cute.Tensor,
stream: cuda.CUstream,
):
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
CuTe Python and store to result tensor.
@ -262,8 +264,7 @@ def elementwise_apply(
# Async token(s) can also be specified as dependencies
elementwise_apply_kernel(
op,
gA,
gB,
[gA, gB], # Group input tensors into a list as a single argument
gC,
cC,
result.shape,
@ -271,6 +272,7 @@ def elementwise_apply(
).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
stream=stream,
)
@ -287,6 +289,11 @@ def run_elementwise_apply_and_verify(
if not torch.cuda.is_available():
raise RuntimeError(f"Ampere GPU is required to run this example!")
# Create non default CUDA stream from PyTorch
torch_stream = torch.cuda.Stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
print(f"\nRunning Elementwise Apply test with:")
print(f"Tensor dimensions: [{M}, {N}]")
print(f"Input and Output Data type: {dtype}")
@ -309,20 +316,16 @@ def run_elementwise_apply_and_verify(
if op in (operator.truediv, operator.floordiv):
b = torch.where(b == 0, torch.tensor(epsilon), b)
print("Compiling kernel with cute.compile ...")
start_time = time.time()
compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
print("Executing elementwise apply kernel...")
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
if not skip_ref_check:
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
elementwise_apply(
op,
from_dlpack(a),
from_dlpack(b),
from_dlpack(c).mark_layout_dynamic(),
current_stream,
)
print("Verifying results...")
torch.testing.assert_close(op(a, b), c)
print("Results verified successfully!")
@ -330,28 +333,32 @@ def run_elementwise_apply_and_verify(
if not benchmark:
return
# Create CUDA events for timing
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
compiled_func = cute.compile(
elementwise_apply,
op,
from_dlpack(a),
from_dlpack(b),
from_dlpack(c).mark_layout_dynamic(),
current_stream,
)
# Warmup
for _ in range(warmup_iterations):
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
# When compiled we inlined op in the kernel, so we do not pass it when benchmarking
# Record start event
cuda.cuEventRecord(start_event, current_stream)
avg_time_us = testing.benchmark(
compiled_func,
kernel_arguments=testing.JitArguments(
from_dlpack(a),
from_dlpack(b),
from_dlpack(c).mark_layout_dynamic(),
current_stream,
),
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
use_cuda_graphs=True,
stream=current_stream,
)
# Execute the kernel
for _ in range(iterations):
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
# Record end event
cuda.cuEventRecord(end_event, current_stream)
cuda.cuEventSynchronize(end_event)
# Calculate elapsed time
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
avg_time = elapsed_time / iterations
avg_time = avg_time_us / 1e3
# Print execution results
print(f"Kernel execution time: {avg_time:.4f} ms")
@ -360,10 +367,6 @@ def run_elementwise_apply_and_verify(
)
print(f"First few elements of result: \n{c[:3, :3]}")
# Destroy events
cuda.cuEventDestroy(start_event)
cuda.cuEventDestroy(end_event)
if __name__ == "__main__":
parser = argparse.ArgumentParser(

View File

@ -542,13 +542,13 @@ class FlashAttentionForwardAmpere:
cutlass.Boolean,
)
# Set predicates for head_dim bounds, seqlen_q/k bounds is processed at the first tile.
for rest_v in range(tQpQ.shape[0]):
for rest_k in range(tQpQ.shape[2]):
for rest_v in cutlass.range_constexpr(tQpQ.shape[0]):
for rest_k in cutlass.range_constexpr(tQpQ.shape[2]):
tQpQ[rest_v, 0, rest_k] = cute.elem_less(
tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3]
)
for rest_v in range(tKVpKV.shape[0]):
for rest_k in range(tKVpKV.shape[2]):
for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]):
for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]):
tKVpKV[rest_v, 0, rest_k] = cute.elem_less(
tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3]
)
@ -556,7 +556,7 @@ class FlashAttentionForwardAmpere:
# Prefetch Prologue
# ///////////////////////////////////////////////////////////////////////////////
# Start async loads of the last mn-tile, where we take care of the mn residue
for m in range(cute.size(tQsQ.shape[1])):
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]):
cute.copy(
gmem_tiled_copy_QKV,
@ -567,7 +567,7 @@ class FlashAttentionForwardAmpere:
else:
# Clear the smem tiles to account for predicated off loads
tQsQ[None, m, None].fill(0)
for n in range(cute.size(tKsK.shape[1])):
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]):
cute.copy(
gmem_tiled_copy_QKV,
@ -644,13 +644,13 @@ class FlashAttentionForwardAmpere:
# We also need masking on S if it's causal, for the last ceil_div(m_block_size, n_block_size) blocks.
# We will have at least 1 "masking" iteration.
mask_steps = 1
if self._is_causal:
if cutlass.const_expr(self._is_causal):
mask_steps = cute.ceil_div(self._m_block_size, self._n_block_size)
for n_tile in range(mask_steps):
for n_tile in cutlass.range_constexpr(mask_steps):
n_block = n_block_max - n_tile - 1
basic_params.n_block = n_block
if self._is_causal:
if cutlass.const_expr(self._is_causal):
if n_block >= 0:
self.compute_one_n_block(
basic_params,
@ -673,7 +673,7 @@ class FlashAttentionForwardAmpere:
)
# Start async loads of rest k-tiles in reverse order, no k-residue handling needed
for n_tile in cutlass.range_dynamic(mask_steps, n_block_max, 1):
for n_tile in range(mask_steps, n_block_max, 1):
n_block = n_block_max - n_tile - 1
basic_params.n_block = n_block
self.compute_one_n_block(
@ -748,13 +748,13 @@ class FlashAttentionForwardAmpere:
),
cutlass.Boolean,
)
for rest_v in range(tOpO.shape[0]):
for rest_n in range(cute.size(tOpO.shape[2])):
for rest_v in cutlass.range_constexpr(tOpO.shape[0]):
for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])):
tOpO[rest_v, 0, rest_n] = cute.elem_less(
tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3]
)
# copy acc O from rmem to gmem
for rest_m in range(cute.size(tOpO.shape[1])):
for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])):
if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]):
cute.copy(
gmem_tiled_copy_O,
@ -804,7 +804,7 @@ class FlashAttentionForwardAmpere:
# load smem tile V for O, special process for the first tile to avoid loading nan.
# The `if` here is a constexpr, won't be generated in the IR.
if is_first_n_block:
for n in range(cute.size(gmem_copy_params.tVsV.shape[1])):
for n in cutlass.range_constexpr(cute.size(gmem_copy_params.tVsV.shape[1])):
if cute.elem_less(
gmem_copy_params.tKVcKV[0, n, 0][1],
basic_params.mK.layout.shape[1],
@ -841,7 +841,7 @@ class FlashAttentionForwardAmpere:
smem_copy_params.tSrK_copy_view[None, None, 0],
)
# mma for S
for k in range(cute.size(smem_copy_params.tSsQ.shape[2])):
for k in cutlass.range_constexpr(cute.size(smem_copy_params.tSsQ.shape[2])):
# load next QK k-block from smem to rmem for mma
k_next = (k + 1) % cute.size(smem_copy_params.tSsQ.shape[2])
cute.copy(
@ -916,7 +916,7 @@ class FlashAttentionForwardAmpere:
smem_copy_params.tOrVt_copy_view[None, None, 0],
)
# mma for O
for k in range(cute.size(tOrS.shape[2])):
for k in cutlass.range_constexpr(cute.size(tOrS.shape[2])):
# load next V k-block from smem to rmem for mma
k_next = (k + 1) % cute.size(tOrS.shape[2])
cute.copy(
@ -965,14 +965,14 @@ class FlashAttentionForwardAmpere:
acc_O_mn = self._make_acc_tensor_mn_view(mma_params.acc_O)
row_max_prev = None
# if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row.
if not is_first_n_block:
if cutlass.const_expr(not is_first_n_block):
row_max_prev = cute.make_fragment_like(
softmax_params.row_max, cutlass.Float32
)
cute.basic_copy(softmax_params.row_max, row_max_prev)
# if it is the first tile, create a mask for residual of S to -inf for softmax.
tScS_mn = None
if in_mask_steps:
if cutlass.const_expr(in_mask_steps):
mcS = cute.make_identity_tensor(
(
basic_params.mQ.shape[0],
@ -990,12 +990,12 @@ class FlashAttentionForwardAmpere:
tScS_mn = self._make_acc_tensor_mn_view(tScS)
# Each iteration processes one row of acc_S
for r in range(cute.size(softmax_params.row_max)):
for r in cutlass.range_constexpr(cute.size(softmax_params.row_max)):
# mask residual of S with -inf
if in_mask_steps:
if not self._is_causal:
if cutlass.const_expr(in_mask_steps):
if cutlass.const_expr(not self._is_causal):
# traverse column index.
for c in range(cute.size(tScS_mn.shape[1])):
for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])):
if cute.elem_less(
basic_params.mK.shape[1], tScS_mn[0, c][3] + 1
):
@ -1006,7 +1006,7 @@ class FlashAttentionForwardAmpere:
tScS_mn[r, 0][1] + 1, basic_params.mK.shape[1]
)
# traverse column index.
for c in range(cute.size(tScS_mn.shape[1])):
for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])):
# only consider the column index, so the row index sets to 0.
if cute.elem_less(col_idx_limit, tScS_mn[0, c][3] + 1):
acc_S_mn[r, c] = -cutlass.Float32.inf
@ -1021,10 +1021,10 @@ class FlashAttentionForwardAmpere:
row_max_cur_row = self._threadquad_reduce_max(row_max_cur_row)
row_max_prev_row = None
# if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row.
if not is_first_n_block:
if cutlass.const_expr(not is_first_n_block):
row_max_prev_row = row_max_prev[r]
row_max_cur_row = cute.arch.fmax(row_max_prev_row, row_max_cur_row)
if self._is_causal:
if cutlass.const_expr(self._is_causal):
row_max_cur_row = (
0.0 if row_max_cur_row == -cutlass.Float32.inf else row_max_cur_row
)
@ -1043,7 +1043,7 @@ class FlashAttentionForwardAmpere:
cute.ReductionOp.ADD, cutlass.Float32.zero, 0
)
# if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum.
if not is_first_n_block:
if cutlass.const_expr(not is_first_n_block):
prev_minus_cur_exp = self._exp2f(
row_max_prev_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2
@ -1072,7 +1072,7 @@ class FlashAttentionForwardAmpere:
"""
# do quad reduction for row_sum.
acc_O_mn = self._make_acc_tensor_mn_view(acc_O)
for r in range(cute.size(row_sum)):
for r in cutlass.range_constexpr(cute.size(row_sum)):
row_sum[r] = self._threadquad_reduce_sum(row_sum[r])
# if row_sum is zero or nan, set acc_O_mn_row to 1.0
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]

View File

@ -35,6 +35,8 @@ import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutlass.cute.runtime import from_dlpack
@ -109,6 +111,7 @@ class SGemm:
mB: cute.Tensor,
mC: cute.Tensor,
epilogue_op: cutlass.Constexpr = lambda x: x,
stream: cuda.CUstream = cuda.CUstream(cuda.CUstream_flags.CU_STREAM_DEFAULT),
):
self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
@ -168,7 +171,7 @@ class SGemm:
num_bits_per_copy=mB.element_type.width,
)
if self.a_major_mode == utils.LayoutEnum.COL_MAJOR:
if cutlass.const_expr(self.a_major_mode == utils.LayoutEnum.COL_MAJOR):
num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1
atom_async_copy_A = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(),
@ -182,7 +185,7 @@ class SGemm:
)
vA = cute.make_layout((num_vectorized, 1))
if self.b_major_mode == utils.LayoutEnum.COL_MAJOR:
if cutlass.const_expr(self.b_major_mode == utils.LayoutEnum.COL_MAJOR):
num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1
atom_async_copy_B = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(),
@ -222,7 +225,7 @@ class SGemm:
atoms_layout = cute.make_layout(
(self._num_threads // 16, 16, 1), stride=(16, 1, 0)
)
if self.c_major_mode == utils.LayoutEnum.COL_MAJOR:
if cutlass.const_expr(self.c_major_mode == utils.LayoutEnum.COL_MAJOR):
atoms_layout = cute.make_layout(
(16, self._num_threads // 16, 1), stride=(1, 16, 0)
)
@ -256,6 +259,7 @@ class SGemm:
grid=grid_dim,
block=[cute.size(atoms_layout), 1, 1],
smem=smem_size,
stream=stream,
)
@cute.kernel
@ -540,8 +544,8 @@ class SGemm:
# 3. Combining the smem and register pipelines results in the mainloop.
# ///////////////////////////////////////////////////////////////////////////////
for _ in cutlass.range_dynamic(k_tile_count, unroll=1):
for k_block in range(k_block_max):
for _ in range(k_tile_count):
for k_block in range(k_block_max, unroll_full=True):
if k_block == k_block_max - 1:
tCsA_p = tCsA[None, None, None, smem_pipe_read]
tCsB_p = tCsB[None, None, None, smem_pipe_read]
@ -639,7 +643,6 @@ def main(
iterations: int = 100,
skip_ref_check: bool = False,
):
torch.manual_seed(1024)
M, N, K = problem_shape
# Create and permute tensor A/B/C
@ -694,51 +697,36 @@ def main(
sgemm = SGemm()
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
print("Compiling kernel with cute.compile ...")
start_time = time.time()
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor)
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream)
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
print("Executing GEMM kernel...")
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
# Create CUDA events for timing
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
# Warmup
for _ in range(warmup_iterations):
gemm(a_tensor, b_tensor, c_tensor)
# Use the current stream for CUDA events instead of the default stream
# Record start event
cuda.cuEventRecord(start_event, current_stream)
# Execute the kernel
for _ in range(iterations):
gemm(a_tensor, b_tensor, c_tensor)
# Record end event
cuda.cuEventRecord(end_event, current_stream)
cuda.cuEventSynchronize(end_event)
# Calculate elapsed time
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
avg_time_us = testing.benchmark(
gemm,
kernel_arguments=testing.JitArguments(
a_tensor, b_tensor, c_tensor, current_stream
),
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
use_cuda_graphs=False,
stream=current_stream,
)
# Print execution results
print(f"Kernel execution time: {elapsed_time / iterations:.4f} ms")
# Destroy events
cuda.cuEventDestroy(start_event)
cuda.cuEventDestroy(end_event)
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
torch.cuda.synchronize()
print("Verifying results...")
ref = torch.einsum("mk,nk->mn", a, b)
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
@ -768,6 +756,9 @@ if __name__ == "__main__":
args = parser.parse_args()
print("Running SIMT GEMM example:")
torch.manual_seed(1024)
main(
args.a_major,
args.b_major,

View File

@ -36,6 +36,7 @@ import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutlass.cute.runtime import from_dlpack
@ -48,6 +49,7 @@ A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE D
This GEMM kernel supports the following features:
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
- Threadblock rasterization to improve data re-use
- Supports multi-stage pipeline to overlap computation and memory access
- Implements shared memory buffering for epilogue to increase coalesed global memory access
@ -253,6 +255,22 @@ class TensorOpGemm:
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l)
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
# Add threadblock rasterization to improve re-use of data
raster_factor = 1
grid_dim_n = cute.size(grid_dim[1])
# Thresholds picked so that it doesn't cause too many no-op CTAs
if grid_dim_n > 5:
raster_factor = 8
elif grid_dim_n > 2:
raster_factor = 4
elif grid_dim_n > 1:
raster_factor = 2
rasterization_remap_grid_dim = (
cute.size(grid_dim[0]) * raster_factor,
(cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor,
cute.size(grid_dim[2]),
)
self.kernel(
mA,
mB,
@ -264,9 +282,10 @@ class TensorOpGemm:
tiled_copy_B,
tiled_copy_C,
tiled_mma,
raster_factor,
epilogue_op,
).launch(
grid=grid_dim,
grid=rasterization_remap_grid_dim,
block=[self.num_threads, 1, 1],
smem=smem_size,
)
@ -284,436 +303,445 @@ class TensorOpGemm:
tiled_copy_B: cute.TiledCopy,
tiled_copy_C: cute.TiledCopy,
tiled_mma: cute.TiledMma,
rasterization_factor: cutlass.Int32,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
# Thread index, block index
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
tiler_coord = (bidx, bidy, None)
# ///////////////////////////////////////////////////////////////////////////////
# Get the appropriate tiles for this thread block.
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
# ///////////////////////////////////////////////////////////////////////////////
gA = cute.local_tile(
mA[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, None, 1),
)
gB = cute.local_tile(
mB[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(None, 1, 1),
)
gC = cute.local_tile(
mC[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, 1, None),
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
offset_tile_x, offset_tile_y = self.raster_tile(
bidx, bidy, rasterization_factor
)
# Early exit if CTA is out of range
if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y:
pass
else:
tiler_coord = (offset_tile_x, offset_tile_y, None)
# By default, if the tensor k mode does not divide into the tile k
# size, then last tiles in the k dimension are irregular.
# Instead, make the first tiles irregular when k is irregular.
# This allows us to handle the irregular tile first to avoid
# checking for this condition within the mainloop.
# residual_k is a negative number indicating the amount needed to
# shift the pointer by in dimension k
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
gA, mode=[2]
)
# move the pointer of gA/gB in the `-k` direction
gA = cute.domain_offset((0, residual_k, 0), gA)
gB = cute.domain_offset((0, residual_k, 0), gB)
# input is 16B aligned
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
# Construct identity layout for sA and sB (mirrors global tensors,
# used for predication only)
mcA = cute.make_identity_tensor(mA.layout.shape)
mcB = cute.make_identity_tensor(mB.layout.shape)
cA = cute.local_tile(
mcA[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, None, 1),
)
cB = cute.local_tile(
mcB[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(None, 1, 1),
)
cA = cute.domain_offset((0, residual_k, 0), cA)
cB = cute.domain_offset((0, residual_k, 0), cB)
# ///////////////////////////////////////////////////////////////////////////////
# Create shared memory buffers and get the appropriate fragments for this thread.
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
# ///////////////////////////////////////////////////////////////////////////////
# Shared memory buffer
smem = cutlass.utils.SmemAllocator()
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
sC = cute.make_tensor(
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
thr_copy_C = tiled_copy_C.get_slice(tidx)
tAgA = thr_copy_A.partition_S(gA)
tAsA = thr_copy_A.partition_D(sA)
tBgB = thr_copy_B.partition_S(gB)
tBsB = thr_copy_B.partition_D(sB)
tCsC_epilogue = thr_copy_C.partition_S(sC)
tCgC_epilogue = thr_copy_C.partition_D(gC)
# Repeat the partitioning with identity layouts
tAcA = thr_copy_A.partition_S(cA)
tBcB = thr_copy_B.partition_S(cB)
# ///////////////////////////////////////////////////////////////////////////////
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
# of tile_shape
# ///////////////////////////////////////////////////////////////////////////////
# For predication over the tensors A (M/K), B (N/K), and (in the
# epilogue) C (M/N), we will compute it in a fashion similar to an
# outer product. The predication along one of the dimensions is
# evaluated and stored in a predication tensor. Then, the
# predication for the remaining dimension is handled later via an
# if/else branch at the copy.
# For A and B, predication booleans along M/N are stored in a
# predication tensor and along K is handled via a if/else branch.
# Allocate predicate tensors for M and N. Predication is checked
# at the granularity of a copy atom, so the predicate tensor does not
# need separate booleans for individual elements within a copy
# atom (for example, the elements of tAgA.shape[0][0].)
tApA = cute.make_fragment(
cute.make_layout(
(
tAgA.shape[0][1],
cute.size(tAgA, mode=[1]),
cute.size(tAgA, mode=[2]),
),
stride=(cute.size(tAgA, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
tBpB = cute.make_fragment(
cute.make_layout(
(
tBsB.shape[0][1],
cute.size(tBsB, mode=[1]),
cute.size(tBsB, mode=[2]),
),
stride=(cute.size(tBsB, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
# Set predicates for M/N bounds
for rest_v in range(tApA.shape[0]):
for m in range(tApA.shape[1]):
tApA[rest_v, m, 0] = cute.elem_less(
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
)
for rest_v in range(tBpB.shape[0]):
for n in range(tBpB.shape[1]):
tBpB[rest_v, n, 0] = cute.elem_less(
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
)
# ///////////////////////////////////////////////////////////////////////////////
# Prefetch Prologue
# ///////////////////////////////////////////////////////////////////////////////
# Clear the smem tiles to account for predicated off loads
tAsA.fill(0)
tBsB.fill(0)
cute.arch.sync_threads()
# Start async loads for the first k-tile. Here we take care of the k residue
# via if/else check along the k dimension. Because we shifted the identity tensor
# by the residue_k and because the identity tensor is a counting tensor, the
# values of any identity tensor element that is poison is less than -1
num_smem_stages = cute.size(tAsA, mode=[3])
k_tile_count = cute.size(tAgA, mode=[3])
k_tile_index = cutlass.Int32(0)
for k in range(tApA.shape[2]):
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
cute.copy(
tiled_copy_A,
tAgA[None, None, k, k_tile_index],
tAsA[None, None, k, 0],
pred=tApA[None, None, k],
)
for k in range(tBpB.shape[2]):
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
cute.copy(
tiled_copy_B,
tBgB[None, None, k, k_tile_index],
tBsB[None, None, k, 0],
pred=tBpB[None, None, k],
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
# Start async loads for rest of the k-tiles
for k_tile in range(1, num_smem_stages - 1):
if k_tile == k_tile_count:
tApA.fill(0)
tBpB.fill(0)
cute.copy(
tiled_copy_A,
tAgA[None, None, None, k_tile_index],
tAsA[None, None, None, k_tile],
pred=tApA,
# ///////////////////////////////////////////////////////////////////////////////
# Get the appropriate tiles for this thread block.
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
# ///////////////////////////////////////////////////////////////////////////////
gA = cute.local_tile(
mA[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, None, 1),
)
cute.copy(
tiled_copy_B,
tBgB[None, None, None, k_tile_index],
tBsB[None, None, None, k_tile],
pred=tBpB,
gB = cute.local_tile(
mB[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(None, 1, 1),
)
gC = cute.local_tile(
mC[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, 1, None),
)
# By default, if the tensor k mode does not divide into the tile k
# size, then last tiles in the k dimension are irregular.
# Instead, make the first tiles irregular when k is irregular.
# This allows us to handle the irregular tile first to avoid
# checking for this condition within the mainloop.
# residual_k is a negative number indicating the amount needed to
# shift the pointer by in dimension k
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
gA, mode=[2]
)
# move the pointer of gA/gB in the `-k` direction
gA = cute.domain_offset((0, residual_k, 0), gA)
gB = cute.domain_offset((0, residual_k, 0), gB)
# input is 16B aligned
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
# Construct identity layout for sA and sB (mirrors global tensors,
# used for predication only)
mcA = cute.make_identity_tensor(mA.layout.shape)
mcB = cute.make_identity_tensor(mB.layout.shape)
cA = cute.local_tile(
mcA[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, None, 1),
)
cB = cute.local_tile(
mcB[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(None, 1, 1),
)
cA = cute.domain_offset((0, residual_k, 0), cA)
cB = cute.domain_offset((0, residual_k, 0), cB)
# ///////////////////////////////////////////////////////////////////////////////
# Create shared memory buffers and get the appropriate fragments for this thread.
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
# ///////////////////////////////////////////////////////////////////////////////
# Shared memory buffer
smem = cutlass.utils.SmemAllocator()
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
sC = cute.make_tensor(
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
thr_copy_C = tiled_copy_C.get_slice(tidx)
tAgA = thr_copy_A.partition_S(gA)
tAsA = thr_copy_A.partition_D(sA)
tBgB = thr_copy_B.partition_S(gB)
tBsB = thr_copy_B.partition_D(sB)
tCsC_epilogue = thr_copy_C.partition_S(sC)
tCgC_epilogue = thr_copy_C.partition_D(gC)
# Repeat the partitioning with identity layouts
tAcA = thr_copy_A.partition_S(cA)
tBcB = thr_copy_B.partition_S(cB)
# ///////////////////////////////////////////////////////////////////////////////
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
# of tile_shape
# ///////////////////////////////////////////////////////////////////////////////
# For predication over the tensors A (M/K), B (N/K), and (in the
# epilogue) C (M/N), we will compute it in a fashion similar to an
# outer product. The predication along one of the dimensions is
# evaluated and stored in a predication tensor. Then, the
# predication for the remaining dimension is handled later via an
# if/else branch at the copy.
# For A and B, predication booleans along M/N are stored in a
# predication tensor and along K is handled via a if/else branch.
# Allocate predicate tensors for M and N. Predication is checked
# at the granularity of a copy atom, so the predicate tensor does not
# need separate booleans for individual elements within a copy
# atom (for example, the elements of tAgA.shape[0][0].)
tApA = cute.make_fragment(
cute.make_layout(
(
tAgA.shape[0][1],
cute.size(tAgA, mode=[1]),
cute.size(tAgA, mode=[2]),
),
stride=(cute.size(tAgA, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
tBpB = cute.make_fragment(
cute.make_layout(
(
tBsB.shape[0][1],
cute.size(tBsB, mode=[1]),
cute.size(tBsB, mode=[2]),
),
stride=(cute.size(tBsB, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
# Set predicates for M/N bounds
for rest_v in range(tApA.shape[0]):
for m in range(tApA.shape[1]):
tApA[rest_v, m, 0] = cute.elem_less(
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
)
for rest_v in range(tBpB.shape[0]):
for n in range(tBpB.shape[1]):
tBpB[rest_v, n, 0] = cute.elem_less(
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
)
# ///////////////////////////////////////////////////////////////////////////////
# Prefetch Prologue
# ///////////////////////////////////////////////////////////////////////////////
# Clear the smem tiles to account for predicated off loads
tAsA.fill(0)
tBsB.fill(0)
cute.arch.sync_threads()
# Start async loads for the first k-tile. Here we take care of the k residue
# via if/else check along the k dimension. Because we shifted the identity tensor
# by the residue_k and because the identity tensor is a counting tensor, the
# values of any identity tensor element that is poison is less than -1
num_smem_stages = cute.size(tAsA, mode=[3])
k_tile_count = cute.size(tAgA, mode=[3])
k_tile_index = cutlass.Int32(0)
for k in range(tApA.shape[2]):
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
cute.copy(
tiled_copy_A,
tAgA[None, None, k, k_tile_index],
tAsA[None, None, k, 0],
pred=tApA[None, None, k],
)
for k in range(tBpB.shape[2]):
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
cute.copy(
tiled_copy_B,
tBgB[None, None, k, k_tile_index],
tBsB[None, None, k, 0],
pred=tBpB[None, None, k],
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
# ///////////////////////////////////////////////////////////////////////////////
# Tile MMA compute thread partitions and allocate accumulators
# ///////////////////////////////////////////////////////////////////////////////
thr_mma = tiled_mma.get_slice(tidx)
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCsC = thr_mma.partition_C(sC)
tCgC = thr_mma.partition_C(gC)
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
tCrC = tiled_mma.make_fragment_C(tCgC)
# Clear the accumulator
tCrC.fill(0.0)
# Start async loads for rest of the k-tiles
for k_tile in range(1, num_smem_stages - 1):
if k_tile == k_tile_count:
tApA.fill(0)
tBpB.fill(0)
cute.copy(
tiled_copy_A,
tAgA[None, None, None, k_tile_index],
tAsA[None, None, None, k_tile],
pred=tApA,
)
cute.copy(
tiled_copy_B,
tBgB[None, None, None, k_tile_index],
tBsB[None, None, None, k_tile],
pred=tBpB,
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
# ///////////////////////////////////////////////////////////////////////////////
# Copy Atom A/B retiling
# ///////////////////////////////////////////////////////////////////////////////
# ///////////////////////////////////////////////////////////////////////////////
# Tile MMA compute thread partitions and allocate accumulators
# ///////////////////////////////////////////////////////////////////////////////
thr_mma = tiled_mma.get_slice(tidx)
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCsC = thr_mma.partition_C(sC)
tCgC = thr_mma.partition_C(gC)
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
tCrC = tiled_mma.make_fragment_C(tCgC)
# Clear the accumulator
tCrC.fill(0.0)
# Create the copy atoms for the copy from shared memory to register
atom_copy_s2r_A = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
),
mA.element_type,
)
atom_copy_s2r_B = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
),
mB.element_type,
)
# ///////////////////////////////////////////////////////////////////////////////
# Copy Atom A/B retiling
# ///////////////////////////////////////////////////////////////////////////////
# Creates the tiled copy so that it matches the thread-value layout
# expected by the tiled mma
tiled_copy_s2r_A = cute.make_tiled_copy(
atom_copy_s2r_A,
layout_tv=tiled_mma.tv_layout_A_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
)
tiled_copy_s2r_B = cute.make_tiled_copy(
atom_copy_s2r_B,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
# Current pipe index in smem to read from / write to
smem_pipe_read = 0
smem_pipe_write = num_smem_stages - 1
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
# ///////////////////////////////////////////////////////////////////////////////
# PREFETCH register pipeline
# ///////////////////////////////////////////////////////////////////////////////
num_k_block = cute.size(tCrA, mode=[2])
if num_k_block > 1:
# Wait until our first prefetched tile is loaded in
cute.arch.cp_async_wait_group(num_smem_stages - 2)
cute.arch.sync_threads()
# Prefetch the first k-block rmem from the first k-tile
cute.copy(
tiled_copy_s2r_A,
tCsA_p[None, None, 0],
tCrA_copy_view[None, None, 0],
# Create the copy atoms for the copy from shared memory to register
atom_copy_s2r_A = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
),
mA.element_type,
)
cute.copy(
tiled_copy_s2r_B,
tCsB_p[None, None, 0],
tCrB_copy_view[None, None, 0],
atom_copy_s2r_B = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
),
mB.element_type,
)
# ///////////////////////////////////////////////////////////////////////////////
# Mainloop
# 1. Shared memory pipeline (gmem -> smem):
# The default smem pipeline depth is 3, meaning that for shared
# memory buffers, we allocate three times the size described by the
# CTA tiler. We prefetch 2 of these buffers before entering the main
# loop. Considering only the transfer from global memory to shared
# memory, the general structure of the mainloop is:
# (1) copy k-tile from gmem to smem;
# (2) perform gemm computation on k-tile;
# (3) wait for the next copy to finish.
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
# waits for the number of unfinished 'copy' to be <= 1. The advantage
# of this approach is that it allows for simultaneous production
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
# A common misconception is to prefetch N buffers and rewrite
# the pipeline logic to wait on N-1 pending copies. The disadvantage
# of this approach is that it requires fully consuming a buffer in
# order to open an empty buffer for the next copy.
# 2. Register pipeline (smem -> register):
# Similarly, the register pipeline produces i+1, consumes i, and
# produces i+2... Notably, i and i+1 do not use the same register,
# eliminating dependencies on the same register for better parallelism.
# 3. Combining the smem and register pipelines results in the mainloop.
# ///////////////////////////////////////////////////////////////////////////////
for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
for k_block in range(num_k_block):
if k_block == num_k_block - 1:
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
cute.arch.cp_async_wait_group(num_smem_stages - 2)
cute.arch.sync_threads()
# Creates the tiled copy so that it matches the thread-value layout
# expected by the tiled mma
tiled_copy_s2r_A = cute.make_tiled_copy(
atom_copy_s2r_A,
layout_tv=tiled_mma.tv_layout_A_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
)
tiled_copy_s2r_B = cute.make_tiled_copy(
atom_copy_s2r_B,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
# Load A, B from shared memory to registers for k_block + 1
k_block_next = (k_block + 1) % num_k_block # static
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
# Current pipe index in smem to read from / write to
smem_pipe_read = 0
smem_pipe_write = num_smem_stages - 1
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
# ///////////////////////////////////////////////////////////////////////////////
# PREFETCH register pipeline
# ///////////////////////////////////////////////////////////////////////////////
num_k_block = cute.size(tCrA, mode=[2])
if num_k_block > 1:
# Wait until our first prefetched tile is loaded in
cute.arch.cp_async_wait_group(num_smem_stages - 2)
cute.arch.sync_threads()
# Prefetch the first k-block rmem from the first k-tile
cute.copy(
tiled_copy_s2r_A,
tCsA_p[None, None, k_block_next],
tCrA_copy_view[None, None, k_block_next],
tCsA_p[None, None, 0],
tCrA_copy_view[None, None, 0],
)
cute.copy(
tiled_copy_s2r_B,
tCsB_p[None, None, k_block_next],
tCrB_copy_view[None, None, k_block_next],
tCsB_p[None, None, 0],
tCrB_copy_view[None, None, 0],
)
# Fetch next A: To better interleave global memory access and compute
# instructions, we intentionally use the sequence: copy A, perform GEMM,
# then copy B.
if k_block == 0:
if k_tile + num_smem_stages - 1 < k_tile_count:
cute.copy(
tiled_copy_A,
tAgA[None, None, None, k_tile_index],
tAsA[None, None, None, smem_pipe_write],
pred=tApA,
)
# ///////////////////////////////////////////////////////////////////////////////
# Mainloop
# 1. Shared memory pipeline (gmem -> smem):
# The default smem pipeline depth is 3, meaning that for shared
# memory buffers, we allocate three times the size described by the
# CTA tiler. We prefetch 2 of these buffers before entering the main
# loop. Considering only the transfer from global memory to shared
# memory, the general structure of the mainloop is:
# (1) copy k-tile from gmem to smem;
# (2) perform gemm computation on k-tile;
# (3) wait for the next copy to finish.
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
# waits for the number of unfinished 'copy' to be <= 1. The advantage
# of this approach is that it allows for simultaneous production
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
# A common misconception is to prefetch N buffers and rewrite
# the pipeline logic to wait on N-1 pending copies. The disadvantage
# of this approach is that it requires fully consuming a buffer in
# order to open an empty buffer for the next copy.
# 2. Register pipeline (smem -> register):
# Similarly, the register pipeline produces i+1, consumes i, and
# produces i+2... Notably, i and i+1 do not use the same register,
# eliminating dependencies on the same register for better parallelism.
# 3. Combining the smem and register pipelines results in the mainloop.
# ///////////////////////////////////////////////////////////////////////////////
for k_tile in range(k_tile_count):
for k_block in cutlass.range(num_k_block, unroll_full=True):
if k_block == num_k_block - 1:
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
cute.arch.cp_async_wait_group(num_smem_stages - 2)
cute.arch.sync_threads()
# Thread-level register gemm for k_block
cute.gemm(
tiled_mma,
tCrC,
tCrA[None, None, k_block],
tCrB[None, None, k_block],
tCrC,
)
# Fetch next B and update smem pipeline read/write
if k_block == 0:
if k_tile + num_smem_stages - 1 < k_tile_count:
cute.copy(
tiled_copy_B,
tBgB[None, None, None, k_tile_index],
tBsB[None, None, None, smem_pipe_write],
pred=tBpB,
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
smem_pipe_write = smem_pipe_read
smem_pipe_read = smem_pipe_read + 1
if smem_pipe_read == num_smem_stages:
smem_pipe_read = 0
# Sync before epilogue
cute.arch.cp_async_wait_group(0)
cute.arch.sync_threads()
# ///////////////////////////////////////////////////////////////////////////////
# Epilogue with fusion
# ///////////////////////////////////////////////////////////////////////////////
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
# Copy results of D back to shared memory
cute.autovec_copy(tCrD, tCsC)
# Create counting tensor for C
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
mcC = cute.make_identity_tensor(
(
cute.size(ceilM) * self.cta_tiler[0],
cute.size(ceilN) * self.cta_tiler[1],
1,
)
)
cC = cute.local_tile(
mcC[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, 1, None),
)
tCcC = thr_copy_C.partition_S(cC)
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
# Wait for all writes to shared memory to finish before starting copies
# using the new layouts
cute.arch.sync_threads()
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
# Create predication tensor for m
tCpC = cute.make_fragment(
cute.make_layout(
(
tCgC_epilogue.shape[0][1],
cute.size(tCgC_epilogue, mode=[1]),
cute.size(tCgC_epilogue, mode=[2]),
),
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
for rest_v in range(tCpC.shape[0]):
for m in range(tCpC.shape[1]):
tCpC[rest_v, m, 0] = cute.elem_less(
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
)
# Copy to global memory using better vectorization
for rest_v in range(tCpC.shape[0]):
for n in range(tCpC.shape[2]):
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
# Load A, B from shared memory to registers for k_block + 1
k_block_next = (k_block + 1) % num_k_block # static
cute.copy(
tiled_copy_C,
tCrC_epilogue[None, None, n],
tCgC_epilogue[None, None, n],
pred=tCpC[None, None, n],
tiled_copy_s2r_A,
tCsA_p[None, None, k_block_next],
tCrA_copy_view[None, None, k_block_next],
)
cute.copy(
tiled_copy_s2r_B,
tCsB_p[None, None, k_block_next],
tCrB_copy_view[None, None, k_block_next],
)
# Fetch next A: To better interleave global memory access and compute
# instructions, we intentionally use the sequence: copy A, perform GEMM,
# then copy B.
if k_block == 0:
if k_tile + num_smem_stages - 1 < k_tile_count:
cute.copy(
tiled_copy_A,
tAgA[None, None, None, k_tile_index],
tAsA[None, None, None, smem_pipe_write],
pred=tApA,
)
# Thread-level register gemm for k_block
cute.gemm(
tiled_mma,
tCrC,
tCrA[None, None, k_block],
tCrB[None, None, k_block],
tCrC,
)
# Fetch next B and update smem pipeline read/write
if k_block == 0:
if k_tile + num_smem_stages - 1 < k_tile_count:
cute.copy(
tiled_copy_B,
tBgB[None, None, None, k_tile_index],
tBsB[None, None, None, smem_pipe_write],
pred=tBpB,
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
smem_pipe_write = smem_pipe_read
smem_pipe_read = smem_pipe_read + 1
if smem_pipe_read == num_smem_stages:
smem_pipe_read = 0
# Sync before epilogue
cute.arch.cp_async_wait_group(0)
cute.arch.sync_threads()
# ///////////////////////////////////////////////////////////////////////////////
# Epilogue with fusion
# ///////////////////////////////////////////////////////////////////////////////
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
# Copy results of D back to shared memory
cute.autovec_copy(tCrD, tCsC)
# Create counting tensor for C
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
mcC = cute.make_identity_tensor(
(
cute.size(ceilM) * self.cta_tiler[0],
cute.size(ceilN) * self.cta_tiler[1],
1,
)
)
cC = cute.local_tile(
mcC[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, 1, None),
)
tCcC = thr_copy_C.partition_S(cC)
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
# Wait for all writes to shared memory to finish before starting copies
# using the new layouts
cute.arch.sync_threads()
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
# Create predication tensor for m
tCpC = cute.make_fragment(
cute.make_layout(
(
tCgC_epilogue.shape[0][1],
cute.size(tCgC_epilogue, mode=[1]),
cute.size(tCgC_epilogue, mode=[2]),
),
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
for rest_v in range(tCpC.shape[0]):
for m in range(tCpC.shape[1]):
tCpC[rest_v, m, 0] = cute.elem_less(
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
)
# Copy to global memory using better vectorization
for rest_v in range(tCpC.shape[0]):
for n in range(tCpC.shape[2]):
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
cute.copy(
tiled_copy_C,
tCrC_epilogue[None, None, n],
tCgC_epilogue[None, None, n],
pred=tCpC[None, None, n],
)
return
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
@ -811,6 +839,11 @@ class TensorOpGemm:
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
def raster_tile(self, i, j, f):
new_i = i // f
new_j = (i % f) + (j * f)
return (new_i, new_j)
def run_tensor_op_gemm(
a_major: str,
@ -892,15 +925,18 @@ def run_tensor_op_gemm(
print("Executing GEMM kernel...")
# Warmup
for _ in range(warmup_iterations):
gemm(a_tensor, b_tensor, c_tensor)
avg_time_us = testing.benchmark(
gemm,
kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor),
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
use_cuda_graphs=False,
)
# Execute the kernel
for _ in range(iterations):
gemm(a_tensor, b_tensor, c_tensor)
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
print("Verifying results...")
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")

View File

@ -35,6 +35,7 @@ import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.torch as cutlass_torch
import cutlass.utils.blackwell_helpers as sm100_utils
@ -211,7 +212,7 @@ class DenseGemmKernel:
self.occupancy = 1
self.threads_per_cta = 128
self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -283,7 +284,7 @@ class DenseGemmKernel:
self.epi_tile,
self.c_dtype,
self.c_layout,
self.num_smem_capacity,
self.smem_capacity,
self.occupancy,
self.use_tma_store,
)
@ -308,7 +309,7 @@ class DenseGemmKernel:
self.epi_tile,
self.num_c_stage,
)
if cutlass.const_expr(self.use_tma_store)
if self.use_tma_store
else None
)
@ -372,9 +373,11 @@ class DenseGemmKernel:
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
# Setup TMA load for A
a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast)
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A(
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
a,
a_smem_layout,
@ -387,9 +390,11 @@ class DenseGemmKernel:
)
# Setup TMA load for B
b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast)
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B(
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
b,
b_smem_layout,
@ -413,7 +418,7 @@ class DenseGemmKernel:
cute.make_identity_layout(c.shape), self.epi_tile
)
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom(
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
@ -426,9 +431,7 @@ class DenseGemmKernel:
self.buffer_align_bytes = 1024
c_smem_size = (
cute.cosize(self.c_smem_layout_staged.outer)
if cutlass.const_expr(self.use_tma_store)
else 0
cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0
)
# Define shared storage for kernel
@ -472,7 +475,7 @@ class DenseGemmKernel:
tma_atom_b,
tma_tensor_b,
tma_atom_c,
tma_tensor_c if cutlass.const_expr(self.use_tma_store) else c,
tma_tensor_c if self.use_tma_store else c,
self.cluster_layout_vmnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
@ -556,12 +559,12 @@ class DenseGemmKernel:
tmem_holding_buf = storage.tmem_holding_buf
# Initialize mainloop ab_pipeline (barrier) and states
ab_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
ab_pipeline_consumer_group = utils.CooperativeGroup(
utils.Agent.Thread, num_tma_producer
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_tma_producer
)
ab_pipeline = utils.PipelineTmaUmma.create(
ab_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=ab_pipeline_producer_group,
@ -569,30 +572,30 @@ class DenseGemmKernel:
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
)
ab_producer_state = utils.make_pipeline_state(
utils.PipelineUserType.Producer, self.num_ab_stage
ab_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_ab_stage
)
ab_consumer_state = utils.make_pipeline_state(
utils.PipelineUserType.Consumer, self.num_ab_stage
ab_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
# Initialize acc_pipeline (barrier) and states
acc_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
acc_pipeline_consumer_group = utils.CooperativeGroup(
utils.Agent.Thread, self.threads_per_cta, self.threads_per_cta
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
)
acc_pipeline = utils.PipelineUmmaAsync.create(
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=acc_pipeline_producer_group,
consumer_group=acc_pipeline_consumer_group,
cta_layout_vmnk=cluster_layout_vmnk,
)
acc_producer_state = utils.make_pipeline_state(
utils.PipelineUserType.Producer, self.num_acc_stage
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
acc_consumer_state = utils.make_pipeline_state(
utils.PipelineUserType.Consumer, self.num_acc_stage
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
# Tensor memory dealloc barrier init
@ -600,7 +603,7 @@ class DenseGemmKernel:
if warp_idx == 0:
num_tmem_dealloc_threads = 32
with cute.arch.elect_one():
cute.arch.mbarrier_init_arrive_cnt(
cute.arch.mbarrier_init(
tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads
)
cute.arch.mbarrier_init_fence()
@ -617,7 +620,7 @@ class DenseGemmKernel:
storage.sC.get_tensor(
c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner
)
if cutlass.const_expr(self.use_tma_store)
if self.use_tma_store
else None
)
# (MMA, MMA_M, MMA_K, STAGE)
@ -634,7 +637,7 @@ class DenseGemmKernel:
#
a_full_mcast_mask = None
b_full_mcast_mask = None
if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs:
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
@ -645,15 +648,15 @@ class DenseGemmKernel:
#
# Local_tile partition global tensors
#
# (bM, bK, loopM, loopK, loopL)
# (bM, bK, RestM, RestK, RestL)
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
# (bN, bK, loopN, loopK, loopL)
# (bN, bK, RestN, RestK, RestL)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
# (bM, bN, loopM, loopN, loopL)
# (bM, bN, RestM, RestN, RestL)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
@ -663,11 +666,11 @@ class DenseGemmKernel:
# Partition global tensor for TiledMMA_A/B/C
#
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
tCgA = thr_mma.partition_A(gA_mkl)
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
tCgB = thr_mma.partition_B(gB_nkl)
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
tCgC = thr_mma.partition_C(gC_mnl)
#
@ -678,7 +681,7 @@ class DenseGemmKernel:
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
# ((atom_v, rest_v), RestM, RestK, RestL)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
@ -691,7 +694,7 @@ class DenseGemmKernel:
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopN, loopK, loopL)
# ((atom_v, rest_v), RestN, RestK, RestL)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
@ -771,9 +774,9 @@ class DenseGemmKernel:
#
# Slice to per mma tile index
#
# ((atom_v, rest_v), loopK)
# ((atom_v, rest_v), RestK)
tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])]
# ((atom_v, rest_v), loopK)
# ((atom_v, rest_v), RestK)
tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])]
if cutlass.const_expr(self.use_tma_store):
# ((ATOM_V, REST_V), EPI_M, EPI_N)
@ -797,7 +800,7 @@ class DenseGemmKernel:
#
# Prefetch TMA load A/B
#
for prefetch_idx in cutlass.range_dynamic(prefetch_k_block_cnt, unroll=1):
for prefetch_idx in cutlass.range(prefetch_k_block_cnt, unroll=1):
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
@ -833,7 +836,7 @@ class DenseGemmKernel:
#
# MMA mainloop
#
for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1):
for k_block in range(k_block_cnt):
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
@ -860,7 +863,7 @@ class DenseGemmKernel:
# tCtAcc += tCrA * tCrB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in range(num_kphases):
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (None, None, kphase_idx, ab_consumer_state.index)
cute.gemm(
@ -917,10 +920,10 @@ class DenseGemmKernel:
c_pipeline = None
if cutlass.const_expr(self.use_tma_store):
# Initialize tma store c_pipeline
c_producer_group = utils.CooperativeGroup(
utils.Agent.Thread, self.threads_per_cta, self.threads_per_cta
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
)
c_pipeline = utils.PipelineTmaStore.create(
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
producer_group=c_producer_group,
)
@ -929,7 +932,7 @@ class DenseGemmKernel:
# Store accumulator to global memory in subtiles
#
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
for subtile_idx in cutlass.range_dynamic(subtile_cnt):
for subtile_idx in range(subtile_cnt):
#
# Load accumulator from tensor memory buffer to register
#
@ -1007,7 +1010,7 @@ class DenseGemmKernel:
#
if warp_idx == 0:
# Reverse prefetch_k_block_cnt times to next available buffer
for i in cutlass.range_dynamic(prefetch_k_block_cnt):
for i in range(prefetch_k_block_cnt):
ab_producer_state.reverse()
ab_pipeline.producer_tail(ab_producer_state)
return
@ -1063,11 +1066,11 @@ class DenseGemmKernel:
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
gC_mnl_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
# (T2R, T2R_M, T2R_N)
tTR_rAcc = cute.make_fragment(
@ -1149,7 +1152,7 @@ class DenseGemmKernel:
- tTR_gC: The partitioned global tensor C
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
"""
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
gC_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
@ -1158,7 +1161,7 @@ class DenseGemmKernel:
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
# ((ATOM_V, REST_V), EPI_M, EPI_N)
# ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL)
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_atom_c,
0,
@ -1169,7 +1172,7 @@ class DenseGemmKernel:
return tma_atom_c, bSG_sC, bSG_gC
else:
tiled_copy_t2r = atom
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
tTR_gC = thr_copy_t2r.partition_D(gC_epi)
# (T2R, T2R_M, T2R_N)
@ -1188,7 +1191,7 @@ class DenseGemmKernel:
epi_tile: cute.Tile,
c_dtype: Type[cutlass.Numeric],
c_layout: utils.LayoutEnum,
num_smem_capacity: int,
smem_capacity: int,
occupancy: int,
use_tma_store: bool,
) -> Tuple[int, int, int]:
@ -1208,8 +1211,8 @@ class DenseGemmKernel:
:type c_dtype: type[cutlass.Numeric]
:param c_layout: Layout enum of operand C in global memory.
:type c_layout: utils.LayoutEnum
:param num_smem_capacity: Total available shared memory capacity in bytes.
:type num_smem_capacity: int
:param smem_capacity: Total available shared memory capacity in bytes.
:type smem_capacity: int
:param occupancy: Target number of CTAs per SM (occupancy).
:type occupancy: int
:param use_tma_store: Whether TMA store is enabled.
@ -1263,7 +1266,7 @@ class DenseGemmKernel:
# Subtract reserved bytes and initial C stages bytes
# Divide remaining by bytes needed per A/B stage
num_ab_stage = (
num_smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes)
smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes)
) // ab_bytes_per_stage
# Refine epilogue stages:
@ -1271,7 +1274,7 @@ class DenseGemmKernel:
# Add remaining unused smem to epilogue
if use_tma_store:
num_c_stage += (
num_smem_capacity
smem_capacity
- ab_bytes_per_stage * num_ab_stage
- (occupancy + 1) * (mbar_helpers_bytes + c_bytes)
) // ((occupancy + 1) * c_bytes_per_stage)
@ -1309,36 +1312,6 @@ class DenseGemmKernel:
return grid
@staticmethod
def _get_tma_atom_kind(
atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean
) -> Union[
cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp
]:
"""
Select the appropriate TMA copy atom based on the number of SMs and the multicast flag.
:param atom_sm_cnt: The number of SMs
:type atom_sm_cnt: cutlass.Int32
:param mcast: The multicast flag
:type mcast: cutlass.Boolean
:return: The appropriate TMA copy atom kind
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
:raise ValueError: If the atom_sm_cnt is invalid
"""
if atom_sm_cnt == 2 and mcast:
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
elif atom_sm_cnt == 2 and not mcast:
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO)
elif atom_sm_cnt == 1 and mcast:
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE)
elif atom_sm_cnt == 1 and not mcast:
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}")
@staticmethod
def _compute_num_tmem_alloc_cols(
tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int]

View File

@ -37,6 +37,7 @@ import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack
@ -225,7 +226,7 @@ class PersistentDenseGemmKernel:
self.cta_sync_bar_id = 0
self.epilog_sync_bar_id = 1
self.tmem_ptr_sync_bar_id = 2
self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -297,7 +298,7 @@ class PersistentDenseGemmKernel:
self.epi_tile,
self.c_dtype,
self.c_layout,
self.num_smem_capacity,
self.smem_capacity,
self.occupancy,
self.use_tma_store,
)
@ -389,9 +390,11 @@ class PersistentDenseGemmKernel:
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
# Setup TMA load for A
a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast)
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A(
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
a,
a_smem_layout,
@ -404,9 +407,11 @@ class PersistentDenseGemmKernel:
)
# Setup TMA load for B
b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast)
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B(
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
b,
b_smem_layout,
@ -430,7 +435,7 @@ class PersistentDenseGemmKernel:
cute.make_identity_layout(c.shape), self.epi_tile
)
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom(
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
@ -571,12 +576,12 @@ class PersistentDenseGemmKernel:
tmem_holding_buf = storage.tmem_holding_buf
# Initialize mainloop ab_pipeline (barrier) and states
ab_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
ab_pipeline_consumer_group = utils.CooperativeGroup(
utils.Agent.Thread, num_tma_producer
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_tma_producer
)
ab_pipeline = utils.PipelineTmaUmma.create(
ab_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=ab_pipeline_producer_group,
@ -586,14 +591,14 @@ class PersistentDenseGemmKernel:
)
# Initialize acc_pipeline (barrier) and states
acc_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_acc_consumer_threads = len(self.epilog_warp_id) * (
2 if use_2cta_instrs else 1
)
acc_pipeline_consumer_group = utils.CooperativeGroup(
utils.Agent.Thread, num_acc_consumer_threads
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_acc_consumer_threads
)
acc_pipeline = utils.PipelineUmmaAsync.create(
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=acc_pipeline_producer_group,
@ -606,7 +611,7 @@ class PersistentDenseGemmKernel:
if warp_idx == self.tma_warp_id:
num_tmem_dealloc_threads = 32
with cute.arch.elect_one():
cute.arch.mbarrier_init_arrive_cnt(
cute.arch.mbarrier_init(
tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads
)
cute.arch.mbarrier_init_fence()
@ -640,7 +645,7 @@ class PersistentDenseGemmKernel:
#
a_full_mcast_mask = None
b_full_mcast_mask = None
if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs:
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
@ -651,15 +656,15 @@ class PersistentDenseGemmKernel:
#
# Local_tile partition global tensors
#
# (bM, bK, loopM, loopK, loopL)
# (bM, bK, RestM, RestK, RestL)
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
# (bN, bK, loopN, loopK, loopL)
# (bN, bK, RestN, RestK, RestL)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
# (bM, bN, loopM, loopN, loopL)
# (bM, bN, RestM, RestN, RestL)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
@ -669,11 +674,11 @@ class PersistentDenseGemmKernel:
# Partition global tensor for TiledMMA_A/B/C
#
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
tCgA = thr_mma.partition_A(gA_mkl)
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
tCgB = thr_mma.partition_B(gB_nkl)
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
tCgC = thr_mma.partition_C(gC_mnl)
#
@ -684,7 +689,7 @@ class PersistentDenseGemmKernel:
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
# ((atom_v, rest_v), RestM, RestK, RestL)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
@ -697,7 +702,7 @@ class PersistentDenseGemmKernel:
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
# ((atom_v, rest_v), RestM, RestK, RestL)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
@ -743,12 +748,11 @@ class PersistentDenseGemmKernel:
)
work_tile = tile_sched.initial_work_tile_info()
ab_producer_state = utils.make_pipeline_state(
utils.PipelineUserType.Producer, self.num_ab_stage
ab_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_ab_stage
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
@ -760,11 +764,11 @@ class PersistentDenseGemmKernel:
#
# Slice to per mma tile index
#
# ((atom_v, rest_v), loopK)
# ((atom_v, rest_v), RestK)
tAgA_slice = tAgA[
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
]
# ((atom_v, rest_v), loopK)
# ((atom_v, rest_v), RestK)
tBgB_slice = tBgB[
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
@ -779,7 +783,7 @@ class PersistentDenseGemmKernel:
#
# Tma load loop
#
for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1):
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(
ab_producer_state, peek_ab_empty_status
@ -852,15 +856,14 @@ class PersistentDenseGemmKernel:
)
work_tile = tile_sched.initial_work_tile_info()
ab_consumer_state = utils.make_pipeline_state(
utils.PipelineUserType.Consumer, self.num_ab_stage
ab_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
acc_producer_state = utils.make_pipeline_state(
utils.PipelineUserType.Producer, self.num_acc_stage
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
@ -895,7 +898,7 @@ class PersistentDenseGemmKernel:
#
# Mma mainloop
#
for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1):
for k_block in range(k_block_cnt):
if is_leader_cta:
# Conditionally wait for AB buffer full
ab_pipeline.consumer_wait(
@ -904,7 +907,7 @@ class PersistentDenseGemmKernel:
# tCtAcc += tCrA * tCrB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in range(num_kphases):
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (
None,
None,
@ -989,10 +992,12 @@ class PersistentDenseGemmKernel:
# Partition for epilogue
#
epi_tidx = tidx
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = (
self.epilog_tmem_copy_and_partition(
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
)
(
tiled_copy_t2r,
tTR_tAcc_base,
tTR_rAcc,
) = self.epilog_tmem_copy_and_partition(
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
)
tTR_rC = None
@ -1008,16 +1013,20 @@ class PersistentDenseGemmKernel:
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
tiled_copy_t2r, tTR_rC, epi_tidx, sC
)
tma_atom_c, bSG_sC, bSG_gC_partitioned = (
self.epilog_gmem_copy_and_partition(
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
)
(
tma_atom_c,
bSG_sC,
bSG_gC_partitioned,
) = self.epilog_gmem_copy_and_partition(
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
)
else:
simt_atom, tTR_rC, tTR_gC_partitioned = (
self.epilog_gmem_copy_and_partition(
epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC
)
(
simt_atom,
tTR_rC,
tTR_gC_partitioned,
) = self.epilog_gmem_copy_and_partition(
epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC
)
#
@ -1028,25 +1037,24 @@ class PersistentDenseGemmKernel:
)
work_tile = tile_sched.initial_work_tile_info()
acc_consumer_state = utils.make_pipeline_state(
utils.PipelineUserType.Consumer, self.num_acc_stage
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
c_pipeline = None
if cutlass.const_expr(self.use_tma_store):
# Threads/warps participating in tma store pipeline
c_producer_group = utils.CooperativeGroup(
utils.Agent.Thread,
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilog_warp_id),
32 * len(self.epilog_warp_id),
)
c_pipeline = utils.PipelineTmaStore.create(
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
producer_group=c_producer_group,
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
@ -1105,7 +1113,7 @@ class PersistentDenseGemmKernel:
#
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
for subtile_idx in cutlass.range_dynamic(subtile_cnt):
for subtile_idx in cutlass.range(subtile_cnt):
#
# Load accumulator from tensor memory buffer to register
#
@ -1259,11 +1267,11 @@ class PersistentDenseGemmKernel:
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
gC_mnl_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
# (T2R, T2R_M, T2R_N)
tTR_rAcc = cute.make_fragment(
@ -1346,7 +1354,7 @@ class PersistentDenseGemmKernel:
- tTR_gC: The partitioned global tensor C
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
"""
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
gC_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
@ -1355,7 +1363,7 @@ class PersistentDenseGemmKernel:
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
# ((ATOM_V, REST_V), EPI_M, EPI_N)
# ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL)
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_atom_c,
0,
@ -1366,7 +1374,7 @@ class PersistentDenseGemmKernel:
return tma_atom_c, bSG_sC, bSG_gC
else:
tiled_copy_t2r = atom
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
tTR_gC = thr_copy_t2r.partition_D(gC_epi)
# (T2R, T2R_M, T2R_N)
@ -1385,7 +1393,7 @@ class PersistentDenseGemmKernel:
epi_tile: cute.Tile,
c_dtype: Type[cutlass.Numeric],
c_layout: utils.LayoutEnum,
num_smem_capacity: int,
smem_capacity: int,
occupancy: int,
use_tma_store: bool,
) -> Tuple[int, int, int]:
@ -1405,8 +1413,8 @@ class PersistentDenseGemmKernel:
:type c_dtype: type[cutlass.Numeric]
:param c_layout: Layout enum of operand C.
:type c_layout: utils.LayoutEnum
:param num_smem_capacity: Total available shared memory capacity in bytes.
:type num_smem_capacity: int
:param smem_capacity: Total available shared memory capacity in bytes.
:type smem_capacity: int
:param occupancy: Target number of CTAs per SM (occupancy).
:type occupancy: int
:param use_tma_store: Whether TMA store is enabled.
@ -1461,7 +1469,7 @@ class PersistentDenseGemmKernel:
# Subtract reserved bytes and initial C stages bytes
# Divide remaining by bytes needed per A/B stage
num_ab_stage = (
num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
) // ab_bytes_per_stage
# Refine epilogue stages:
@ -1469,7 +1477,7 @@ class PersistentDenseGemmKernel:
# Add remaining unused smem to epilogue
if use_tma_store:
num_c_stage += (
num_smem_capacity
smem_capacity
- occupancy * ab_bytes_per_stage * num_ab_stage
- occupancy * (mbar_helpers_bytes + c_bytes)
) // (occupancy * c_bytes_per_stage)
@ -1512,36 +1520,6 @@ class PersistentDenseGemmKernel:
return tile_sched_params, grid
@staticmethod
def _get_tma_atom_kind(
atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean
) -> Union[
cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp
]:
"""
Select the appropriate TMA copy atom based on the number of SMs and the multicast flag.
:param atom_sm_cnt: The number of SMs
:type atom_sm_cnt: cutlass.Int32
:param mcast: The multicast flag
:type mcast: cutlass.Boolean
:return: The appropriate TMA copy atom kind
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
:raise ValueError: If the atom_sm_cnt is invalid
"""
if atom_sm_cnt == 2 and mcast:
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
elif atom_sm_cnt == 2 and not mcast:
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO)
elif atom_sm_cnt == 1 and mcast:
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE)
elif atom_sm_cnt == 1 and not mcast:
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}")
@staticmethod
def _compute_num_tmem_alloc_cols(
tiled_mma: cute.TiledMma,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -40,7 +40,6 @@ import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.torch as cutlass_torch
from cutlass.cute.runtime import from_dlpack
"""
A grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL
@ -89,7 +88,6 @@ there are also the following constrains:
class GroupedGemmKernel:
def __init__(
self,
acc_dtype: type[cutlass.Numeric],
@ -159,7 +157,7 @@ class GroupedGemmKernel:
self.tmem_ptr_sync_bar_id = 2
# Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion
self.tensormap_ab_init_bar_id = 4
self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.num_tma_load_bytes = 0
def _setup_attributes(self):
@ -217,18 +215,20 @@ class GroupedGemmKernel:
)
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
self.num_acc_stage, self.num_ab_stage, self.num_epi_stage = (
self._compute_stages(
tiled_mma,
self.mma_tiler,
self.a_dtype,
self.b_dtype,
self.epi_tile,
self.c_dtype,
self.c_layout,
self.num_smem_capacity,
self.occupancy,
)
(
self.num_acc_stage,
self.num_ab_stage,
self.num_epi_stage,
) = self._compute_stages(
tiled_mma,
self.mma_tiler,
self.a_dtype,
self.b_dtype,
self.epi_tile,
self.c_dtype,
self.c_layout,
self.smem_capacity,
self.occupancy,
)
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
@ -355,9 +355,11 @@ class GroupedGemmKernel:
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
# Setup TMA load for A
a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast)
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A(
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
initial_a,
a_smem_layout,
@ -367,9 +369,11 @@ class GroupedGemmKernel:
)
# Setup TMA load for B
b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast)
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B(
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
initial_b,
b_smem_layout,
@ -389,7 +393,7 @@ class GroupedGemmKernel:
cute.make_identity_layout(initial_c.shape), self.epi_tile
)
epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom(
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
initial_c,
epi_smem_layout,
@ -403,9 +407,7 @@ class GroupedGemmKernel:
self.buffer_align_bytes = 1024
self.size_tensormap_in_i64 = (
0
if cutlass.const_expr(
self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM
)
if self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM
else GroupedGemmKernel.num_tensormaps
* GroupedGemmKernel.bytes_per_tensormap
// 8
@ -564,16 +566,16 @@ class GroupedGemmKernel:
for k_stage in range(self.num_ab_stage):
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
with cute.arch.elect_one():
cute.arch.mbarrier_init_arrive_cnt(ab_full_mbar_ptr + k_stage, 1)
cute.arch.mbarrier_init_arrive_cnt(
cute.arch.mbarrier_init(ab_full_mbar_ptr + k_stage, 1)
cute.arch.mbarrier_init(
ab_empty_mbar_ptr + k_stage, num_tma_producer
)
# Accumulator barrier init
if warp_idx == self.mma_warp_id:
for acc_stage in range(self.num_acc_stage):
with cute.arch.elect_one():
cute.arch.mbarrier_init_arrive_cnt(acc_full_mbar_ptr + acc_stage, 1)
cute.arch.mbarrier_init_arrive_cnt(
cute.arch.mbarrier_init(acc_full_mbar_ptr + acc_stage, 1)
cute.arch.mbarrier_init(
acc_empty_mbar_ptr + acc_stage, 8 if use_2cta_instrs else 4
)
# Tensor memory dealloc barrier init
@ -581,7 +583,7 @@ class GroupedGemmKernel:
if warp_idx == self.tma_warp_id:
num_tmem_dealloc_threads = 32
with cute.arch.elect_one():
cute.arch.mbarrier_init_arrive_cnt(
cute.arch.mbarrier_init(
tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads
)
cute.arch.mbarrier_init_fence()
@ -612,7 +614,7 @@ class GroupedGemmKernel:
a_full_mcast_mask = None
b_full_mcast_mask = None
ab_empty_mcast_mask = None
if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs:
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
@ -621,7 +623,7 @@ class GroupedGemmKernel:
)
ab_empty_mcast_mask = a_full_mcast_mask | b_full_mcast_mask
acc_full_mcast_mask = None
if use_2cta_instrs:
if cutlass.const_expr(use_2cta_instrs):
acc_full_mcast_mask = cute.make_layout_image_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mode=0
)
@ -646,15 +648,15 @@ class GroupedGemmKernel:
#
# Local_tile partition global tensors
#
# (bM, bK, loopM, loopK, loopL)
# (bM, bK, RestM, RestK, RestL)
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
# (bN, bK, loopN, loopK, loopL)
# (bN, bK, RestN, RestK, RestL)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
# (bM, bN, loopM, loopN, loopL)
# (bM, bN, RestM, RestN, RestL)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
@ -663,11 +665,11 @@ class GroupedGemmKernel:
# Partition global tensor for TiledMMA_A/B/C
#
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
tCgA = thr_mma.partition_A(gA_mkl)
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
tCgB = thr_mma.partition_B(gB_nkl)
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
tCgC = thr_mma.partition_C(gC_mnl)
#
@ -677,7 +679,7 @@ class GroupedGemmKernel:
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
# ((atom_v, rest_v), RestM, RestK, RestL)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
@ -690,7 +692,7 @@ class GroupedGemmKernel:
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
# ((atom_v, rest_v), RestM, RestK, RestL)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
@ -849,11 +851,11 @@ class GroupedGemmKernel:
#
# Slice to per mma tile index
#
# ((atom_v, rest_v), loopK)
# ((atom_v, rest_v), RestK)
tAgA_slice = tAgA[
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
]
# ((atom_v, rest_v), loopK)
# ((atom_v, rest_v), RestK)
tBgB_slice = tBgB[
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
@ -867,7 +869,7 @@ class GroupedGemmKernel:
tma_wr_ab_empty_phase = (
num_prev_k_blk + tma_wr_k_block
) // self.num_ab_stage % 2 ^ 1
peek_ab_empty_status = cute.arch.conditional_mbarrier_try_wait(
peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait(
tma_wr_k_block < cur_k_block_cnt,
ab_empty_mbar_ptr + smem_wr_buffer,
tma_wr_ab_empty_phase,
@ -879,7 +881,7 @@ class GroupedGemmKernel:
#
# Tma load loop
#
for k_block in cutlass.range_dynamic(0, cur_k_block_cnt, 1, unroll=1):
for k_block in cutlass.range(0, cur_k_block_cnt, 1, unroll=1):
tma_wr_k_block_next = tma_wr_k_block + 1
smem_wr_buffer_next = (
num_prev_k_blk + tma_wr_k_block_next
@ -898,10 +900,10 @@ class GroupedGemmKernel:
ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase
)
# Init AB buffer full transaction byte
# Arrive AB buffer and expect full transaction bytes
if is_leader_cta:
with cute.arch.elect_one():
cute.arch.mbarrier_init_tx_bytes(
cute.arch.mbarrier_arrive_and_expect_tx(
smem_full_mbar_ptr, self.num_tma_load_bytes
)
@ -930,7 +932,7 @@ class GroupedGemmKernel:
)
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
peek_ab_empty_status = cute.arch.conditional_mbarrier_try_wait(
peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait(
tma_wr_k_block_next < cur_k_block_cnt,
ab_empty_mbar_ptr + smem_wr_buffer_next,
tma_wr_ab_empty_phase_next,
@ -999,11 +1001,12 @@ class GroupedGemmKernel:
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
# MMA warp is only interested in number of tiles along K dimension
cur_k_block_cnt, cur_group_idx = (
group_gemm_ts_helper.search_cluster_tile_count_k(
cur_tile_coord,
problem_sizes_mnkl,
)
(
cur_k_block_cnt,
cur_group_idx,
) = group_gemm_ts_helper.search_cluster_tile_count_k(
cur_tile_coord,
problem_sizes_mnkl,
)
# Set tensor memory buffer for current tile
acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage
@ -1022,7 +1025,7 @@ class GroupedGemmKernel:
mma_rd_ab_full_phase = (
(num_prev_k_blk + mma_rd_k_block) // self.num_ab_stage % 2
)
peek_ab_full_status = cute.arch.conditional_mbarrier_try_wait(
peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait(
need_check_rd_buffer_full,
ab_full_mbar_ptr + smem_rd_buffer,
mma_rd_ab_full_phase,
@ -1047,7 +1050,7 @@ class GroupedGemmKernel:
#
# Mma mainloop
#
for k_block in cutlass.range_dynamic(0, cur_k_block_cnt, 1, unroll=1):
for k_block in range(cur_k_block_cnt):
mma_rd_k_block_next = cutlass.Int32(k_block + 1)
smem_rd_buffer_next = (
num_prev_k_blk + mma_rd_k_block_next
@ -1066,7 +1069,7 @@ class GroupedGemmKernel:
# tCtAcc += tCrA * tCrB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in range(num_kphases):
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (None, None, kphase_idx, smem_rd_buffer)
cute.gemm(
@ -1092,7 +1095,7 @@ class GroupedGemmKernel:
mma_rd_k_block_next < cur_k_block_cnt and is_leader_cta
)
peek_ab_full_status = cute.arch.conditional_mbarrier_try_wait(
peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait(
need_check_rd_buffer_full,
ab_full_mbar_ptr + smem_rd_buffer_next,
mma_rd_ab_full_phase_next,
@ -1161,19 +1164,23 @@ class GroupedGemmKernel:
#
# Partition for epilogue
#
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = (
self.epilog_tmem_copy_and_partition(
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
)
(
tiled_copy_t2r,
tTR_tAcc_base,
tTR_rAcc,
) = self.epilog_tmem_copy_and_partition(
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
)
tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
tiled_copy_t2r, tTR_rC, epi_tidx, sC
)
tma_atom_c, bSG_sC, bSG_gC_partitioned = (
self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC)
)
(
tma_atom_c,
bSG_sC,
bSG_gC_partitioned,
) = self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC)
#
# Persistent tile scheduling loop
@ -1270,7 +1277,7 @@ class GroupedGemmKernel:
#
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
for subtile_idx in cutlass.range_dynamic(subtile_cnt):
for subtile_idx in range(subtile_cnt):
#
# Load accumulator from tensor memory buffer to register
#
@ -1493,11 +1500,11 @@ class GroupedGemmKernel:
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
gC_mnl_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
# (T2R, T2R_M, T2R_N)
tTR_rAcc = cute.make_fragment(
@ -1569,14 +1576,14 @@ class GroupedGemmKernel:
- tCgC: The destination global memory tensor partitioned for the TMA operation.
:rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
"""
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
gC_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
# ((ATOM_V, REST_V), EPI_M, EPI_N)
# ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL)
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_atom_c,
0,
@ -1595,7 +1602,7 @@ class GroupedGemmKernel:
epi_tile: cute.Tile,
c_dtype: type[cutlass.Numeric],
c_layout: utils.LayoutEnum,
num_smem_capacity: int,
smem_capacity: int,
occupancy: int,
) -> tuple[int, int, int]:
"""Computes the number of stages for accumulator, A/B operands, and epilogue based on heuristics.
@ -1614,8 +1621,8 @@ class GroupedGemmKernel:
:type c_dtype: type[cutlass.Numeric]
:param c_layout: Layout enum of operand C in global memory.
:type c_layout: utils.LayoutEnum
:param num_smem_capacity: Total available shared memory capacity in bytes.
:type num_smem_capacity: int
:param smem_capacity: Total available shared memory capacity in bytes.
:type smem_capacity: int
:param occupancy: Target number of CTAs per SM (occupancy).
:type occupancy: int
@ -1658,7 +1665,7 @@ class GroupedGemmKernel:
# Subtract reserved bytes and initial epilogue bytes
# Divide remaining by bytes needed per A/B stage
num_ab_stage = (
num_smem_capacity // occupancy
smem_capacity // occupancy
- GroupedGemmKernel.reserved_smem_bytes
- epi_bytes
) // ab_bytes_per_stage
@ -1667,7 +1674,7 @@ class GroupedGemmKernel:
# Calculate remaining smem after allocating for A/B stages and reserved bytes
# Add remaining unused smem to epilogue
remaining_smem = (
num_smem_capacity
smem_capacity
- occupancy * ab_bytes_per_stage * num_ab_stage
- occupancy * (GroupedGemmKernel.reserved_smem_bytes + epi_bytes)
)
@ -1775,20 +1782,6 @@ class GroupedGemmKernel:
epi_bytes = cute.size_in_bytes(c_dtype, epi_smem_layout_staged)
return ab_bytes + epi_bytes
@staticmethod
def _get_tma_atom_kind(atom_sm_cnt: int, mcast: bool):
"""Select the appropriate TMA copy atom based on the number of SMs and the multicast flag."""
if atom_sm_cnt == 2 and mcast:
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
elif atom_sm_cnt == 2 and not mcast:
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO)
elif atom_sm_cnt == 1 and mcast:
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE)
elif atom_sm_cnt == 1 and not mcast:
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}")
@staticmethod
def _compute_num_tmem_alloc_cols(
tiled_mma: cute.TiledMma,
@ -1909,8 +1902,6 @@ def run_grouped_gemm(
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run this example!")
torch.manual_seed(2025)
# Create tensor and return the pointer, tensor, and stride
def create_tensor_and_stride(
l: int,
@ -1920,42 +1911,17 @@ def run_grouped_gemm(
dtype: type[cutlass.Numeric],
is_dynamic_layout: bool = True,
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
# else: (l, mode0, mode1) -> (mode0, mode1, l)
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
# omit stride for L mode as it is always 1 for grouped GEMM
strides = (1, mode0) if is_mode0_major else (mode1, 1)
assert dtype in {cutlass.Float16, cutlass.BFloat16, cutlass.Float32}
is_unsigned = False
torch_dtype = cutlass_torch.dtype(dtype)
torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor(
shape,
torch_dtype,
permute_order=permute_order,
init_type=cutlass_torch.TensorInitType.RANDOM,
init_config=cutlass_torch.RandomInitConfig(
min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
),
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
)
torch_tensor = torch_tensor_cpu.cuda()
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
if is_dynamic_layout:
cute_tensor = cute_tensor.mark_layout_dynamic(
leading_dim=(0 if is_mode0_major else 1)
)
cute_tensor = cutlass_torch.convert_cute_tensor(
f32_torch_tensor,
return (
torch_tensor.data_ptr(),
torch_tensor,
cute_tensor,
dtype,
is_dynamic_layout=is_dynamic_layout,
torch_tensor_cpu,
torch_tensor.stride()[:-1],
)
# Get pointer of the tensor
ptr = torch_tensor.data_ptr()
return ptr, torch_tensor, cute_tensor, f32_torch_tensor, strides
# iterate all groups and create tensors for each group
torch_fp32_tensors_abc = []
@ -1964,15 +1930,27 @@ def run_grouped_gemm(
strides_abc = []
ptrs_abc = []
for _, (m, n, k, l) in enumerate(problem_sizes_mnkl):
ptr_a, torch_tensor_a, cute_tensor_a, tensor_fp32_a, stride_mk_a = (
create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
)
ptr_b, torch_tensor_b, cute_tensor_b, tensor_fp32_b, stride_nk_b = (
create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
)
ptr_c, torch_tensor_c, cute_tensor_c, tensor_fp32_c, stride_mn_c = (
create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
)
(
ptr_a,
torch_tensor_a,
cute_tensor_a,
tensor_fp32_a,
stride_mk_a,
) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
(
ptr_b,
torch_tensor_b,
cute_tensor_b,
tensor_fp32_b,
stride_nk_b,
) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
(
ptr_c,
torch_tensor_c,
cute_tensor_c,
tensor_fp32_c,
stride_mn_c,
) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c])
@ -2005,19 +1983,16 @@ def run_grouped_gemm(
)
# Prepare tensormap buffer for each SM
num_tensormap_buffers = sm_count
tensormap_pytorch_tensor = (
torch.empty(
(
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
),
dtype=torch.int64,
)
.fill_(0)
.cuda()
tensormap_shape = (
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
)
tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like(
torch.empty(tensormap_shape, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
)
tensormap_cute_tensor = from_dlpack(tensormap_pytorch_tensor, assumed_align=16)
grouped_gemm = GroupedGemmKernel(
acc_dtype,
@ -2027,23 +2002,30 @@ def run_grouped_gemm(
tensormap_update_mode,
)
# Convert integer list to torch tensor and cute tensor
def convert_list_to_tensor(l, dtype) -> tuple[torch.Tensor, cute.Tensor]:
torch_tensor = torch.tensor(l, dtype=dtype).cuda()
cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
return torch_tensor, cute_tensor
# layout (num_groups, 4):(4, 1)
problem_sizes_mnkl_torch_tensor, problem_sizes_mnkl_cute_tensor = (
convert_list_to_tensor(problem_sizes_mnkl, torch.int32)
(
tensor_of_dim_size_mnkl,
tensor_of_dim_size_mnkl_torch,
) = cutlass_torch.cute_tensor_like(
torch.tensor(problem_sizes_mnkl, dtype=torch.int32),
cutlass.Int32,
is_dynamic_layout=False,
assumed_align=16,
)
# layout (num_groups, 3, 2):(6, 2, 1)
strides_abc_torch_tensor, strides_abc_cute_tensor = convert_list_to_tensor(
strides_abc, torch.int32
tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like(
torch.tensor(strides_abc, dtype=torch.int32),
cutlass.Int32,
is_dynamic_layout=False,
assumed_align=16,
)
# layout (num_groups,3):(3, 1)
ptrs_abc_torch_tensor, ptrs_abc_cute_tensor = convert_list_to_tensor(
ptrs_abc, torch.int64
tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like(
torch.tensor(ptrs_abc, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
assumed_align=16,
)
# Compute total number of cluster tiles we need to compute for given grouped GEMM problem
@ -2077,10 +2059,9 @@ def run_grouped_gemm(
problem_sizes_mnkl, cluster_tile_shape_mn
)
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
# Initialize Stream
current_stream = cutlass_torch.default_stream()
# Compile grouped GEMM kernel
compiled_grouped_gemm = cute.compile(
grouped_gemm,
@ -2088,11 +2069,11 @@ def run_grouped_gemm(
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
num_groups,
problem_sizes_mnkl_cute_tensor,
strides_abc_cute_tensor,
ptrs_abc_cute_tensor,
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
total_num_clusters,
tensormap_cute_tensor,
tensor_of_tensormap,
max_active_clusters,
current_stream,
)
@ -2104,10 +2085,10 @@ def run_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
problem_sizes_mnkl_cute_tensor,
strides_abc_cute_tensor,
ptrs_abc_cute_tensor,
tensormap_cute_tensor,
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
# Execution
@ -2116,28 +2097,27 @@ def run_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
problem_sizes_mnkl_cute_tensor,
strides_abc_cute_tensor,
ptrs_abc_cute_tensor,
tensormap_cute_tensor,
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
torch.cuda.synchronize()
# Compute reference result
if not skip_ref_check:
refs = []
for a, b, _ in torch_fp32_tensors_abc:
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
refs.append(ref)
for i, ((_, _, c), ref) in enumerate(zip(torch_tensors_abc, refs)):
for i, (a, b, c) in enumerate(torch_tensors_abc):
ref = torch.einsum(
"mkl,nkl->mnl",
a.cpu().to(dtype=torch.float32),
b.cpu().to(dtype=torch.float32),
)
print(f"checking group {i}")
if c_dtype == cutlass.Float32:
ref_c = ref
else:
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
torch.testing.assert_close(
c.cpu(),
ref_c,
ref.to(cutlass_torch.dtype(c_dtype)),
atol=tolerance,
rtol=1e-05,
)
@ -2266,6 +2246,8 @@ if __name__ == "__main__":
else:
tensormap_update_mode = utils.TensorMapUpdateMode.SMEM
torch.manual_seed(2025)
run_grouped_gemm(
args.num_groups,
args.problem_sizes_mnkl,

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,397 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import torch
import torch.nn.functional as F
def ssd_reference_fp32_all(x, a, delta, B, C, Y_out, Fstate_out, D, has_d, d_has_hdim):
"""
Rearrange tensor dimensions from cuda layout to reference layout, then directly call TriDao's ssd implementation
Arguments:
X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L)
A/delta: (L, C, H, B):(1, L, C*L, H*C*L)
a: (H):(1)
B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
D: (1, H):(0, 1) or (D, H):(1, D)
has_d: bool
d_has_hdim: bool
Return:
Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L)
Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N)
"""
assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype
A = delta * a.view(1, 1, -1, 1)
X = x * delta.unsqueeze(0)
# Rearrange to match cutlass layout to tridao's layout
block_len = A.shape[0]
initial_states = None
# A: l c h b-> b c l h
A = A.permute(3, 1, 0, 2)
# X: p l c h b -> b c l h p
X = X.permute(4, 2, 1, 3, 0)
# B: l n c g b -> b c l g n
B = B.permute(4, 2, 0, 3, 1)
# C: l n c g b -> b c l g n
C = C.permute(4, 2, 0, 3, 1)
# X/A/B/C: b c l ... -> b (c l) ...
X, A, B, C = [x.reshape(x.shape[0], -1, *x.shape[3:]) for x in (X, A, B, C)]
# Ngroup (g to h) mapping
B_val, CL_val, G_val, N_val = B.shape
H_val = X.shape[2]
ngroup_ratio = H_val // G_val
# B/C: (B, CL, H, N)
h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio
B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
###################################################################
# Call reference implementation from Tri Dao ssd_minimal_discrete
Y, final_state = ssd_minimal_discrete_fp32_all(
X, A, B, C, block_len, initial_states
)
###################################################################
if has_d:
D_val = Y.shape[3]
if not d_has_hdim:
D = D.expand(D_val, -1)
Y = Y + torch.einsum("bchp,ph->bchp", X, D)
# Rearrange to match tridao's layout to cutlass layout
# Y: b (c l) h p -> b c l h p
Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3])
# Y: b c l h p -> l p c h b
Y = Y.permute(2, 4, 1, 3, 0)
# Fstate_out: b h p n -> p n h b
Fstate_out.copy_(final_state.permute(2, 3, 1, 0))
Y_out.copy_(Y)
return
def ssd_reference_lowprecision_intermediates(
x, a, delta, B, C, Y_out, Fstate_out, intermediate_dtype, D, has_d, d_has_hdim
):
"""
Rearrange tensor dimensions from cuda layout to reference layout, then call a reduced intermediate dtype version of ssd implementation
Arguments:
X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L)
A/delta: (L, C, H, B):(1, L, C*L, H*C*L)
a: (H):(1)
B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
intermediate_dtype: input and intermediate data type
D: (1, H):(0, 1) or (D, H):(1, D)
has_d: bool
d_has_hdim: bool
Return:
Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L)
Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N)
"""
assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype
A = delta * a.view(1, 1, -1, 1)
# Rearrange to match cutlass layout to tridao's layout
block_len = A.shape[0]
initial_states = None
# A: l c h b-> b c l h
A = A.permute(3, 1, 0, 2)
# delta: l c h b-> b c l h
delta = delta.permute(3, 1, 0, 2)
# x: p l c h b -> b c l h p
x = x.permute(4, 2, 1, 3, 0)
# B: l n c g b -> b c l g n
B = B.permute(4, 2, 0, 3, 1)
# C: l n c g b -> b c l g n
C = C.permute(4, 2, 0, 3, 1)
# x/A/delta/B/C: b c l ... -> b (c l) ...
x, A, delta, B, C = [
tensor.reshape(tensor.shape[0], -1, *tensor.shape[3:])
for tensor in (x, A, delta, B, C)
]
# Ngroup (g to h) mapping
B_val, CL_val, G_val, N_val = B.shape
H_val = x.shape[2]
ngroup_ratio = H_val // G_val
# B/C: (B, CL, H, N)
h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio
B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
# Type convert input tensors to input dtype (same as intermediate dtype)
x = x.to(intermediate_dtype).to(torch.float32)
A = A.to(intermediate_dtype).to(torch.float32)
delta = delta.to(intermediate_dtype).to(torch.float32)
B = B.to(intermediate_dtype).to(torch.float32)
C = C.to(intermediate_dtype).to(torch.float32)
#########################################################################
# Call reference implementation ssd_minimal_discrete_bf16_intermediates
Y, final_state = ssd_minimal_discrete_lowprecision_intermediates(
x, A, delta, B, C, block_len, intermediate_dtype, initial_states
)
#########################################################################
if has_d:
D = D.to(intermediate_dtype).to(torch.float32)
D_val = Y.shape[3]
if not d_has_hdim:
D = D.expand(D_val, -1)
Y = Y + torch.einsum("bchp,ph->bchp", x, D)
# Type convert output tensors to output dtype (same as intermediate dtype)
Y = Y.to(intermediate_dtype).to(torch.float32)
final_state = final_state.to(intermediate_dtype).to(torch.float32)
# Rearrange to match tridao's layout to cutlass layout
# Y: b (c l) h p -> b c l h p
Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3])
# Y: b c l h p -> l p c h b
Y = Y.permute(2, 4, 1, 3, 0)
# Fstate_out: b h p n -> p n h b
Fstate_out.copy_(final_state.permute(2, 3, 1, 0))
Y_out.copy_(Y)
return
def analyze_relative_diffs(actual, expected):
"""
Print statistics of relative differences between actual and expected tensors
"""
# Calculate relative differences
abs_diff = (actual - expected).abs()
rel_diff = abs_diff / (torch.maximum(expected.abs(), actual.abs()) + 0.00001)
total_elements = rel_diff.numel()
# Handle special cases first
nan_mask = torch.isnan(rel_diff)
inf_mask = torch.isinf(rel_diff)
nan_count = nan_mask.sum().item()
inf_count = inf_mask.sum().item()
# Find position and value of maximum relative difference
max_rel_diff = (
rel_diff[~nan_mask & ~inf_mask].max()
if (~nan_mask & ~inf_mask).any()
else float("nan")
)
max_rel_diff_pos = (
rel_diff[~nan_mask & ~inf_mask].argmax()
if (~nan_mask & ~inf_mask).any()
else -1
)
# Print max relative difference info
print(f"Maximum relative difference:")
print(f"Position: {max_rel_diff_pos}")
print(f"Value: {max_rel_diff:.6e}")
print(f"Actual value: {actual.flatten()[max_rel_diff_pos]}")
print(f"Expected value: {expected.flatten()[max_rel_diff_pos]}")
print(f"NaN values: {nan_count} ({100.0 * nan_count / total_elements:.2f}%)")
print(f"Inf values: {inf_count} ({100.0 * inf_count / total_elements:.2f}%)\n")
# Check different rtol thresholds
rtol_levels = [1e-5, 1e-4, 1e-3, 1e-2, 5e-02, 1e-01]
for i, rtol in enumerate(rtol_levels):
if i == 0:
mask = rel_diff <= rtol
else:
mask = (rel_diff <= rtol) & (rel_diff > rtol_levels[i - 1])
count = mask.sum().item()
percentage = (count / total_elements) * 100
if i == 0:
print(f"Elements with rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)")
else:
print(
f"Elements with {rtol_levels[i-1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)"
)
# Print elements exceeding the largest rtol
mask = rel_diff > rtol_levels[-1]
count = mask.sum().item()
percentage = (count / total_elements) * 100
print(f"Elements with rtol > {rtol_levels[-1]:.0e}: {count} ({percentage:.2f}%)\n")
def segsum(x):
"""
More stable segment sum calculation.
x: b h c l
"""
T = x.size(-1)
# x: b h c l -> b h c l l
x = x.unsqueeze(-1).expand(*x.shape, T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd_minimal_discrete_fp32_all(X, A, B, C, block_len, initial_states=None):
"""
This is same with https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py
(all accumulation and intermediate results in fp32)
Arguments:
X: (batch(B), length(C*L), n_heads(H), d_head(D))
A: (batch(B), length(C*L), n_heads(H))
B: (batch(B), length(C*L), n_heads(H), d_state(N))
C: (batch(B), length(C*L), n_heads(H), d_state(N))
Return:
Y: (batch(B), length(C*L), n_heads(H), d_head(D))
final_state: (B, H, D, N)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
# X/A/B/C:b (c l) ... -> b c l ...
X, A, B, C = [
x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, B, C)
]
# A: b c l h -> b h c l
A = A.permute(0, 3, 1, 2)
# A_cumsum: (B, H, C, L)
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
segsum_A = segsum(A)
L = torch.exp(segsum_A)
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
# Y: b c l h p -> b (c l) h p
Y = (Y_diag + Y_off).reshape(Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4])
return Y, final_state
def ssd_minimal_discrete_lowprecision_intermediates(
X, A, delta, B, C, block_len, intermediate_dtype, initial_states=None
):
"""
This is adjusted from ssd_minimal_discrete_fp32_all, with exceptions:
1. accumulation in fp32 but intermediates Q/b_tmem/P are in intermediate_dtype
2. delta is not pre-multiplied with X, delta was applied to generate Q/b_tmem to match GPU implementation
Arguments:
X: (batch(B), length(C*L), n_heads(H), d_head(D))
A: (batch(B), length(C*L), n_heads(H))
delta: (batch(B), length(C*L), n_heads(H))
B: (batch(B), length(C*L), n_heads(H), d_state(N))
C: (batch(B), length(C*L), n_heads(H), d_state(N))
Return:
Y: (batch(B), length(C*L), n_heads(H), d_head(D))
final_state: (B, H, D, N)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
# X/A/delta/B/C: b (c l) ... -> b c l ...
X, A, delta, B, C = [
x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, delta, B, C)
]
# A: b c l h -> b h c l
A = A.permute(0, 3, 1, 2)
# delta: b c l h -> b h c l
delta = delta.permute(0, 3, 1, 2)
# A_cumsum: (B, H, C, L)
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
segsum_A = segsum(A)
L = torch.exp(segsum_A)
intra_acc_0 = torch.einsum("bclhn,bcshn->bclhs", C, B)
Q = torch.einsum("bclhs,bhcls,bhcs->bclhs", intra_acc_0, L, delta)
Y_diag = torch.einsum(
"bclhs,bcshp->bclhp", Q.to(intermediate_dtype).to(torch.float32), X
)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
b_tmem = torch.einsum("bclhn,bhcl,bhcl->bclhn", B, decay_states, delta)
states = torch.einsum(
"bclhn,bclhp->bchpn", b_tmem.to(intermediate_dtype).to(torch.float32), X
)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
final_state = final_state
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off_tmp = torch.einsum(
"bclhn,bchpn->bclhp", C, states.to(intermediate_dtype).to(torch.float32)
)
Y_off = torch.einsum("bclhp,bhcl->bclhp", Y_off_tmp, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
# Y: b c l h p -> b (c l) h p
Y = (Y_diag + Y_off).reshape(
Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4]
) # b (c l) h p
return Y, final_state

View File

@ -0,0 +1,200 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import Tuple
from cutlass.cutlass_dsl import (
Boolean,
Integer,
Int32,
min,
extract_mlir_values,
new_from_mlir_values,
dsl_user_op,
)
from cutlass._mlir import ir
import cutlass.cute as cute
from cutlass.utils import WorkTileInfo
class Mamba2SSDTileSchedulerParams:
def __init__(
self,
problem_shape_ntiles: int,
eh: int,
ngroup_ratio: int,
*,
loc=None,
ip=None,
):
self.problem_shape_ntiles = problem_shape_ntiles
self.eh = eh
self.ngroup_ratio = ngroup_ratio
self._loc = loc
def __extract_mlir_values__(self):
values, self._values_pos = [], []
for obj in [self.problem_shape_ntiles, self.eh, self.ngroup_ratio]:
obj_values = extract_mlir_values(obj)
values += obj_values
self._values_pos.append(len(obj_values))
return values
def __new_from_mlir_values__(self, values):
obj_list = []
for obj, n_items in zip(
[self.problem_shape_ntiles, self.eh, self.ngroup_ratio], self._values_pos
):
obj_list.append(new_from_mlir_values(obj, values[:n_items]))
values = values[n_items:]
return Mamba2SSDTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
@dsl_user_op
def get_grid_shape(
self, max_active_clusters: Int32, *, loc=None, ip=None
) -> Tuple[Integer, Integer, Integer]:
return (min(self.problem_shape_ntiles, max_active_clusters), 1, 1)
class Mamba2SSDTileScheduler:
def __init__(
self,
params: Mamba2SSDTileSchedulerParams,
num_persistent_ctas: Int32,
current_work_linear_idx: Int32,
num_tiles_executed: Int32,
):
self.params = params
self.num_persistent_ctas = num_persistent_ctas
self._current_work_linear_idx = current_work_linear_idx
self._num_tiles_executed = num_tiles_executed
def __extract_mlir_values__(self) -> list[ir.Value]:
values = extract_mlir_values(self.num_persistent_ctas)
values.extend(extract_mlir_values(self._current_work_linear_idx))
values.extend(extract_mlir_values(self._num_tiles_executed))
return values
def __new_from_mlir_values__(
self, values: list[ir.Value]
) -> "Mamba2SSDTileScheduler":
assert len(values) == 3
new_num_persistent_ctas = new_from_mlir_values(
self.num_persistent_ctas, [values[0]]
)
new_current_work_linear_idx = new_from_mlir_values(
self._current_work_linear_idx, [values[1]]
)
new_num_tiles_executed = new_from_mlir_values(
self._num_tiles_executed, [values[2]]
)
return Mamba2SSDTileScheduler(
self.params,
new_num_persistent_ctas,
new_current_work_linear_idx,
new_num_tiles_executed,
)
# called by host
@dsl_user_op
@staticmethod
def create(
params: Mamba2SSDTileSchedulerParams,
block_idx: Tuple[Integer, Integer, Integer],
grid_dim: Tuple[Integer, Integer, Integer],
*,
loc=None,
ip=None,
):
params = params
# Calculate the number of persistent clusters by dividing the total grid size
# by the number of CTAs per cluster
num_persistent_ctas = Int32(cute.size(grid_dim, loc=loc, ip=ip))
bidx, bidy, bidz = block_idx
# Initialize workload index equals to the cluster index in the grid
current_work_linear_idx = Int32(bidx)
# Initialize number of tiles executed to zero
num_tiles_executed = Int32(0)
return Mamba2SSDTileScheduler(
params,
num_persistent_ctas,
current_work_linear_idx,
num_tiles_executed,
)
# called by host
@staticmethod
def get_grid_shape(
params: Mamba2SSDTileSchedulerParams,
max_active_clusters: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Integer, Integer, Integer]:
return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip)
# private method
def _get_current_work_for_linear_idx(
self, current_work_linear_idx: Int32, *, loc=None, ip=None
) -> WorkTileInfo:
is_valid = current_work_linear_idx < cute.size(
self.params.problem_shape_ntiles, loc=loc, ip=ip
)
eh_idx = current_work_linear_idx % self.params.eh
b_idx = current_work_linear_idx // self.params.eh
g_idx = eh_idx // self.params.ngroup_ratio
# cur_tile_coord is (b_idx, eh_idx, g_idx)
cur_tile_coord = tuple(Int32(x) for x in (b_idx, eh_idx, g_idx))
return WorkTileInfo(cur_tile_coord, is_valid)
@dsl_user_op
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
return self._get_current_work_for_linear_idx(
self._current_work_linear_idx, loc=loc, ip=ip
)
@dsl_user_op
def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo:
return self.get_current_work(loc=loc, ip=ip)
@dsl_user_op
def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
self._current_work_linear_idx += Int32(advance_count) * Int32(
self.num_persistent_ctas
)
self._num_tiles_executed += Int32(1)
@property
def num_tiles_executed(self) -> Int32:
return self._num_tiles_executed

View File

@ -30,7 +30,7 @@ cmake_minimum_required(VERSION 3.15)
project(tensor)
# Find Python
find_package(Python COMPONENTS Interpreter Development REQUIRED)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
# Get Python site-packages directory using Python
execute_process(

View File

@ -36,6 +36,7 @@ import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
from cutlass.cute.runtime import from_dlpack
import cutlass.utils.hopper_helpers as sm90_utils
@ -579,20 +580,25 @@ class HopperWgmmaGemmKernel:
mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr()
# Threads/warps participating in this pipeline
mainloop_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
# Set the consumer arrive count to the number of mcast size
consumer_arrive_cnt = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
mainloop_pipeline_consumer_group = utils.CooperativeGroup(
utils.Agent.Thread, consumer_arrive_cnt
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread
)
# Each warp will constribute to the arrive count with the number of mcast size
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
num_warps = self.threads_per_cta // 32
consumer_arrive_cnt = mcast_size * num_warps
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, consumer_arrive_cnt
)
mainloop_pipeline = utils.PipelineTmaAsync.create(
cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
mainloop_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=mainloop_pipeline_array_ptr,
num_stages=self.ab_stage,
producer_group=mainloop_pipeline_producer_group,
consumer_group=mainloop_pipeline_consumer_group,
tx_count=tma_copy_bytes,
cta_layout_vmnk=cta_layout_mnk,
cta_layout_vmnk=cta_layout_vmnk,
)
# Cluster arrive after barrier init
@ -616,11 +622,11 @@ class HopperWgmmaGemmKernel:
# ///////////////////////////////////////////////////////////////////////////////
# Local_tile partition global tensors
# ///////////////////////////////////////////////////////////////////////////////
# (bM, bK, loopK)
# (bM, bK, RestK)
gA_mkl = cute.local_tile(
mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
)
# (bN, bK, loopK)
# (bN, bK, RestK)
gB_nkl = cute.local_tile(
mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
)
@ -696,14 +702,14 @@ class HopperWgmmaGemmKernel:
k_tile_cnt = cute.size(gA_mkl, mode=[2])
prefetch_k_tile_cnt = cutlass.max(cutlass.min(self.ab_stage, k_tile_cnt), 0)
mainloop_producer_state = utils.make_pipeline_state(
utils.PipelineUserType.Producer, self.ab_stage
mainloop_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.ab_stage
)
if warp_idx == 0:
# /////////////////////////////////////////////////////////////////////////////
# Prefetch TMA load
# /////////////////////////////////////////////////////////////////////////////
for prefetch_idx in cutlass.range_dynamic(prefetch_k_tile_cnt, unroll=1):
for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1):
# /////////////////////////////////////////////////////////////////////////////
# Wait for A/B buffers to be empty before loading into them
# Also sets the transaction barrier for the A/B buffers
@ -748,11 +754,11 @@ class HopperWgmmaGemmKernel:
# /////////////////////////////////////////////////////////////////////////////
k_pipe_mmas = 1
mainloop_consumer_read_state = utils.make_pipeline_state(
utils.PipelineUserType.Consumer, self.ab_stage
mainloop_consumer_read_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.ab_stage
)
mainloop_consumer_release_state = utils.make_pipeline_state(
utils.PipelineUserType.Consumer, self.ab_stage
mainloop_consumer_release_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.ab_stage
)
peek_ab_full_status = cutlass.Boolean(1)
@ -763,14 +769,14 @@ class HopperWgmmaGemmKernel:
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_tile in cutlass.range_dynamic(k_pipe_mmas, unroll=1):
for k_tile in range(k_pipe_mmas):
# Wait for A/B buffer to be ready
mainloop_pipeline.consumer_wait(
mainloop_consumer_read_state, peek_ab_full_status
)
cute.nvgpu.warpgroup.fence()
for k_block_idx in range(num_k_blocks):
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
@ -800,7 +806,7 @@ class HopperWgmmaGemmKernel:
# /////////////////////////////////////////////////////////////////////////////
# MAINLOOP
# /////////////////////////////////////////////////////////////////////////////
for k_tile in cutlass.range_dynamic(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
# /////////////////////////////////////////////////////////////////////////////
# Wait for TMA copies to complete
# /////////////////////////////////////////////////////////////////////////////
@ -811,7 +817,7 @@ class HopperWgmmaGemmKernel:
# WGMMA
# /////////////////////////////////////////////////////////////////////////////
cute.nvgpu.warpgroup.fence()
for k_block_idx in range(num_k_blocks):
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
@ -949,7 +955,7 @@ class HopperWgmmaGemmKernel:
epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1])
epi_tile_shape = tcgc_for_tma_partition.shape[1]
for epi_idx in cutlass.range_dynamic(epi_tile_num, unroll=epi_tile_num):
for epi_idx in cutlass.range(epi_tile_num, unroll=epi_tile_num):
# Copy from accumulators to D registers
for epi_v in range(size_tRS_rD):
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
@ -1213,7 +1219,7 @@ class HopperWgmmaGemmKernel:
c_cta_v_layout = cute.composition(
cute.make_identity_layout(tensor_c.shape), epi_tile
)
tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tma_tile_atom(
tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
tensor_c,
epi_smem_layout,
@ -1250,7 +1256,7 @@ class HopperWgmmaGemmKernel:
)
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tma_tile_atom(
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
op,
tensor,
smem_layout,

View File

@ -297,7 +297,7 @@
" assert depth <= 1, f\"Depth of coalesced layout should be <= 1, got {depth}\"\n",
"\n",
" print(\">>> 3. Checking layout functionality remains the same after the coalesce operation:\")\n",
" for i in range(original_size):\n",
" for i in cutlass.range_constexpr(original_size):\n",
" original_value = layout(i)\n",
" coalesced_value = result(i)\n",
" print(f\"Index {i}: original {original_value}, coalesced {coalesced_value}\")\n",

View File

@ -60,48 +60,7 @@
"@cute.jit\n",
"def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n",
" ...\n",
"```\n",
"To differentiate between compile-time and runtime values, CuTe DSL introduces primitive types that \n",
"represent dynamic values in JIT-compiled code.\n",
"\n",
"CuTe DSL provides a comprehensive set of primitive numeric types for representing dynamic values at \n",
"runtime. These types are formally defined within the CuTe DSL typing system:\n",
"\n",
"### Integer Types\n",
"- `Int8` - 8-bit signed integer\n",
"- `Int16` - 16-bit signed integer \n",
"- `Int32` - 32-bit signed integer\n",
"- `Int64` - 64-bit signed integer\n",
"- `Int128` - 128-bit signed integer\n",
"- `Uint8` - 8-bit unsigned integer\n",
"- `Uint16` - 16-bit unsigned integer\n",
"- `Uint32` - 32-bit unsigned integer\n",
"- `Uint64` - 64-bit unsigned integer\n",
"- `Uint128` - 128-bit unsigned integer\n",
"\n",
"### Floating Point Types\n",
"- `Float16` - 16-bit floating point\n",
"- `Float32` - 32-bit floating point \n",
"- `Float64` - 64-bit floating point\n",
"- `BFloat16` - Brain Floating Point format (16-bit)\n",
"- `TFloat32` - Tensor Float32 format (reduced precision format used in tensor operations)\n",
"- `Float8E4M3` - 8-bit floating point with 4-bit exponent and 3-bit mantissa\n",
"- `Float8E5M2` - 8-bit floating point with 5-bit exponent and 2-bit mantissa\n",
"\n",
"These specialized types are designed to represent dynamic values in CuTe DSL code that will be \n",
"evaluated at runtime, in contrast to Python's built-in numeric types which are evaluated during \n",
"compilation.\n",
"\n",
"### Example usage:\n",
"\n",
"```python\n",
"x = cutlass.Int32(5) # Creates a 32-bit integer\n",
"y = cutlass.Float32(3.14) # Creates a 32-bit float\n",
"\n",
"@cute.jit\n",
"def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n",
" ...\n",
"```"
"```\n"
]
},
{

View File

@ -120,7 +120,7 @@
" src_vec = src.load()\n",
" dst_vec = src_vec[indices]\n",
" print(f\"{src_vec} -> {dst_vec}\")\n",
" if isinstance(dst_vec, cute.TensorSSA):\n",
" if cutlass.const_expr(isinstance(dst_vec, cute.TensorSSA)):\n",
" dst.store(dst_vec)\n",
" cute.print_tensor(dst)\n",
" else:\n",