v4.1 release
This commit is contained in:
259
examples/python/CuTeDSL/ampere/call_from_jit.py
Normal file
259
examples/python/CuTeDSL/ampere/call_from_jit.py
Normal file
@ -0,0 +1,259 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Demonstrating JIT GEMM Implementation with Static Shape Wrapper
|
||||
|
||||
This example illustrates how to invoke a JIT-compiled GEMM implementation through a wrapper function
|
||||
with static shapes. It showcases the integration between PyTorch and CuTe tensors in a JIT context.
|
||||
|
||||
Key features demonstrated:
|
||||
1. Seamless conversion between PyTorch and CuTe tensors using the JitArgument protocol
|
||||
2. Integration of static shape GEMM operations within a JIT-compiled wrapper function
|
||||
|
||||
Core components:
|
||||
- BufferWithLayout: Handles memory buffer management with configurable stride ordering
|
||||
- tensor_op_gemm_wrapper: JIT-compiled entry point that orchestrates the GEMM operation
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/call_from_jit.py
|
||||
|
||||
Default configuration:
|
||||
- Batch dimension (L): 16
|
||||
- Matrix dimensions: M=512, N=256, K=128
|
||||
- Precision: Float16 inputs with Float32 accumulation
|
||||
|
||||
Requirements:
|
||||
- CUDA-capable GPU
|
||||
- PyTorch with CUDA support
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Type, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.torch import dtype as torch_dtype
|
||||
from cutlass.cute.runtime import make_ptr
|
||||
|
||||
|
||||
# Add the current directory to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from tensorop_gemm import TensorOpGemm
|
||||
|
||||
|
||||
class BufferWithLayout:
|
||||
def __init__(self, ptr: cute.Pointer, stride_order: tuple[int, int, int]):
|
||||
self.ptr = ptr
|
||||
|
||||
# static properties
|
||||
self.stride_order = stride_order
|
||||
|
||||
def to_tensor(
|
||||
self, shape: tuple[int, int, int], *, loc=None, ip=None
|
||||
) -> cute.Tensor:
|
||||
assert len(shape) == len(self.stride_order), (
|
||||
f"Shape {shape} and stride_order {self.stride_order} must have the "
|
||||
"same rank."
|
||||
)
|
||||
layout = cute.make_ordered_layout(shape, self.stride_order)
|
||||
# permute (l, mn, k) -> (mn, k, l)
|
||||
res = cute.make_tensor(self.ptr, cute.select(layout, mode=[1, 2, 0]))
|
||||
return res
|
||||
|
||||
# Implement JitArgument Protocol and DynamicExpression Protocol
|
||||
|
||||
def __c_pointers__(self):
|
||||
"""Get the C pointers for the underlying pointer.
|
||||
|
||||
This method is part of the JitArgument Protocol and returns the C pointers
|
||||
from the underlying pointer object.
|
||||
|
||||
This is required for user to define a custom data type which can pass to JIT function.
|
||||
When JIT compiled function is called, JIT executor will call this method to get raw pointers
|
||||
to underlying data object.
|
||||
|
||||
Following condition must be satisfied:
|
||||
|
||||
len(__c_pointers__()) == len(__get_mlir_types__()) == len(__extract_mlir_values__())
|
||||
|
||||
:return: The C pointers from the underlying pointer object
|
||||
:rtype: Any
|
||||
"""
|
||||
return self.ptr.__c_pointers__()
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
"""Get the MLIR types for the underlying pointer.
|
||||
|
||||
This method is part of the JitArgument Protocol and returns the MLIR types
|
||||
used for compiler to generate code. It must match the type of the underlying pointers
|
||||
returned by __c_pointers__().
|
||||
|
||||
:return: The MLIR types from the underlying pointer object
|
||||
:rtype: Any
|
||||
"""
|
||||
return self.ptr.__get_mlir_types__()
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""Extract MLIR values from the underlying pointer.
|
||||
|
||||
This method is part of the DynamicExpression Protocol and extracts MLIR values
|
||||
from the underlying pointer object.
|
||||
|
||||
It is used by compiler to generate function call in MLIR to another JIT function.
|
||||
It must match the types returned by __get_mlir_types__().
|
||||
|
||||
:return: The MLIR values extracted from the underlying pointer object
|
||||
:rtype: Any
|
||||
"""
|
||||
return self.ptr.__extract_mlir_values__()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new BufferWithLayout instance from MLIR values.
|
||||
|
||||
This method is part of the JitArgument & DynamicExpression Protocol and creates a new
|
||||
BufferWithLayout instance with pointer initialized from the given MLIR values.
|
||||
|
||||
It is used by compiler to generate function body in MLIR called by JIT function.
|
||||
It must match the types returned by __c_pointers__() and __get_mlir_types__().
|
||||
code generator takes function arguments and reconstructs python object which is legal
|
||||
inside function body.
|
||||
|
||||
:param values: MLIR values to initialize the underlying pointer
|
||||
:type values: Any
|
||||
:return: A new BufferWithLayout instance with pointer initialized from values
|
||||
:rtype: BufferWithLayout
|
||||
"""
|
||||
return BufferWithLayout(
|
||||
self.ptr.__new_from_mlir_values__(values), self.stride_order
|
||||
)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def tensor_op_gemm_wrapper(
|
||||
buffer_a: BufferWithLayout,
|
||||
buffer_b: BufferWithLayout,
|
||||
buffer_c: BufferWithLayout,
|
||||
mnkl: cutlass.Constexpr[tuple[int, int, int, int]],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
atom_layout_mnk: cutlass.Constexpr[tuple[int, int, int]],
|
||||
):
|
||||
print(f"\n[DSL INFO] Input Parameters:")
|
||||
print(f"[DSL INFO] mnkl: {mnkl}")
|
||||
print(f"[DSL INFO] buffer_a: {buffer_a}")
|
||||
print(f"[DSL INFO] buffer_b: {buffer_b}")
|
||||
print(f"[DSL INFO] buffer_c: {buffer_c}")
|
||||
print(f"[DSL INFO] acc_dtype: {acc_dtype}")
|
||||
print(f"[DSL INFO] atom_layout_mnk: {atom_layout_mnk}")
|
||||
|
||||
mA = buffer_a.to_tensor(cute.select(mnkl, mode=[3, 0, 2]))
|
||||
mB = buffer_b.to_tensor(cute.select(mnkl, mode=[3, 1, 2]))
|
||||
mC = buffer_c.to_tensor(cute.select(mnkl, mode=[3, 0, 1]))
|
||||
|
||||
print(f"\n[DSL INFO] Created Tensors:")
|
||||
print(f"[DSL INFO] mA = {mA}")
|
||||
print(f"[DSL INFO] mB = {mB}")
|
||||
print(f"[DSL INFO] mC = {mC}")
|
||||
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
buffer_a.ptr.value_type,
|
||||
buffer_c.ptr.value_type,
|
||||
acc_dtype,
|
||||
atom_layout_mnk,
|
||||
)
|
||||
print(f"\n[DSL INFO] Created TensorOpGemm instance")
|
||||
print(f"[DSL INFO] Input dtype: {buffer_a.ptr.value_type}")
|
||||
print(f"[DSL INFO] Output dtype: {buffer_c.ptr.value_type}")
|
||||
print(f"[DSL INFO] Accumulation dtype: {acc_dtype}")
|
||||
print(f"[DSL INFO] Atom layout: {atom_layout_mnk}")
|
||||
|
||||
# No need to compile inside jit function
|
||||
tensor_op_gemm(mA, mB, mC)
|
||||
print(f"\n[DSL INFO] Executed TensorOpGemm")
|
||||
|
||||
|
||||
def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
|
||||
print(f"\nRunning TensorOpGemm test with:")
|
||||
print(f"Tensor dimensions: {mnkl}")
|
||||
|
||||
ab_dtype = cutlass.Float16
|
||||
c_dtype = cutlass.Float16
|
||||
|
||||
a = torch.randn(
|
||||
mnkl[3], mnkl[0], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda"
|
||||
)
|
||||
b = torch.randn(
|
||||
mnkl[3], mnkl[1], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda"
|
||||
)
|
||||
c = torch.randn(
|
||||
mnkl[3], mnkl[0], mnkl[1], dtype=torch_dtype(c_dtype), device="cuda"
|
||||
)
|
||||
|
||||
print(f"Input tensor shapes:")
|
||||
print(f"a: {a.shape}, dtype: {a.dtype}")
|
||||
print(f"b: {b.shape}, dtype: {b.dtype}")
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
buffer_a = BufferWithLayout(
|
||||
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem),
|
||||
(2, 1, 0),
|
||||
)
|
||||
buffer_b = BufferWithLayout(
|
||||
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem),
|
||||
(2, 1, 0),
|
||||
)
|
||||
buffer_c = BufferWithLayout(
|
||||
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem),
|
||||
(2, 1, 0),
|
||||
)
|
||||
|
||||
tensor_op_gemm_wrapper(
|
||||
buffer_a,
|
||||
buffer_b,
|
||||
buffer_c,
|
||||
mnkl, # pass shape as static value
|
||||
# no stride passing
|
||||
cutlass.Float32,
|
||||
(2, 2, 1),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ref = torch.einsum("lmk,lnk->lmn", a, b)
|
||||
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
|
||||
print(f"\n[DSL INFO] Results verified successfully!")
|
||||
print(f"First few elements of result: \n{c[:3, :3, :3]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tensor_op_gemm_wrapper((512, 256, 128, 16))
|
||||
@ -28,16 +28,17 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import time
|
||||
from typing import Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
An Elementwise Addition Example using CuTe DSL.
|
||||
@ -153,6 +154,7 @@ def elementwise_add_kernel(
|
||||
blkC = gC[blk_coord] # (TileM,TileN)
|
||||
blkCrd = cC[blk_coord] # (TileM, TileN)
|
||||
|
||||
# Note: these prints only run at compile/jit time
|
||||
print(f"[DSL INFO] Sliced Tensors per thread block:")
|
||||
print(f"[DSL INFO] blkA = {blkA.type}")
|
||||
print(f"[DSL INFO] blkB = {blkB.type}")
|
||||
@ -189,7 +191,7 @@ def elementwise_add_kernel(
|
||||
print(f"[DSL INFO] thrC = {thrC.type}")
|
||||
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
||||
|
||||
for i in cutlass.range_dynamic(0, cute.size(frgPred), 1):
|
||||
for i in range(0, cute.size(frgPred), 1):
|
||||
val = cute.elem_less(thrCrd[i], shape)
|
||||
frgPred[i] = val
|
||||
|
||||
@ -270,9 +272,6 @@ def run_elementwise_add(
|
||||
warmup_iterations=2,
|
||||
iterations=200,
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(f"Ampere GPU is required to run this example!")
|
||||
|
||||
print(f"\nRunning Elementwise Add test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Input and Output Data type: {dtype}")
|
||||
@ -315,10 +314,8 @@ def run_elementwise_add(
|
||||
|
||||
print("Executing vector add kernel...")
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
# Get current CUstream from torch
|
||||
current_stream = cutlass_torch.current_stream()
|
||||
|
||||
if not skip_ref_check:
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
@ -329,41 +326,52 @@ def run_elementwise_add(
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
def generate_tensors():
|
||||
if dtype.is_integer:
|
||||
a = torch.randint(
|
||||
0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype
|
||||
)
|
||||
b = torch.randint(
|
||||
0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
c = torch.zeros_like(a)
|
||||
|
||||
# Use the current stream for CUDA events instead of the default stream
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
if not is_a_dynamic_layout:
|
||||
a_tensor = from_dlpack(a).mark_layout_dynamic()
|
||||
else:
|
||||
a_tensor = a
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
if not is_b_dynamic_layout:
|
||||
b_tensor = from_dlpack(b).mark_layout_dynamic()
|
||||
else:
|
||||
b_tensor = b
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
if not is_result_dynamic_layout:
|
||||
c_tensor = from_dlpack(c).mark_layout_dynamic()
|
||||
else:
|
||||
c_tensor = c
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
avg_time = elapsed_time / iterations
|
||||
return testing.JitArguments(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
compiled_func,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time:.4f} ms")
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s"
|
||||
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
|
||||
)
|
||||
print(f"First few elements of result: \n{c[:3, :3]}")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -377,6 +385,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(f"Ampere GPU is required to run this example!")
|
||||
|
||||
run_elementwise_add(
|
||||
args.M,
|
||||
args.N,
|
||||
|
||||
@ -29,14 +29,15 @@
|
||||
|
||||
import argparse
|
||||
import operator
|
||||
import torch
|
||||
from typing import Type
|
||||
import time
|
||||
from typing import Type, List
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@ -77,8 +78,7 @@ while maintaining high performance through efficient memory access patterns.
|
||||
@cute.kernel
|
||||
def elementwise_apply_kernel(
|
||||
op: cutlass.Constexpr,
|
||||
gA: cute.Tensor,
|
||||
gB: cute.Tensor,
|
||||
inputs: List[cute.Tensor],
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
@ -90,48 +90,46 @@ def elementwise_apply_kernel(
|
||||
# slice for CTAs
|
||||
cta_coord = ((None, None), bidx)
|
||||
# logical coord -> address
|
||||
ctaA = gA[cta_coord] # (TileM, TileN)
|
||||
ctaB = gB[cta_coord] # (TileM, TileN)
|
||||
# Leverage the meta-programming capability of the DSL to slice the tensors for each input
|
||||
# All for loops below on input tensors would be fully unrolled automatically at compile time
|
||||
ctaInputs = [t[cta_coord] for t in inputs] # (TileM, TileN)
|
||||
ctaC = gC[cta_coord] # (TileM, TileN)
|
||||
ctaCrd = cC[cta_coord] # (TileM, TileN)
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread block:")
|
||||
print(f"[DSL INFO] ctaA = {ctaA.type}")
|
||||
print(f"[DSL INFO] ctaB = {ctaB.type}")
|
||||
for i in cutlass.range_constexpr(len(ctaInputs)):
|
||||
print(f"[DSL INFO] ctaInputs{i} = {ctaInputs[i].type}")
|
||||
print(f"[DSL INFO] ctaC = {ctaC.type}")
|
||||
print(f"[DSL INFO] ctaCrd = {ctaCrd.type}")
|
||||
|
||||
# compose with CTA TV layout
|
||||
# (tid, vid) -> address
|
||||
tidfrgA = cute.composition(ctaA, tv_layout)
|
||||
tidfrgB = cute.composition(ctaB, tv_layout)
|
||||
tidfrgInputs = [cute.composition(t, tv_layout) for t in ctaInputs]
|
||||
tidfrgC = cute.composition(ctaC, tv_layout)
|
||||
tidfrgCrd = cute.composition(ctaCrd, tv_layout)
|
||||
# print(f"{tv_layout = }")
|
||||
# print(f"{tidfrgA = }")
|
||||
# print(f"{tidfrgAB[0] = }")
|
||||
|
||||
thr_coord = (tidx, (None, None))
|
||||
|
||||
# slice for threads
|
||||
# vid -> address
|
||||
thrA = tidfrgA[thr_coord] # (V)
|
||||
thrB = tidfrgB[thr_coord] # (V)
|
||||
thrInputs = [t[thr_coord] for t in tidfrgInputs] # (V)
|
||||
thrC = tidfrgC[thr_coord] # (V)
|
||||
thrCrd = tidfrgCrd[thr_coord]
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread:")
|
||||
print(f"[DSL INFO] thrA = {thrA.type}")
|
||||
print(f"[DSL INFO] thrB = {thrB.type}")
|
||||
for i in cutlass.range_constexpr(len(thrInputs)):
|
||||
print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}")
|
||||
print(f"[DSL INFO] thrC = {thrC.type}")
|
||||
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
||||
|
||||
# allocate fragments for gmem->rmem
|
||||
frgA = cute.make_fragment_like(thrA, gA.element_type)
|
||||
frgB = cute.make_fragment_like(thrB, gB.element_type)
|
||||
frgInputs = [cute.make_fragment_like(t, t.element_type) for t in thrInputs]
|
||||
frgC = cute.make_fragment_like(thrC, gC.element_type)
|
||||
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
|
||||
|
||||
for i in cutlass.range_dynamic(cute.size(frgPred), unroll=1):
|
||||
for i in cutlass.range(cute.size(frgPred), unroll=1):
|
||||
frgPred[i] = cute.elem_less(thrCrd[i], shape)
|
||||
|
||||
# if tidx == 0 and bidx == 0:
|
||||
@ -142,10 +140,13 @@ def elementwise_apply_kernel(
|
||||
##########################################################
|
||||
|
||||
# declare the atoms which will be used later for memory copy
|
||||
# Compile time validation: expect same element type for all input tensors so as to reuse the copy atom for load
|
||||
assert all(t.element_type == inputs[0].element_type for t in inputs)
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
gA.element_type,
|
||||
num_bits_per_copy=gA.element_type.width,
|
||||
inputs[0].element_type,
|
||||
num_bits_per_copy=inputs[0].element_type.width,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
@ -153,12 +154,12 @@ def elementwise_apply_kernel(
|
||||
num_bits_per_copy=gC.element_type.width,
|
||||
)
|
||||
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
for thrInput, frgInput in zip(thrInputs, frgInputs):
|
||||
cute.copy(copy_atom_load, thrInput, frgInput, pred=frgPred)
|
||||
|
||||
# Load data before use. The compiler will optimize the copy and load
|
||||
# operations to convert some memory ld/st into register uses.
|
||||
result = op(frgA.load(), frgB.load())
|
||||
result = op(*[frgInput.load() for frgInput in frgInputs])
|
||||
|
||||
# Save the results back to registers. Here we reuse b's registers.
|
||||
frgC.store(result)
|
||||
@ -173,6 +174,7 @@ def elementwise_apply(
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
result: cute.Tensor,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
|
||||
CuTe Python and store to result tensor.
|
||||
@ -262,8 +264,7 @@ def elementwise_apply(
|
||||
# Async token(s) can also be specified as dependencies
|
||||
elementwise_apply_kernel(
|
||||
op,
|
||||
gA,
|
||||
gB,
|
||||
[gA, gB], # Group input tensors into a list as a single argument
|
||||
gC,
|
||||
cC,
|
||||
result.shape,
|
||||
@ -271,6 +272,7 @@ def elementwise_apply(
|
||||
).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
@ -287,6 +289,11 @@ def run_elementwise_apply_and_verify(
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(f"Ampere GPU is required to run this example!")
|
||||
|
||||
# Create non default CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.Stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
print(f"\nRunning Elementwise Apply test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Input and Output Data type: {dtype}")
|
||||
@ -309,20 +316,16 @@ def run_elementwise_apply_and_verify(
|
||||
if op in (operator.truediv, operator.floordiv):
|
||||
b = torch.where(b == 0, torch.tensor(epsilon), b)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing elementwise apply kernel...")
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
if not skip_ref_check:
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
elementwise_apply(
|
||||
op,
|
||||
from_dlpack(a),
|
||||
from_dlpack(b),
|
||||
from_dlpack(c).mark_layout_dynamic(),
|
||||
current_stream,
|
||||
)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(op(a, b), c)
|
||||
print("Results verified successfully!")
|
||||
@ -330,28 +333,32 @@ def run_elementwise_apply_and_verify(
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
compiled_func = cute.compile(
|
||||
elementwise_apply,
|
||||
op,
|
||||
from_dlpack(a),
|
||||
from_dlpack(b),
|
||||
from_dlpack(c).mark_layout_dynamic(),
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
# When compiled we inlined op in the kernel, so we do not pass it when benchmarking
|
||||
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
avg_time_us = testing.benchmark(
|
||||
compiled_func,
|
||||
kernel_arguments=testing.JitArguments(
|
||||
from_dlpack(a),
|
||||
from_dlpack(b),
|
||||
from_dlpack(c).mark_layout_dynamic(),
|
||||
current_stream,
|
||||
),
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
use_cuda_graphs=True,
|
||||
stream=current_stream,
|
||||
)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
avg_time = elapsed_time / iterations
|
||||
avg_time = avg_time_us / 1e3
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time:.4f} ms")
|
||||
@ -360,10 +367,6 @@ def run_elementwise_apply_and_verify(
|
||||
)
|
||||
print(f"First few elements of result: \n{c[:3, :3]}")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
@ -542,13 +542,13 @@ class FlashAttentionForwardAmpere:
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Set predicates for head_dim bounds, seqlen_q/k bounds is processed at the first tile.
|
||||
for rest_v in range(tQpQ.shape[0]):
|
||||
for rest_k in range(tQpQ.shape[2]):
|
||||
for rest_v in cutlass.range_constexpr(tQpQ.shape[0]):
|
||||
for rest_k in cutlass.range_constexpr(tQpQ.shape[2]):
|
||||
tQpQ[rest_v, 0, rest_k] = cute.elem_less(
|
||||
tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3]
|
||||
)
|
||||
for rest_v in range(tKVpKV.shape[0]):
|
||||
for rest_k in range(tKVpKV.shape[2]):
|
||||
for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]):
|
||||
for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]):
|
||||
tKVpKV[rest_v, 0, rest_k] = cute.elem_less(
|
||||
tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3]
|
||||
)
|
||||
@ -556,7 +556,7 @@ class FlashAttentionForwardAmpere:
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Start async loads of the last mn-tile, where we take care of the mn residue
|
||||
for m in range(cute.size(tQsQ.shape[1])):
|
||||
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
||||
if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_QKV,
|
||||
@ -567,7 +567,7 @@ class FlashAttentionForwardAmpere:
|
||||
else:
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tQsQ[None, m, None].fill(0)
|
||||
for n in range(cute.size(tKsK.shape[1])):
|
||||
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
|
||||
if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_QKV,
|
||||
@ -644,13 +644,13 @@ class FlashAttentionForwardAmpere:
|
||||
# We also need masking on S if it's causal, for the last ceil_div(m_block_size, n_block_size) blocks.
|
||||
# We will have at least 1 "masking" iteration.
|
||||
mask_steps = 1
|
||||
if self._is_causal:
|
||||
if cutlass.const_expr(self._is_causal):
|
||||
mask_steps = cute.ceil_div(self._m_block_size, self._n_block_size)
|
||||
|
||||
for n_tile in range(mask_steps):
|
||||
for n_tile in cutlass.range_constexpr(mask_steps):
|
||||
n_block = n_block_max - n_tile - 1
|
||||
basic_params.n_block = n_block
|
||||
if self._is_causal:
|
||||
if cutlass.const_expr(self._is_causal):
|
||||
if n_block >= 0:
|
||||
self.compute_one_n_block(
|
||||
basic_params,
|
||||
@ -673,7 +673,7 @@ class FlashAttentionForwardAmpere:
|
||||
)
|
||||
|
||||
# Start async loads of rest k-tiles in reverse order, no k-residue handling needed
|
||||
for n_tile in cutlass.range_dynamic(mask_steps, n_block_max, 1):
|
||||
for n_tile in range(mask_steps, n_block_max, 1):
|
||||
n_block = n_block_max - n_tile - 1
|
||||
basic_params.n_block = n_block
|
||||
self.compute_one_n_block(
|
||||
@ -748,13 +748,13 @@ class FlashAttentionForwardAmpere:
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in range(tOpO.shape[0]):
|
||||
for rest_n in range(cute.size(tOpO.shape[2])):
|
||||
for rest_v in cutlass.range_constexpr(tOpO.shape[0]):
|
||||
for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])):
|
||||
tOpO[rest_v, 0, rest_n] = cute.elem_less(
|
||||
tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3]
|
||||
)
|
||||
# copy acc O from rmem to gmem
|
||||
for rest_m in range(cute.size(tOpO.shape[1])):
|
||||
for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])):
|
||||
if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_O,
|
||||
@ -804,7 +804,7 @@ class FlashAttentionForwardAmpere:
|
||||
# load smem tile V for O, special process for the first tile to avoid loading nan.
|
||||
# The `if` here is a constexpr, won't be generated in the IR.
|
||||
if is_first_n_block:
|
||||
for n in range(cute.size(gmem_copy_params.tVsV.shape[1])):
|
||||
for n in cutlass.range_constexpr(cute.size(gmem_copy_params.tVsV.shape[1])):
|
||||
if cute.elem_less(
|
||||
gmem_copy_params.tKVcKV[0, n, 0][1],
|
||||
basic_params.mK.layout.shape[1],
|
||||
@ -841,7 +841,7 @@ class FlashAttentionForwardAmpere:
|
||||
smem_copy_params.tSrK_copy_view[None, None, 0],
|
||||
)
|
||||
# mma for S
|
||||
for k in range(cute.size(smem_copy_params.tSsQ.shape[2])):
|
||||
for k in cutlass.range_constexpr(cute.size(smem_copy_params.tSsQ.shape[2])):
|
||||
# load next QK k-block from smem to rmem for mma
|
||||
k_next = (k + 1) % cute.size(smem_copy_params.tSsQ.shape[2])
|
||||
cute.copy(
|
||||
@ -916,7 +916,7 @@ class FlashAttentionForwardAmpere:
|
||||
smem_copy_params.tOrVt_copy_view[None, None, 0],
|
||||
)
|
||||
# mma for O
|
||||
for k in range(cute.size(tOrS.shape[2])):
|
||||
for k in cutlass.range_constexpr(cute.size(tOrS.shape[2])):
|
||||
# load next V k-block from smem to rmem for mma
|
||||
k_next = (k + 1) % cute.size(tOrS.shape[2])
|
||||
cute.copy(
|
||||
@ -965,14 +965,14 @@ class FlashAttentionForwardAmpere:
|
||||
acc_O_mn = self._make_acc_tensor_mn_view(mma_params.acc_O)
|
||||
row_max_prev = None
|
||||
# if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row.
|
||||
if not is_first_n_block:
|
||||
if cutlass.const_expr(not is_first_n_block):
|
||||
row_max_prev = cute.make_fragment_like(
|
||||
softmax_params.row_max, cutlass.Float32
|
||||
)
|
||||
cute.basic_copy(softmax_params.row_max, row_max_prev)
|
||||
# if it is the first tile, create a mask for residual of S to -inf for softmax.
|
||||
tScS_mn = None
|
||||
if in_mask_steps:
|
||||
if cutlass.const_expr(in_mask_steps):
|
||||
mcS = cute.make_identity_tensor(
|
||||
(
|
||||
basic_params.mQ.shape[0],
|
||||
@ -990,12 +990,12 @@ class FlashAttentionForwardAmpere:
|
||||
tScS_mn = self._make_acc_tensor_mn_view(tScS)
|
||||
|
||||
# Each iteration processes one row of acc_S
|
||||
for r in range(cute.size(softmax_params.row_max)):
|
||||
for r in cutlass.range_constexpr(cute.size(softmax_params.row_max)):
|
||||
# mask residual of S with -inf
|
||||
if in_mask_steps:
|
||||
if not self._is_causal:
|
||||
if cutlass.const_expr(in_mask_steps):
|
||||
if cutlass.const_expr(not self._is_causal):
|
||||
# traverse column index.
|
||||
for c in range(cute.size(tScS_mn.shape[1])):
|
||||
for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])):
|
||||
if cute.elem_less(
|
||||
basic_params.mK.shape[1], tScS_mn[0, c][3] + 1
|
||||
):
|
||||
@ -1006,7 +1006,7 @@ class FlashAttentionForwardAmpere:
|
||||
tScS_mn[r, 0][1] + 1, basic_params.mK.shape[1]
|
||||
)
|
||||
# traverse column index.
|
||||
for c in range(cute.size(tScS_mn.shape[1])):
|
||||
for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])):
|
||||
# only consider the column index, so the row index sets to 0.
|
||||
if cute.elem_less(col_idx_limit, tScS_mn[0, c][3] + 1):
|
||||
acc_S_mn[r, c] = -cutlass.Float32.inf
|
||||
@ -1021,10 +1021,10 @@ class FlashAttentionForwardAmpere:
|
||||
row_max_cur_row = self._threadquad_reduce_max(row_max_cur_row)
|
||||
row_max_prev_row = None
|
||||
# if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row.
|
||||
if not is_first_n_block:
|
||||
if cutlass.const_expr(not is_first_n_block):
|
||||
row_max_prev_row = row_max_prev[r]
|
||||
row_max_cur_row = cute.arch.fmax(row_max_prev_row, row_max_cur_row)
|
||||
if self._is_causal:
|
||||
if cutlass.const_expr(self._is_causal):
|
||||
row_max_cur_row = (
|
||||
0.0 if row_max_cur_row == -cutlass.Float32.inf else row_max_cur_row
|
||||
)
|
||||
@ -1043,7 +1043,7 @@ class FlashAttentionForwardAmpere:
|
||||
cute.ReductionOp.ADD, cutlass.Float32.zero, 0
|
||||
)
|
||||
# if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum.
|
||||
if not is_first_n_block:
|
||||
if cutlass.const_expr(not is_first_n_block):
|
||||
prev_minus_cur_exp = self._exp2f(
|
||||
row_max_prev_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2
|
||||
@ -1072,7 +1072,7 @@ class FlashAttentionForwardAmpere:
|
||||
"""
|
||||
# do quad reduction for row_sum.
|
||||
acc_O_mn = self._make_acc_tensor_mn_view(acc_O)
|
||||
for r in range(cute.size(row_sum)):
|
||||
for r in cutlass.range_constexpr(cute.size(row_sum)):
|
||||
row_sum[r] = self._threadquad_reduce_sum(row_sum[r])
|
||||
# if row_sum is zero or nan, set acc_O_mn_row to 1.0
|
||||
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
|
||||
|
||||
@ -35,6 +35,8 @@ import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@ -109,6 +111,7 @@ class SGemm:
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
stream: cuda.CUstream = cuda.CUstream(cuda.CUstream_flags.CU_STREAM_DEFAULT),
|
||||
):
|
||||
self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
|
||||
self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
|
||||
@ -168,7 +171,7 @@ class SGemm:
|
||||
num_bits_per_copy=mB.element_type.width,
|
||||
)
|
||||
|
||||
if self.a_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
if cutlass.const_expr(self.a_major_mode == utils.LayoutEnum.COL_MAJOR):
|
||||
num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1
|
||||
atom_async_copy_A = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
@ -182,7 +185,7 @@ class SGemm:
|
||||
)
|
||||
vA = cute.make_layout((num_vectorized, 1))
|
||||
|
||||
if self.b_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
if cutlass.const_expr(self.b_major_mode == utils.LayoutEnum.COL_MAJOR):
|
||||
num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1
|
||||
atom_async_copy_B = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
@ -222,7 +225,7 @@ class SGemm:
|
||||
atoms_layout = cute.make_layout(
|
||||
(self._num_threads // 16, 16, 1), stride=(16, 1, 0)
|
||||
)
|
||||
if self.c_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
if cutlass.const_expr(self.c_major_mode == utils.LayoutEnum.COL_MAJOR):
|
||||
atoms_layout = cute.make_layout(
|
||||
(16, self._num_threads // 16, 1), stride=(1, 16, 0)
|
||||
)
|
||||
@ -256,6 +259,7 @@ class SGemm:
|
||||
grid=grid_dim,
|
||||
block=[cute.size(atoms_layout), 1, 1],
|
||||
smem=smem_size,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
@ -540,8 +544,8 @@ class SGemm:
|
||||
# 3. Combining the smem and register pipelines results in the mainloop.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
for _ in cutlass.range_dynamic(k_tile_count, unroll=1):
|
||||
for k_block in range(k_block_max):
|
||||
for _ in range(k_tile_count):
|
||||
for k_block in range(k_block_max, unroll_full=True):
|
||||
if k_block == k_block_max - 1:
|
||||
tCsA_p = tCsA[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB[None, None, None, smem_pipe_read]
|
||||
@ -639,7 +643,6 @@ def main(
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
):
|
||||
torch.manual_seed(1024)
|
||||
M, N, K = problem_shape
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
@ -694,51 +697,36 @@ def main(
|
||||
|
||||
sgemm = SGemm()
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor)
|
||||
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream)
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Use the current stream for CUDA events instead of the default stream
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
kernel_arguments=testing.JitArguments(
|
||||
a_tensor, b_tensor, c_tensor, current_stream
|
||||
),
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
use_cuda_graphs=False,
|
||||
stream=current_stream,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {elapsed_time / iterations:.4f} ms")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
torch.cuda.synchronize()
|
||||
print("Verifying results...")
|
||||
ref = torch.einsum("mk,nk->mn", a, b)
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
@ -768,6 +756,9 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running SIMT GEMM example:")
|
||||
|
||||
torch.manual_seed(1024)
|
||||
|
||||
main(
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
|
||||
@ -36,6 +36,7 @@ import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
@ -48,6 +49,7 @@ A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE D
|
||||
|
||||
This GEMM kernel supports the following features:
|
||||
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
|
||||
- Threadblock rasterization to improve data re-use
|
||||
- Supports multi-stage pipeline to overlap computation and memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesed global memory access
|
||||
|
||||
@ -253,6 +255,22 @@ class TensorOpGemm:
|
||||
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l)
|
||||
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
|
||||
# Add threadblock rasterization to improve re-use of data
|
||||
raster_factor = 1
|
||||
grid_dim_n = cute.size(grid_dim[1])
|
||||
# Thresholds picked so that it doesn't cause too many no-op CTAs
|
||||
if grid_dim_n > 5:
|
||||
raster_factor = 8
|
||||
elif grid_dim_n > 2:
|
||||
raster_factor = 4
|
||||
elif grid_dim_n > 1:
|
||||
raster_factor = 2
|
||||
rasterization_remap_grid_dim = (
|
||||
cute.size(grid_dim[0]) * raster_factor,
|
||||
(cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor,
|
||||
cute.size(grid_dim[2]),
|
||||
)
|
||||
|
||||
self.kernel(
|
||||
mA,
|
||||
mB,
|
||||
@ -264,9 +282,10 @@ class TensorOpGemm:
|
||||
tiled_copy_B,
|
||||
tiled_copy_C,
|
||||
tiled_mma,
|
||||
raster_factor,
|
||||
epilogue_op,
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
grid=rasterization_remap_grid_dim,
|
||||
block=[self.num_threads, 1, 1],
|
||||
smem=smem_size,
|
||||
)
|
||||
@ -284,436 +303,445 @@ class TensorOpGemm:
|
||||
tiled_copy_B: cute.TiledCopy,
|
||||
tiled_copy_C: cute.TiledCopy,
|
||||
tiled_mma: cute.TiledMma,
|
||||
rasterization_factor: cutlass.Int32,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
# Thread index, block index
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
tiler_coord = (bidx, bidy, None)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread block.
|
||||
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
gA = cute.local_tile(
|
||||
mA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
gB = cute.local_tile(
|
||||
mB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
gC = cute.local_tile(
|
||||
mC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
offset_tile_x, offset_tile_y = self.raster_tile(
|
||||
bidx, bidy, rasterization_factor
|
||||
)
|
||||
# Early exit if CTA is out of range
|
||||
if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y:
|
||||
pass
|
||||
else:
|
||||
tiler_coord = (offset_tile_x, offset_tile_y, None)
|
||||
|
||||
# By default, if the tensor k mode does not divide into the tile k
|
||||
# size, then last tiles in the k dimension are irregular.
|
||||
# Instead, make the first tiles irregular when k is irregular.
|
||||
# This allows us to handle the irregular tile first to avoid
|
||||
# checking for this condition within the mainloop.
|
||||
|
||||
# residual_k is a negative number indicating the amount needed to
|
||||
# shift the pointer by in dimension k
|
||||
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
|
||||
gA, mode=[2]
|
||||
)
|
||||
|
||||
# move the pointer of gA/gB in the `-k` direction
|
||||
gA = cute.domain_offset((0, residual_k, 0), gA)
|
||||
gB = cute.domain_offset((0, residual_k, 0), gB)
|
||||
# input is 16B aligned
|
||||
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
|
||||
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
|
||||
|
||||
# Construct identity layout for sA and sB (mirrors global tensors,
|
||||
# used for predication only)
|
||||
mcA = cute.make_identity_tensor(mA.layout.shape)
|
||||
mcB = cute.make_identity_tensor(mB.layout.shape)
|
||||
cA = cute.local_tile(
|
||||
mcA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
cB = cute.local_tile(
|
||||
mcB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
|
||||
cA = cute.domain_offset((0, residual_k, 0), cA)
|
||||
cB = cute.domain_offset((0, residual_k, 0), cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create shared memory buffers and get the appropriate fragments for this thread.
|
||||
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
|
||||
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
|
||||
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Shared memory buffer
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
|
||||
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
|
||||
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
|
||||
sC = cute.make_tensor(
|
||||
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
|
||||
)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
tAgA = thr_copy_A.partition_S(gA)
|
||||
tAsA = thr_copy_A.partition_D(sA)
|
||||
tBgB = thr_copy_B.partition_S(gB)
|
||||
tBsB = thr_copy_B.partition_D(sB)
|
||||
tCsC_epilogue = thr_copy_C.partition_S(sC)
|
||||
tCgC_epilogue = thr_copy_C.partition_D(gC)
|
||||
|
||||
# Repeat the partitioning with identity layouts
|
||||
tAcA = thr_copy_A.partition_S(cA)
|
||||
tBcB = thr_copy_B.partition_S(cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
||||
# of tile_shape
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# For predication over the tensors A (M/K), B (N/K), and (in the
|
||||
# epilogue) C (M/N), we will compute it in a fashion similar to an
|
||||
# outer product. The predication along one of the dimensions is
|
||||
# evaluated and stored in a predication tensor. Then, the
|
||||
# predication for the remaining dimension is handled later via an
|
||||
# if/else branch at the copy.
|
||||
# For A and B, predication booleans along M/N are stored in a
|
||||
# predication tensor and along K is handled via a if/else branch.
|
||||
|
||||
# Allocate predicate tensors for M and N. Predication is checked
|
||||
# at the granularity of a copy atom, so the predicate tensor does not
|
||||
# need separate booleans for individual elements within a copy
|
||||
# atom (for example, the elements of tAgA.shape[0][0].)
|
||||
tApA = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAgA.shape[0][1],
|
||||
cute.size(tAgA, mode=[1]),
|
||||
cute.size(tAgA, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tAgA, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tBsB, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Set predicates for M/N bounds
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cute.elem_less(
|
||||
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
|
||||
)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cute.elem_less(
|
||||
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tAsA.fill(0)
|
||||
tBsB.fill(0)
|
||||
cute.arch.sync_threads()
|
||||
# Start async loads for the first k-tile. Here we take care of the k residue
|
||||
# via if/else check along the k dimension. Because we shifted the identity tensor
|
||||
# by the residue_k and because the identity tensor is a counting tensor, the
|
||||
# values of any identity tensor element that is poison is less than -1
|
||||
num_smem_stages = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
k_tile_index = cutlass.Int32(0)
|
||||
|
||||
for k in range(tApA.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, k, k_tile_index],
|
||||
tAsA[None, None, k, 0],
|
||||
pred=tApA[None, None, k],
|
||||
)
|
||||
for k in range(tBpB.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, k, k_tile_index],
|
||||
tBsB[None, None, k, 0],
|
||||
pred=tBpB[None, None, k],
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# Start async loads for rest of the k-tiles
|
||||
for k_tile in range(1, num_smem_stages - 1):
|
||||
if k_tile == k_tile_count:
|
||||
tApA.fill(0)
|
||||
tBpB.fill(0)
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, k_tile],
|
||||
pred=tApA,
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread block.
|
||||
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
gA = cute.local_tile(
|
||||
mA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, k_tile],
|
||||
pred=tBpB,
|
||||
gB = cute.local_tile(
|
||||
mB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
gC = cute.local_tile(
|
||||
mC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
|
||||
# By default, if the tensor k mode does not divide into the tile k
|
||||
# size, then last tiles in the k dimension are irregular.
|
||||
# Instead, make the first tiles irregular when k is irregular.
|
||||
# This allows us to handle the irregular tile first to avoid
|
||||
# checking for this condition within the mainloop.
|
||||
|
||||
# residual_k is a negative number indicating the amount needed to
|
||||
# shift the pointer by in dimension k
|
||||
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
|
||||
gA, mode=[2]
|
||||
)
|
||||
|
||||
# move the pointer of gA/gB in the `-k` direction
|
||||
gA = cute.domain_offset((0, residual_k, 0), gA)
|
||||
gB = cute.domain_offset((0, residual_k, 0), gB)
|
||||
# input is 16B aligned
|
||||
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
|
||||
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
|
||||
|
||||
# Construct identity layout for sA and sB (mirrors global tensors,
|
||||
# used for predication only)
|
||||
mcA = cute.make_identity_tensor(mA.layout.shape)
|
||||
mcB = cute.make_identity_tensor(mB.layout.shape)
|
||||
cA = cute.local_tile(
|
||||
mcA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
cB = cute.local_tile(
|
||||
mcB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
|
||||
cA = cute.domain_offset((0, residual_k, 0), cA)
|
||||
cB = cute.domain_offset((0, residual_k, 0), cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create shared memory buffers and get the appropriate fragments for this thread.
|
||||
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
|
||||
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
|
||||
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Shared memory buffer
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
|
||||
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
|
||||
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
|
||||
sC = cute.make_tensor(
|
||||
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
|
||||
)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
tAgA = thr_copy_A.partition_S(gA)
|
||||
tAsA = thr_copy_A.partition_D(sA)
|
||||
tBgB = thr_copy_B.partition_S(gB)
|
||||
tBsB = thr_copy_B.partition_D(sB)
|
||||
tCsC_epilogue = thr_copy_C.partition_S(sC)
|
||||
tCgC_epilogue = thr_copy_C.partition_D(gC)
|
||||
|
||||
# Repeat the partitioning with identity layouts
|
||||
tAcA = thr_copy_A.partition_S(cA)
|
||||
tBcB = thr_copy_B.partition_S(cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
||||
# of tile_shape
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# For predication over the tensors A (M/K), B (N/K), and (in the
|
||||
# epilogue) C (M/N), we will compute it in a fashion similar to an
|
||||
# outer product. The predication along one of the dimensions is
|
||||
# evaluated and stored in a predication tensor. Then, the
|
||||
# predication for the remaining dimension is handled later via an
|
||||
# if/else branch at the copy.
|
||||
# For A and B, predication booleans along M/N are stored in a
|
||||
# predication tensor and along K is handled via a if/else branch.
|
||||
|
||||
# Allocate predicate tensors for M and N. Predication is checked
|
||||
# at the granularity of a copy atom, so the predicate tensor does not
|
||||
# need separate booleans for individual elements within a copy
|
||||
# atom (for example, the elements of tAgA.shape[0][0].)
|
||||
tApA = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAgA.shape[0][1],
|
||||
cute.size(tAgA, mode=[1]),
|
||||
cute.size(tAgA, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tAgA, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tBsB, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Set predicates for M/N bounds
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cute.elem_less(
|
||||
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
|
||||
)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cute.elem_less(
|
||||
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tAsA.fill(0)
|
||||
tBsB.fill(0)
|
||||
cute.arch.sync_threads()
|
||||
# Start async loads for the first k-tile. Here we take care of the k residue
|
||||
# via if/else check along the k dimension. Because we shifted the identity tensor
|
||||
# by the residue_k and because the identity tensor is a counting tensor, the
|
||||
# values of any identity tensor element that is poison is less than -1
|
||||
num_smem_stages = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
k_tile_index = cutlass.Int32(0)
|
||||
|
||||
for k in range(tApA.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, k, k_tile_index],
|
||||
tAsA[None, None, k, 0],
|
||||
pred=tApA[None, None, k],
|
||||
)
|
||||
for k in range(tBpB.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, k, k_tile_index],
|
||||
tBsB[None, None, k, 0],
|
||||
pred=tBpB[None, None, k],
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tile MMA compute thread partitions and allocate accumulators
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCsC = thr_mma.partition_C(sC)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||||
# Clear the accumulator
|
||||
tCrC.fill(0.0)
|
||||
# Start async loads for rest of the k-tiles
|
||||
for k_tile in range(1, num_smem_stages - 1):
|
||||
if k_tile == k_tile_count:
|
||||
tApA.fill(0)
|
||||
tBpB.fill(0)
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, k_tile],
|
||||
pred=tApA,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, k_tile],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Copy Atom A/B retiling
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tile MMA compute thread partitions and allocate accumulators
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCsC = thr_mma.partition_C(sC)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||||
# Clear the accumulator
|
||||
tCrC.fill(0.0)
|
||||
|
||||
# Create the copy atoms for the copy from shared memory to register
|
||||
atom_copy_s2r_A = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mA.element_type,
|
||||
)
|
||||
atom_copy_s2r_B = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mB.element_type,
|
||||
)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Copy Atom A/B retiling
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Creates the tiled copy so that it matches the thread-value layout
|
||||
# expected by the tiled mma
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy(
|
||||
atom_copy_s2r_A,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy(
|
||||
atom_copy_s2r_B,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
|
||||
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
||||
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
||||
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
|
||||
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
|
||||
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
|
||||
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
|
||||
|
||||
# Current pipe index in smem to read from / write to
|
||||
smem_pipe_read = 0
|
||||
smem_pipe_write = num_smem_stages - 1
|
||||
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# PREFETCH register pipeline
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
num_k_block = cute.size(tCrA, mode=[2])
|
||||
if num_k_block > 1:
|
||||
# Wait until our first prefetched tile is loaded in
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
# Prefetch the first k-block rmem from the first k-tile
|
||||
cute.copy(
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, 0],
|
||||
tCrA_copy_view[None, None, 0],
|
||||
# Create the copy atoms for the copy from shared memory to register
|
||||
atom_copy_s2r_A = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mA.element_type,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, 0],
|
||||
tCrB_copy_view[None, None, 0],
|
||||
atom_copy_s2r_B = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mB.element_type,
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Mainloop
|
||||
# 1. Shared memory pipeline (gmem -> smem):
|
||||
# The default smem pipeline depth is 3, meaning that for shared
|
||||
# memory buffers, we allocate three times the size described by the
|
||||
# CTA tiler. We prefetch 2 of these buffers before entering the main
|
||||
# loop. Considering only the transfer from global memory to shared
|
||||
# memory, the general structure of the mainloop is:
|
||||
# (1) copy k-tile from gmem to smem;
|
||||
# (2) perform gemm computation on k-tile;
|
||||
# (3) wait for the next copy to finish.
|
||||
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
|
||||
# waits for the number of unfinished 'copy' to be <= 1. The advantage
|
||||
# of this approach is that it allows for simultaneous production
|
||||
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
|
||||
# A common misconception is to prefetch N buffers and rewrite
|
||||
# the pipeline logic to wait on N-1 pending copies. The disadvantage
|
||||
# of this approach is that it requires fully consuming a buffer in
|
||||
# order to open an empty buffer for the next copy.
|
||||
# 2. Register pipeline (smem -> register):
|
||||
# Similarly, the register pipeline produces i+1, consumes i, and
|
||||
# produces i+2... Notably, i and i+1 do not use the same register,
|
||||
# eliminating dependencies on the same register for better parallelism.
|
||||
# 3. Combining the smem and register pipelines results in the mainloop.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
|
||||
for k_block in range(num_k_block):
|
||||
if k_block == num_k_block - 1:
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
# Creates the tiled copy so that it matches the thread-value layout
|
||||
# expected by the tiled mma
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy(
|
||||
atom_copy_s2r_A,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy(
|
||||
atom_copy_s2r_B,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
|
||||
# Load A, B from shared memory to registers for k_block + 1
|
||||
k_block_next = (k_block + 1) % num_k_block # static
|
||||
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
||||
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
||||
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
|
||||
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
|
||||
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
|
||||
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
|
||||
|
||||
# Current pipe index in smem to read from / write to
|
||||
smem_pipe_read = 0
|
||||
smem_pipe_write = num_smem_stages - 1
|
||||
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# PREFETCH register pipeline
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
num_k_block = cute.size(tCrA, mode=[2])
|
||||
if num_k_block > 1:
|
||||
# Wait until our first prefetched tile is loaded in
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
# Prefetch the first k-block rmem from the first k-tile
|
||||
cute.copy(
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, k_block_next],
|
||||
tCrA_copy_view[None, None, k_block_next],
|
||||
tCsA_p[None, None, 0],
|
||||
tCrA_copy_view[None, None, 0],
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, k_block_next],
|
||||
tCrB_copy_view[None, None, k_block_next],
|
||||
tCsB_p[None, None, 0],
|
||||
tCrB_copy_view[None, None, 0],
|
||||
)
|
||||
|
||||
# Fetch next A: To better interleave global memory access and compute
|
||||
# instructions, we intentionally use the sequence: copy A, perform GEMM,
|
||||
# then copy B.
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, smem_pipe_write],
|
||||
pred=tApA,
|
||||
)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Mainloop
|
||||
# 1. Shared memory pipeline (gmem -> smem):
|
||||
# The default smem pipeline depth is 3, meaning that for shared
|
||||
# memory buffers, we allocate three times the size described by the
|
||||
# CTA tiler. We prefetch 2 of these buffers before entering the main
|
||||
# loop. Considering only the transfer from global memory to shared
|
||||
# memory, the general structure of the mainloop is:
|
||||
# (1) copy k-tile from gmem to smem;
|
||||
# (2) perform gemm computation on k-tile;
|
||||
# (3) wait for the next copy to finish.
|
||||
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
|
||||
# waits for the number of unfinished 'copy' to be <= 1. The advantage
|
||||
# of this approach is that it allows for simultaneous production
|
||||
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
|
||||
# A common misconception is to prefetch N buffers and rewrite
|
||||
# the pipeline logic to wait on N-1 pending copies. The disadvantage
|
||||
# of this approach is that it requires fully consuming a buffer in
|
||||
# order to open an empty buffer for the next copy.
|
||||
# 2. Register pipeline (smem -> register):
|
||||
# Similarly, the register pipeline produces i+1, consumes i, and
|
||||
# produces i+2... Notably, i and i+1 do not use the same register,
|
||||
# eliminating dependencies on the same register for better parallelism.
|
||||
# 3. Combining the smem and register pipelines results in the mainloop.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
for k_tile in range(k_tile_count):
|
||||
for k_block in cutlass.range(num_k_block, unroll_full=True):
|
||||
if k_block == num_k_block - 1:
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# Thread-level register gemm for k_block
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCrC,
|
||||
tCrA[None, None, k_block],
|
||||
tCrB[None, None, k_block],
|
||||
tCrC,
|
||||
)
|
||||
|
||||
# Fetch next B and update smem pipeline read/write
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, smem_pipe_write],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
smem_pipe_write = smem_pipe_read
|
||||
smem_pipe_read = smem_pipe_read + 1
|
||||
if smem_pipe_read == num_smem_stages:
|
||||
smem_pipe_read = 0
|
||||
|
||||
# Sync before epilogue
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Epilogue with fusion
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
|
||||
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
|
||||
|
||||
# Copy results of D back to shared memory
|
||||
cute.autovec_copy(tCrD, tCsC)
|
||||
|
||||
# Create counting tensor for C
|
||||
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
mcC = cute.make_identity_tensor(
|
||||
(
|
||||
cute.size(ceilM) * self.cta_tiler[0],
|
||||
cute.size(ceilN) * self.cta_tiler[1],
|
||||
1,
|
||||
)
|
||||
)
|
||||
cC = cute.local_tile(
|
||||
mcC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
tCcC = thr_copy_C.partition_S(cC)
|
||||
|
||||
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
|
||||
# Wait for all writes to shared memory to finish before starting copies
|
||||
# using the new layouts
|
||||
cute.arch.sync_threads()
|
||||
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
|
||||
|
||||
# Create predication tensor for m
|
||||
tCpC = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tCgC_epilogue.shape[0][1],
|
||||
cute.size(tCgC_epilogue, mode=[1]),
|
||||
cute.size(tCgC_epilogue, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for m in range(tCpC.shape[1]):
|
||||
tCpC[rest_v, m, 0] = cute.elem_less(
|
||||
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
|
||||
)
|
||||
|
||||
# Copy to global memory using better vectorization
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for n in range(tCpC.shape[2]):
|
||||
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
|
||||
# Load A, B from shared memory to registers for k_block + 1
|
||||
k_block_next = (k_block + 1) % num_k_block # static
|
||||
cute.copy(
|
||||
tiled_copy_C,
|
||||
tCrC_epilogue[None, None, n],
|
||||
tCgC_epilogue[None, None, n],
|
||||
pred=tCpC[None, None, n],
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, k_block_next],
|
||||
tCrA_copy_view[None, None, k_block_next],
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, k_block_next],
|
||||
tCrB_copy_view[None, None, k_block_next],
|
||||
)
|
||||
|
||||
# Fetch next A: To better interleave global memory access and compute
|
||||
# instructions, we intentionally use the sequence: copy A, perform GEMM,
|
||||
# then copy B.
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, smem_pipe_write],
|
||||
pred=tApA,
|
||||
)
|
||||
|
||||
# Thread-level register gemm for k_block
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCrC,
|
||||
tCrA[None, None, k_block],
|
||||
tCrB[None, None, k_block],
|
||||
tCrC,
|
||||
)
|
||||
|
||||
# Fetch next B and update smem pipeline read/write
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, smem_pipe_write],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
smem_pipe_write = smem_pipe_read
|
||||
smem_pipe_read = smem_pipe_read + 1
|
||||
if smem_pipe_read == num_smem_stages:
|
||||
smem_pipe_read = 0
|
||||
|
||||
# Sync before epilogue
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Epilogue with fusion
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
|
||||
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
|
||||
|
||||
# Copy results of D back to shared memory
|
||||
cute.autovec_copy(tCrD, tCsC)
|
||||
|
||||
# Create counting tensor for C
|
||||
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
mcC = cute.make_identity_tensor(
|
||||
(
|
||||
cute.size(ceilM) * self.cta_tiler[0],
|
||||
cute.size(ceilN) * self.cta_tiler[1],
|
||||
1,
|
||||
)
|
||||
)
|
||||
cC = cute.local_tile(
|
||||
mcC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
tCcC = thr_copy_C.partition_S(cC)
|
||||
|
||||
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
|
||||
# Wait for all writes to shared memory to finish before starting copies
|
||||
# using the new layouts
|
||||
cute.arch.sync_threads()
|
||||
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
|
||||
|
||||
# Create predication tensor for m
|
||||
tCpC = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tCgC_epilogue.shape[0][1],
|
||||
cute.size(tCgC_epilogue, mode=[1]),
|
||||
cute.size(tCgC_epilogue, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for m in range(tCpC.shape[1]):
|
||||
tCpC[rest_v, m, 0] = cute.elem_less(
|
||||
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
|
||||
)
|
||||
|
||||
# Copy to global memory using better vectorization
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for n in range(tCpC.shape[2]):
|
||||
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
|
||||
cute.copy(
|
||||
tiled_copy_C,
|
||||
tCrC_epilogue[None, None, n],
|
||||
tCgC_epilogue[None, None, n],
|
||||
pred=tCpC[None, None, n],
|
||||
)
|
||||
return
|
||||
|
||||
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
|
||||
@ -811,6 +839,11 @@ class TensorOpGemm:
|
||||
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
|
||||
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
|
||||
|
||||
def raster_tile(self, i, j, f):
|
||||
new_i = i // f
|
||||
new_j = (i % f) + (j * f)
|
||||
return (new_i, new_j)
|
||||
|
||||
|
||||
def run_tensor_op_gemm(
|
||||
a_major: str,
|
||||
@ -892,15 +925,18 @@ def run_tensor_op_gemm(
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor),
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
use_cuda_graphs=False,
|
||||
)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
@ -35,6 +35,7 @@ import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
@ -211,7 +212,7 @@ class DenseGemmKernel:
|
||||
|
||||
self.occupancy = 1
|
||||
self.threads_per_cta = 128
|
||||
self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -283,7 +284,7 @@ class DenseGemmKernel:
|
||||
self.epi_tile,
|
||||
self.c_dtype,
|
||||
self.c_layout,
|
||||
self.num_smem_capacity,
|
||||
self.smem_capacity,
|
||||
self.occupancy,
|
||||
self.use_tma_store,
|
||||
)
|
||||
@ -308,7 +309,7 @@ class DenseGemmKernel:
|
||||
self.epi_tile,
|
||||
self.num_c_stage,
|
||||
)
|
||||
if cutlass.const_expr(self.use_tma_store)
|
||||
if self.use_tma_store
|
||||
else None
|
||||
)
|
||||
|
||||
@ -372,9 +373,11 @@ class DenseGemmKernel:
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
|
||||
# Setup TMA load for A
|
||||
a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast)
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
||||
self.cluster_shape_mn, tiled_mma.thr_id
|
||||
)
|
||||
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A(
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op,
|
||||
a,
|
||||
a_smem_layout,
|
||||
@ -387,9 +390,11 @@ class DenseGemmKernel:
|
||||
)
|
||||
|
||||
# Setup TMA load for B
|
||||
b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast)
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
||||
self.cluster_shape_mn, tiled_mma.thr_id
|
||||
)
|
||||
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B(
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op,
|
||||
b,
|
||||
b_smem_layout,
|
||||
@ -413,7 +418,7 @@ class DenseGemmKernel:
|
||||
cute.make_identity_layout(c.shape), self.epi_tile
|
||||
)
|
||||
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
|
||||
tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom(
|
||||
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
@ -426,9 +431,7 @@ class DenseGemmKernel:
|
||||
self.buffer_align_bytes = 1024
|
||||
|
||||
c_smem_size = (
|
||||
cute.cosize(self.c_smem_layout_staged.outer)
|
||||
if cutlass.const_expr(self.use_tma_store)
|
||||
else 0
|
||||
cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0
|
||||
)
|
||||
|
||||
# Define shared storage for kernel
|
||||
@ -472,7 +475,7 @@ class DenseGemmKernel:
|
||||
tma_atom_b,
|
||||
tma_tensor_b,
|
||||
tma_atom_c,
|
||||
tma_tensor_c if cutlass.const_expr(self.use_tma_store) else c,
|
||||
tma_tensor_c if self.use_tma_store else c,
|
||||
self.cluster_layout_vmnk,
|
||||
self.a_smem_layout_staged,
|
||||
self.b_smem_layout_staged,
|
||||
@ -556,12 +559,12 @@ class DenseGemmKernel:
|
||||
tmem_holding_buf = storage.tmem_holding_buf
|
||||
|
||||
# Initialize mainloop ab_pipeline (barrier) and states
|
||||
ab_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
|
||||
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
||||
ab_pipeline_consumer_group = utils.CooperativeGroup(
|
||||
utils.Agent.Thread, num_tma_producer
|
||||
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, num_tma_producer
|
||||
)
|
||||
ab_pipeline = utils.PipelineTmaUmma.create(
|
||||
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
|
||||
num_stages=self.num_ab_stage,
|
||||
producer_group=ab_pipeline_producer_group,
|
||||
@ -569,30 +572,30 @@ class DenseGemmKernel:
|
||||
tx_count=self.num_tma_load_bytes,
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
)
|
||||
ab_producer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Producer, self.num_ab_stage
|
||||
ab_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num_ab_stage
|
||||
)
|
||||
ab_consumer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Consumer, self.num_ab_stage
|
||||
ab_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
||||
)
|
||||
|
||||
# Initialize acc_pipeline (barrier) and states
|
||||
acc_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
|
||||
acc_pipeline_consumer_group = utils.CooperativeGroup(
|
||||
utils.Agent.Thread, self.threads_per_cta, self.threads_per_cta
|
||||
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
|
||||
)
|
||||
acc_pipeline = utils.PipelineUmmaAsync.create(
|
||||
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
||||
num_stages=self.num_acc_stage,
|
||||
producer_group=acc_pipeline_producer_group,
|
||||
consumer_group=acc_pipeline_consumer_group,
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
)
|
||||
acc_producer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Producer, self.num_acc_stage
|
||||
acc_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num_acc_stage
|
||||
)
|
||||
acc_consumer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Consumer, self.num_acc_stage
|
||||
acc_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
)
|
||||
|
||||
# Tensor memory dealloc barrier init
|
||||
@ -600,7 +603,7 @@ class DenseGemmKernel:
|
||||
if warp_idx == 0:
|
||||
num_tmem_dealloc_threads = 32
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init_arrive_cnt(
|
||||
cute.arch.mbarrier_init(
|
||||
tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads
|
||||
)
|
||||
cute.arch.mbarrier_init_fence()
|
||||
@ -617,7 +620,7 @@ class DenseGemmKernel:
|
||||
storage.sC.get_tensor(
|
||||
c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner
|
||||
)
|
||||
if cutlass.const_expr(self.use_tma_store)
|
||||
if self.use_tma_store
|
||||
else None
|
||||
)
|
||||
# (MMA, MMA_M, MMA_K, STAGE)
|
||||
@ -634,7 +637,7 @@ class DenseGemmKernel:
|
||||
#
|
||||
a_full_mcast_mask = None
|
||||
b_full_mcast_mask = None
|
||||
if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs:
|
||||
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
||||
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
||||
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
@ -645,15 +648,15 @@ class DenseGemmKernel:
|
||||
#
|
||||
# Local_tile partition global tensors
|
||||
#
|
||||
# (bM, bK, loopM, loopK, loopL)
|
||||
# (bM, bK, RestM, RestK, RestL)
|
||||
gA_mkl = cute.local_tile(
|
||||
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
|
||||
)
|
||||
# (bN, bK, loopN, loopK, loopL)
|
||||
# (bN, bK, RestN, RestK, RestL)
|
||||
gB_nkl = cute.local_tile(
|
||||
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
|
||||
)
|
||||
# (bM, bN, loopM, loopN, loopL)
|
||||
# (bM, bN, RestM, RestN, RestL)
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||||
)
|
||||
@ -663,11 +666,11 @@ class DenseGemmKernel:
|
||||
# Partition global tensor for TiledMMA_A/B/C
|
||||
#
|
||||
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
||||
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
|
||||
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
|
||||
tCgA = thr_mma.partition_A(gA_mkl)
|
||||
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
|
||||
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
|
||||
tCgB = thr_mma.partition_B(gB_nkl)
|
||||
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
|
||||
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
|
||||
tCgC = thr_mma.partition_C(gC_mnl)
|
||||
|
||||
#
|
||||
@ -678,7 +681,7 @@ class DenseGemmKernel:
|
||||
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
||||
)
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), loopM, loopK, loopL)
|
||||
# ((atom_v, rest_v), RestM, RestK, RestL)
|
||||
tAsA, tAgA = cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
block_in_cluster_coord_vmnk[2],
|
||||
@ -691,7 +694,7 @@ class DenseGemmKernel:
|
||||
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
||||
)
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), loopN, loopK, loopL)
|
||||
# ((atom_v, rest_v), RestN, RestK, RestL)
|
||||
tBsB, tBgB = cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
block_in_cluster_coord_vmnk[1],
|
||||
@ -771,9 +774,9 @@ class DenseGemmKernel:
|
||||
#
|
||||
# Slice to per mma tile index
|
||||
#
|
||||
# ((atom_v, rest_v), loopK)
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])]
|
||||
# ((atom_v, rest_v), loopK)
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])]
|
||||
if cutlass.const_expr(self.use_tma_store):
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
||||
@ -797,7 +800,7 @@ class DenseGemmKernel:
|
||||
#
|
||||
# Prefetch TMA load A/B
|
||||
#
|
||||
for prefetch_idx in cutlass.range_dynamic(prefetch_k_block_cnt, unroll=1):
|
||||
for prefetch_idx in cutlass.range(prefetch_k_block_cnt, unroll=1):
|
||||
# Conditionally wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
||||
|
||||
@ -833,7 +836,7 @@ class DenseGemmKernel:
|
||||
#
|
||||
# MMA mainloop
|
||||
#
|
||||
for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1):
|
||||
for k_block in range(k_block_cnt):
|
||||
# Conditionally wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
||||
|
||||
@ -860,7 +863,7 @@ class DenseGemmKernel:
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in range(num_kphases):
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (None, None, kphase_idx, ab_consumer_state.index)
|
||||
|
||||
cute.gemm(
|
||||
@ -917,10 +920,10 @@ class DenseGemmKernel:
|
||||
c_pipeline = None
|
||||
if cutlass.const_expr(self.use_tma_store):
|
||||
# Initialize tma store c_pipeline
|
||||
c_producer_group = utils.CooperativeGroup(
|
||||
utils.Agent.Thread, self.threads_per_cta, self.threads_per_cta
|
||||
c_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
|
||||
)
|
||||
c_pipeline = utils.PipelineTmaStore.create(
|
||||
c_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=c_producer_group,
|
||||
)
|
||||
@ -929,7 +932,7 @@ class DenseGemmKernel:
|
||||
# Store accumulator to global memory in subtiles
|
||||
#
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
for subtile_idx in cutlass.range_dynamic(subtile_cnt):
|
||||
for subtile_idx in range(subtile_cnt):
|
||||
#
|
||||
# Load accumulator from tensor memory buffer to register
|
||||
#
|
||||
@ -1007,7 +1010,7 @@ class DenseGemmKernel:
|
||||
#
|
||||
if warp_idx == 0:
|
||||
# Reverse prefetch_k_block_cnt times to next available buffer
|
||||
for i in cutlass.range_dynamic(prefetch_k_block_cnt):
|
||||
for i in range(prefetch_k_block_cnt):
|
||||
ab_producer_state.reverse()
|
||||
ab_pipeline.producer_tail(ab_producer_state)
|
||||
return
|
||||
@ -1063,11 +1066,11 @@ class DenseGemmKernel:
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
||||
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
|
||||
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
gC_mnl_epi = cute.flat_divide(
|
||||
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
||||
)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
tTR_rAcc = cute.make_fragment(
|
||||
@ -1149,7 +1152,7 @@ class DenseGemmKernel:
|
||||
- tTR_gC: The partitioned global tensor C
|
||||
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
|
||||
"""
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
gC_epi = cute.flat_divide(
|
||||
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
||||
)
|
||||
@ -1158,7 +1161,7 @@ class DenseGemmKernel:
|
||||
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
|
||||
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
@ -1169,7 +1172,7 @@ class DenseGemmKernel:
|
||||
return tma_atom_c, bSG_sC, bSG_gC
|
||||
else:
|
||||
tiled_copy_t2r = atom
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||||
tTR_gC = thr_copy_t2r.partition_D(gC_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
@ -1188,7 +1191,7 @@ class DenseGemmKernel:
|
||||
epi_tile: cute.Tile,
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
c_layout: utils.LayoutEnum,
|
||||
num_smem_capacity: int,
|
||||
smem_capacity: int,
|
||||
occupancy: int,
|
||||
use_tma_store: bool,
|
||||
) -> Tuple[int, int, int]:
|
||||
@ -1208,8 +1211,8 @@ class DenseGemmKernel:
|
||||
:type c_dtype: type[cutlass.Numeric]
|
||||
:param c_layout: Layout enum of operand C in global memory.
|
||||
:type c_layout: utils.LayoutEnum
|
||||
:param num_smem_capacity: Total available shared memory capacity in bytes.
|
||||
:type num_smem_capacity: int
|
||||
:param smem_capacity: Total available shared memory capacity in bytes.
|
||||
:type smem_capacity: int
|
||||
:param occupancy: Target number of CTAs per SM (occupancy).
|
||||
:type occupancy: int
|
||||
:param use_tma_store: Whether TMA store is enabled.
|
||||
@ -1263,7 +1266,7 @@ class DenseGemmKernel:
|
||||
# Subtract reserved bytes and initial C stages bytes
|
||||
# Divide remaining by bytes needed per A/B stage
|
||||
num_ab_stage = (
|
||||
num_smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes)
|
||||
smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes)
|
||||
) // ab_bytes_per_stage
|
||||
|
||||
# Refine epilogue stages:
|
||||
@ -1271,7 +1274,7 @@ class DenseGemmKernel:
|
||||
# Add remaining unused smem to epilogue
|
||||
if use_tma_store:
|
||||
num_c_stage += (
|
||||
num_smem_capacity
|
||||
smem_capacity
|
||||
- ab_bytes_per_stage * num_ab_stage
|
||||
- (occupancy + 1) * (mbar_helpers_bytes + c_bytes)
|
||||
) // ((occupancy + 1) * c_bytes_per_stage)
|
||||
@ -1309,36 +1312,6 @@ class DenseGemmKernel:
|
||||
|
||||
return grid
|
||||
|
||||
@staticmethod
|
||||
def _get_tma_atom_kind(
|
||||
atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean
|
||||
) -> Union[
|
||||
cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp
|
||||
]:
|
||||
"""
|
||||
Select the appropriate TMA copy atom based on the number of SMs and the multicast flag.
|
||||
|
||||
:param atom_sm_cnt: The number of SMs
|
||||
:type atom_sm_cnt: cutlass.Int32
|
||||
:param mcast: The multicast flag
|
||||
:type mcast: cutlass.Boolean
|
||||
|
||||
:return: The appropriate TMA copy atom kind
|
||||
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
|
||||
|
||||
:raise ValueError: If the atom_sm_cnt is invalid
|
||||
"""
|
||||
if atom_sm_cnt == 2 and mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 2 and not mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 1 and mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE)
|
||||
elif atom_sm_cnt == 1 and not mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
|
||||
|
||||
raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}")
|
||||
|
||||
@staticmethod
|
||||
def _compute_num_tmem_alloc_cols(
|
||||
tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int]
|
||||
|
||||
@ -37,6 +37,7 @@ import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@ -225,7 +226,7 @@ class PersistentDenseGemmKernel:
|
||||
self.cta_sync_bar_id = 0
|
||||
self.epilog_sync_bar_id = 1
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -297,7 +298,7 @@ class PersistentDenseGemmKernel:
|
||||
self.epi_tile,
|
||||
self.c_dtype,
|
||||
self.c_layout,
|
||||
self.num_smem_capacity,
|
||||
self.smem_capacity,
|
||||
self.occupancy,
|
||||
self.use_tma_store,
|
||||
)
|
||||
@ -389,9 +390,11 @@ class PersistentDenseGemmKernel:
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
|
||||
# Setup TMA load for A
|
||||
a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast)
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
||||
self.cluster_shape_mn, tiled_mma.thr_id
|
||||
)
|
||||
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A(
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op,
|
||||
a,
|
||||
a_smem_layout,
|
||||
@ -404,9 +407,11 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
|
||||
# Setup TMA load for B
|
||||
b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast)
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
||||
self.cluster_shape_mn, tiled_mma.thr_id
|
||||
)
|
||||
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B(
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op,
|
||||
b,
|
||||
b_smem_layout,
|
||||
@ -430,7 +435,7 @@ class PersistentDenseGemmKernel:
|
||||
cute.make_identity_layout(c.shape), self.epi_tile
|
||||
)
|
||||
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
|
||||
tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom(
|
||||
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
@ -571,12 +576,12 @@ class PersistentDenseGemmKernel:
|
||||
tmem_holding_buf = storage.tmem_holding_buf
|
||||
|
||||
# Initialize mainloop ab_pipeline (barrier) and states
|
||||
ab_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
|
||||
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
||||
ab_pipeline_consumer_group = utils.CooperativeGroup(
|
||||
utils.Agent.Thread, num_tma_producer
|
||||
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, num_tma_producer
|
||||
)
|
||||
ab_pipeline = utils.PipelineTmaUmma.create(
|
||||
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
|
||||
num_stages=self.num_ab_stage,
|
||||
producer_group=ab_pipeline_producer_group,
|
||||
@ -586,14 +591,14 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
|
||||
# Initialize acc_pipeline (barrier) and states
|
||||
acc_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
|
||||
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
num_acc_consumer_threads = len(self.epilog_warp_id) * (
|
||||
2 if use_2cta_instrs else 1
|
||||
)
|
||||
acc_pipeline_consumer_group = utils.CooperativeGroup(
|
||||
utils.Agent.Thread, num_acc_consumer_threads
|
||||
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, num_acc_consumer_threads
|
||||
)
|
||||
acc_pipeline = utils.PipelineUmmaAsync.create(
|
||||
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
||||
num_stages=self.num_acc_stage,
|
||||
producer_group=acc_pipeline_producer_group,
|
||||
@ -606,7 +611,7 @@ class PersistentDenseGemmKernel:
|
||||
if warp_idx == self.tma_warp_id:
|
||||
num_tmem_dealloc_threads = 32
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init_arrive_cnt(
|
||||
cute.arch.mbarrier_init(
|
||||
tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads
|
||||
)
|
||||
cute.arch.mbarrier_init_fence()
|
||||
@ -640,7 +645,7 @@ class PersistentDenseGemmKernel:
|
||||
#
|
||||
a_full_mcast_mask = None
|
||||
b_full_mcast_mask = None
|
||||
if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs:
|
||||
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
||||
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
||||
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
@ -651,15 +656,15 @@ class PersistentDenseGemmKernel:
|
||||
#
|
||||
# Local_tile partition global tensors
|
||||
#
|
||||
# (bM, bK, loopM, loopK, loopL)
|
||||
# (bM, bK, RestM, RestK, RestL)
|
||||
gA_mkl = cute.local_tile(
|
||||
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
|
||||
)
|
||||
# (bN, bK, loopN, loopK, loopL)
|
||||
# (bN, bK, RestN, RestK, RestL)
|
||||
gB_nkl = cute.local_tile(
|
||||
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
|
||||
)
|
||||
# (bM, bN, loopM, loopN, loopL)
|
||||
# (bM, bN, RestM, RestN, RestL)
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||||
)
|
||||
@ -669,11 +674,11 @@ class PersistentDenseGemmKernel:
|
||||
# Partition global tensor for TiledMMA_A/B/C
|
||||
#
|
||||
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
||||
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
|
||||
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
|
||||
tCgA = thr_mma.partition_A(gA_mkl)
|
||||
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
|
||||
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
|
||||
tCgB = thr_mma.partition_B(gB_nkl)
|
||||
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
|
||||
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
|
||||
tCgC = thr_mma.partition_C(gC_mnl)
|
||||
|
||||
#
|
||||
@ -684,7 +689,7 @@ class PersistentDenseGemmKernel:
|
||||
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
||||
)
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), loopM, loopK, loopL)
|
||||
# ((atom_v, rest_v), RestM, RestK, RestL)
|
||||
tAsA, tAgA = cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
block_in_cluster_coord_vmnk[2],
|
||||
@ -697,7 +702,7 @@ class PersistentDenseGemmKernel:
|
||||
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
||||
)
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), loopM, loopK, loopL)
|
||||
# ((atom_v, rest_v), RestM, RestK, RestL)
|
||||
tBsB, tBgB = cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
block_in_cluster_coord_vmnk[1],
|
||||
@ -743,12 +748,11 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
ab_producer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Producer, self.num_ab_stage
|
||||
ab_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num_ab_stage
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
@ -760,11 +764,11 @@ class PersistentDenseGemmKernel:
|
||||
#
|
||||
# Slice to per mma tile index
|
||||
#
|
||||
# ((atom_v, rest_v), loopK)
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tAgA_slice = tAgA[
|
||||
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
|
||||
]
|
||||
# ((atom_v, rest_v), loopK)
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tBgB_slice = tBgB[
|
||||
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
||||
]
|
||||
@ -779,7 +783,7 @@ class PersistentDenseGemmKernel:
|
||||
#
|
||||
# Tma load loop
|
||||
#
|
||||
for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1):
|
||||
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
|
||||
# Conditionally wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(
|
||||
ab_producer_state, peek_ab_empty_status
|
||||
@ -852,15 +856,14 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
ab_consumer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Consumer, self.num_ab_stage
|
||||
ab_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
||||
)
|
||||
acc_producer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Producer, self.num_acc_stage
|
||||
acc_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num_acc_stage
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
@ -895,7 +898,7 @@ class PersistentDenseGemmKernel:
|
||||
#
|
||||
# Mma mainloop
|
||||
#
|
||||
for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1):
|
||||
for k_block in range(k_block_cnt):
|
||||
if is_leader_cta:
|
||||
# Conditionally wait for AB buffer full
|
||||
ab_pipeline.consumer_wait(
|
||||
@ -904,7 +907,7 @@ class PersistentDenseGemmKernel:
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in range(num_kphases):
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (
|
||||
None,
|
||||
None,
|
||||
@ -989,10 +992,12 @@ class PersistentDenseGemmKernel:
|
||||
# Partition for epilogue
|
||||
#
|
||||
epi_tidx = tidx
|
||||
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = (
|
||||
self.epilog_tmem_copy_and_partition(
|
||||
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
|
||||
)
|
||||
(
|
||||
tiled_copy_t2r,
|
||||
tTR_tAcc_base,
|
||||
tTR_rAcc,
|
||||
) = self.epilog_tmem_copy_and_partition(
|
||||
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
|
||||
)
|
||||
|
||||
tTR_rC = None
|
||||
@ -1008,16 +1013,20 @@ class PersistentDenseGemmKernel:
|
||||
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
|
||||
tiled_copy_t2r, tTR_rC, epi_tidx, sC
|
||||
)
|
||||
tma_atom_c, bSG_sC, bSG_gC_partitioned = (
|
||||
self.epilog_gmem_copy_and_partition(
|
||||
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
|
||||
)
|
||||
(
|
||||
tma_atom_c,
|
||||
bSG_sC,
|
||||
bSG_gC_partitioned,
|
||||
) = self.epilog_gmem_copy_and_partition(
|
||||
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
|
||||
)
|
||||
else:
|
||||
simt_atom, tTR_rC, tTR_gC_partitioned = (
|
||||
self.epilog_gmem_copy_and_partition(
|
||||
epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC
|
||||
)
|
||||
(
|
||||
simt_atom,
|
||||
tTR_rC,
|
||||
tTR_gC_partitioned,
|
||||
) = self.epilog_gmem_copy_and_partition(
|
||||
epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC
|
||||
)
|
||||
|
||||
#
|
||||
@ -1028,25 +1037,24 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
acc_consumer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Consumer, self.num_acc_stage
|
||||
acc_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
)
|
||||
|
||||
c_pipeline = None
|
||||
if cutlass.const_expr(self.use_tma_store):
|
||||
# Threads/warps participating in tma store pipeline
|
||||
c_producer_group = utils.CooperativeGroup(
|
||||
utils.Agent.Thread,
|
||||
c_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
32 * len(self.epilog_warp_id),
|
||||
32 * len(self.epilog_warp_id),
|
||||
)
|
||||
c_pipeline = utils.PipelineTmaStore.create(
|
||||
c_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=c_producer_group,
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
@ -1105,7 +1113,7 @@ class PersistentDenseGemmKernel:
|
||||
#
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
|
||||
for subtile_idx in cutlass.range_dynamic(subtile_cnt):
|
||||
for subtile_idx in cutlass.range(subtile_cnt):
|
||||
#
|
||||
# Load accumulator from tensor memory buffer to register
|
||||
#
|
||||
@ -1259,11 +1267,11 @@ class PersistentDenseGemmKernel:
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
|
||||
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
|
||||
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
gC_mnl_epi = cute.flat_divide(
|
||||
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
||||
)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
tTR_rAcc = cute.make_fragment(
|
||||
@ -1346,7 +1354,7 @@ class PersistentDenseGemmKernel:
|
||||
- tTR_gC: The partitioned global tensor C
|
||||
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
|
||||
"""
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
gC_epi = cute.flat_divide(
|
||||
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
||||
)
|
||||
@ -1355,7 +1363,7 @@ class PersistentDenseGemmKernel:
|
||||
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
|
||||
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
@ -1366,7 +1374,7 @@ class PersistentDenseGemmKernel:
|
||||
return tma_atom_c, bSG_sC, bSG_gC
|
||||
else:
|
||||
tiled_copy_t2r = atom
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||||
tTR_gC = thr_copy_t2r.partition_D(gC_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
@ -1385,7 +1393,7 @@ class PersistentDenseGemmKernel:
|
||||
epi_tile: cute.Tile,
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
c_layout: utils.LayoutEnum,
|
||||
num_smem_capacity: int,
|
||||
smem_capacity: int,
|
||||
occupancy: int,
|
||||
use_tma_store: bool,
|
||||
) -> Tuple[int, int, int]:
|
||||
@ -1405,8 +1413,8 @@ class PersistentDenseGemmKernel:
|
||||
:type c_dtype: type[cutlass.Numeric]
|
||||
:param c_layout: Layout enum of operand C.
|
||||
:type c_layout: utils.LayoutEnum
|
||||
:param num_smem_capacity: Total available shared memory capacity in bytes.
|
||||
:type num_smem_capacity: int
|
||||
:param smem_capacity: Total available shared memory capacity in bytes.
|
||||
:type smem_capacity: int
|
||||
:param occupancy: Target number of CTAs per SM (occupancy).
|
||||
:type occupancy: int
|
||||
:param use_tma_store: Whether TMA store is enabled.
|
||||
@ -1461,7 +1469,7 @@ class PersistentDenseGemmKernel:
|
||||
# Subtract reserved bytes and initial C stages bytes
|
||||
# Divide remaining by bytes needed per A/B stage
|
||||
num_ab_stage = (
|
||||
num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
|
||||
smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
|
||||
) // ab_bytes_per_stage
|
||||
|
||||
# Refine epilogue stages:
|
||||
@ -1469,7 +1477,7 @@ class PersistentDenseGemmKernel:
|
||||
# Add remaining unused smem to epilogue
|
||||
if use_tma_store:
|
||||
num_c_stage += (
|
||||
num_smem_capacity
|
||||
smem_capacity
|
||||
- occupancy * ab_bytes_per_stage * num_ab_stage
|
||||
- occupancy * (mbar_helpers_bytes + c_bytes)
|
||||
) // (occupancy * c_bytes_per_stage)
|
||||
@ -1512,36 +1520,6 @@ class PersistentDenseGemmKernel:
|
||||
|
||||
return tile_sched_params, grid
|
||||
|
||||
@staticmethod
|
||||
def _get_tma_atom_kind(
|
||||
atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean
|
||||
) -> Union[
|
||||
cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp
|
||||
]:
|
||||
"""
|
||||
Select the appropriate TMA copy atom based on the number of SMs and the multicast flag.
|
||||
|
||||
:param atom_sm_cnt: The number of SMs
|
||||
:type atom_sm_cnt: cutlass.Int32
|
||||
:param mcast: The multicast flag
|
||||
:type mcast: cutlass.Boolean
|
||||
|
||||
:return: The appropriate TMA copy atom kind
|
||||
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
|
||||
|
||||
:raise ValueError: If the atom_sm_cnt is invalid
|
||||
"""
|
||||
if atom_sm_cnt == 2 and mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 2 and not mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 1 and mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE)
|
||||
elif atom_sm_cnt == 1 and not mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
|
||||
|
||||
raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}")
|
||||
|
||||
@staticmethod
|
||||
def _compute_num_tmem_alloc_cols(
|
||||
tiled_mma: cute.TiledMma,
|
||||
|
||||
1852
examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py
Normal file
1852
examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -40,7 +40,6 @@ import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
A grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL
|
||||
@ -89,7 +88,6 @@ there are also the following constrains:
|
||||
|
||||
|
||||
class GroupedGemmKernel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
acc_dtype: type[cutlass.Numeric],
|
||||
@ -159,7 +157,7 @@ class GroupedGemmKernel:
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
# Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion
|
||||
self.tensormap_ab_init_bar_id = 4
|
||||
self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.num_tma_load_bytes = 0
|
||||
|
||||
def _setup_attributes(self):
|
||||
@ -217,18 +215,20 @@ class GroupedGemmKernel:
|
||||
)
|
||||
|
||||
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
|
||||
self.num_acc_stage, self.num_ab_stage, self.num_epi_stage = (
|
||||
self._compute_stages(
|
||||
tiled_mma,
|
||||
self.mma_tiler,
|
||||
self.a_dtype,
|
||||
self.b_dtype,
|
||||
self.epi_tile,
|
||||
self.c_dtype,
|
||||
self.c_layout,
|
||||
self.num_smem_capacity,
|
||||
self.occupancy,
|
||||
)
|
||||
(
|
||||
self.num_acc_stage,
|
||||
self.num_ab_stage,
|
||||
self.num_epi_stage,
|
||||
) = self._compute_stages(
|
||||
tiled_mma,
|
||||
self.mma_tiler,
|
||||
self.a_dtype,
|
||||
self.b_dtype,
|
||||
self.epi_tile,
|
||||
self.c_dtype,
|
||||
self.c_layout,
|
||||
self.smem_capacity,
|
||||
self.occupancy,
|
||||
)
|
||||
|
||||
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
||||
@ -355,9 +355,11 @@ class GroupedGemmKernel:
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
|
||||
# Setup TMA load for A
|
||||
a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast)
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
||||
self.cluster_shape_mn, tiled_mma.thr_id
|
||||
)
|
||||
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A(
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op,
|
||||
initial_a,
|
||||
a_smem_layout,
|
||||
@ -367,9 +369,11 @@ class GroupedGemmKernel:
|
||||
)
|
||||
|
||||
# Setup TMA load for B
|
||||
b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast)
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
||||
self.cluster_shape_mn, tiled_mma.thr_id
|
||||
)
|
||||
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B(
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op,
|
||||
initial_b,
|
||||
b_smem_layout,
|
||||
@ -389,7 +393,7 @@ class GroupedGemmKernel:
|
||||
cute.make_identity_layout(initial_c.shape), self.epi_tile
|
||||
)
|
||||
epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0))
|
||||
tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom(
|
||||
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileS2GOp(),
|
||||
initial_c,
|
||||
epi_smem_layout,
|
||||
@ -403,9 +407,7 @@ class GroupedGemmKernel:
|
||||
self.buffer_align_bytes = 1024
|
||||
self.size_tensormap_in_i64 = (
|
||||
0
|
||||
if cutlass.const_expr(
|
||||
self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM
|
||||
)
|
||||
if self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM
|
||||
else GroupedGemmKernel.num_tensormaps
|
||||
* GroupedGemmKernel.bytes_per_tensormap
|
||||
// 8
|
||||
@ -564,16 +566,16 @@ class GroupedGemmKernel:
|
||||
for k_stage in range(self.num_ab_stage):
|
||||
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init_arrive_cnt(ab_full_mbar_ptr + k_stage, 1)
|
||||
cute.arch.mbarrier_init_arrive_cnt(
|
||||
cute.arch.mbarrier_init(ab_full_mbar_ptr + k_stage, 1)
|
||||
cute.arch.mbarrier_init(
|
||||
ab_empty_mbar_ptr + k_stage, num_tma_producer
|
||||
)
|
||||
# Accumulator barrier init
|
||||
if warp_idx == self.mma_warp_id:
|
||||
for acc_stage in range(self.num_acc_stage):
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init_arrive_cnt(acc_full_mbar_ptr + acc_stage, 1)
|
||||
cute.arch.mbarrier_init_arrive_cnt(
|
||||
cute.arch.mbarrier_init(acc_full_mbar_ptr + acc_stage, 1)
|
||||
cute.arch.mbarrier_init(
|
||||
acc_empty_mbar_ptr + acc_stage, 8 if use_2cta_instrs else 4
|
||||
)
|
||||
# Tensor memory dealloc barrier init
|
||||
@ -581,7 +583,7 @@ class GroupedGemmKernel:
|
||||
if warp_idx == self.tma_warp_id:
|
||||
num_tmem_dealloc_threads = 32
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init_arrive_cnt(
|
||||
cute.arch.mbarrier_init(
|
||||
tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads
|
||||
)
|
||||
cute.arch.mbarrier_init_fence()
|
||||
@ -612,7 +614,7 @@ class GroupedGemmKernel:
|
||||
a_full_mcast_mask = None
|
||||
b_full_mcast_mask = None
|
||||
ab_empty_mcast_mask = None
|
||||
if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs:
|
||||
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
||||
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
||||
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
@ -621,7 +623,7 @@ class GroupedGemmKernel:
|
||||
)
|
||||
ab_empty_mcast_mask = a_full_mcast_mask | b_full_mcast_mask
|
||||
acc_full_mcast_mask = None
|
||||
if use_2cta_instrs:
|
||||
if cutlass.const_expr(use_2cta_instrs):
|
||||
acc_full_mcast_mask = cute.make_layout_image_mask(
|
||||
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mode=0
|
||||
)
|
||||
@ -646,15 +648,15 @@ class GroupedGemmKernel:
|
||||
#
|
||||
# Local_tile partition global tensors
|
||||
#
|
||||
# (bM, bK, loopM, loopK, loopL)
|
||||
# (bM, bK, RestM, RestK, RestL)
|
||||
gA_mkl = cute.local_tile(
|
||||
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
|
||||
)
|
||||
# (bN, bK, loopN, loopK, loopL)
|
||||
# (bN, bK, RestN, RestK, RestL)
|
||||
gB_nkl = cute.local_tile(
|
||||
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
|
||||
)
|
||||
# (bM, bN, loopM, loopN, loopL)
|
||||
# (bM, bN, RestM, RestN, RestL)
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||||
)
|
||||
@ -663,11 +665,11 @@ class GroupedGemmKernel:
|
||||
# Partition global tensor for TiledMMA_A/B/C
|
||||
#
|
||||
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
||||
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
|
||||
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
|
||||
tCgA = thr_mma.partition_A(gA_mkl)
|
||||
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
|
||||
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
|
||||
tCgB = thr_mma.partition_B(gB_nkl)
|
||||
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
|
||||
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
|
||||
tCgC = thr_mma.partition_C(gC_mnl)
|
||||
|
||||
#
|
||||
@ -677,7 +679,7 @@ class GroupedGemmKernel:
|
||||
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
||||
)
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), loopM, loopK, loopL)
|
||||
# ((atom_v, rest_v), RestM, RestK, RestL)
|
||||
tAsA, tAgA = cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
block_in_cluster_coord_vmnk[2],
|
||||
@ -690,7 +692,7 @@ class GroupedGemmKernel:
|
||||
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
||||
)
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), loopM, loopK, loopL)
|
||||
# ((atom_v, rest_v), RestM, RestK, RestL)
|
||||
tBsB, tBgB = cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
block_in_cluster_coord_vmnk[1],
|
||||
@ -849,11 +851,11 @@ class GroupedGemmKernel:
|
||||
#
|
||||
# Slice to per mma tile index
|
||||
#
|
||||
# ((atom_v, rest_v), loopK)
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tAgA_slice = tAgA[
|
||||
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
|
||||
]
|
||||
# ((atom_v, rest_v), loopK)
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tBgB_slice = tBgB[
|
||||
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
||||
]
|
||||
@ -867,7 +869,7 @@ class GroupedGemmKernel:
|
||||
tma_wr_ab_empty_phase = (
|
||||
num_prev_k_blk + tma_wr_k_block
|
||||
) // self.num_ab_stage % 2 ^ 1
|
||||
peek_ab_empty_status = cute.arch.conditional_mbarrier_try_wait(
|
||||
peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait(
|
||||
tma_wr_k_block < cur_k_block_cnt,
|
||||
ab_empty_mbar_ptr + smem_wr_buffer,
|
||||
tma_wr_ab_empty_phase,
|
||||
@ -879,7 +881,7 @@ class GroupedGemmKernel:
|
||||
#
|
||||
# Tma load loop
|
||||
#
|
||||
for k_block in cutlass.range_dynamic(0, cur_k_block_cnt, 1, unroll=1):
|
||||
for k_block in cutlass.range(0, cur_k_block_cnt, 1, unroll=1):
|
||||
tma_wr_k_block_next = tma_wr_k_block + 1
|
||||
smem_wr_buffer_next = (
|
||||
num_prev_k_blk + tma_wr_k_block_next
|
||||
@ -898,10 +900,10 @@ class GroupedGemmKernel:
|
||||
ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase
|
||||
)
|
||||
|
||||
# Init AB buffer full transaction byte
|
||||
# Arrive AB buffer and expect full transaction bytes
|
||||
if is_leader_cta:
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init_tx_bytes(
|
||||
cute.arch.mbarrier_arrive_and_expect_tx(
|
||||
smem_full_mbar_ptr, self.num_tma_load_bytes
|
||||
)
|
||||
|
||||
@ -930,7 +932,7 @@ class GroupedGemmKernel:
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
|
||||
peek_ab_empty_status = cute.arch.conditional_mbarrier_try_wait(
|
||||
peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait(
|
||||
tma_wr_k_block_next < cur_k_block_cnt,
|
||||
ab_empty_mbar_ptr + smem_wr_buffer_next,
|
||||
tma_wr_ab_empty_phase_next,
|
||||
@ -999,11 +1001,12 @@ class GroupedGemmKernel:
|
||||
while work_tile.is_valid_tile:
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
# MMA warp is only interested in number of tiles along K dimension
|
||||
cur_k_block_cnt, cur_group_idx = (
|
||||
group_gemm_ts_helper.search_cluster_tile_count_k(
|
||||
cur_tile_coord,
|
||||
problem_sizes_mnkl,
|
||||
)
|
||||
(
|
||||
cur_k_block_cnt,
|
||||
cur_group_idx,
|
||||
) = group_gemm_ts_helper.search_cluster_tile_count_k(
|
||||
cur_tile_coord,
|
||||
problem_sizes_mnkl,
|
||||
)
|
||||
# Set tensor memory buffer for current tile
|
||||
acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage
|
||||
@ -1022,7 +1025,7 @@ class GroupedGemmKernel:
|
||||
mma_rd_ab_full_phase = (
|
||||
(num_prev_k_blk + mma_rd_k_block) // self.num_ab_stage % 2
|
||||
)
|
||||
peek_ab_full_status = cute.arch.conditional_mbarrier_try_wait(
|
||||
peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait(
|
||||
need_check_rd_buffer_full,
|
||||
ab_full_mbar_ptr + smem_rd_buffer,
|
||||
mma_rd_ab_full_phase,
|
||||
@ -1047,7 +1050,7 @@ class GroupedGemmKernel:
|
||||
#
|
||||
# Mma mainloop
|
||||
#
|
||||
for k_block in cutlass.range_dynamic(0, cur_k_block_cnt, 1, unroll=1):
|
||||
for k_block in range(cur_k_block_cnt):
|
||||
mma_rd_k_block_next = cutlass.Int32(k_block + 1)
|
||||
smem_rd_buffer_next = (
|
||||
num_prev_k_blk + mma_rd_k_block_next
|
||||
@ -1066,7 +1069,7 @@ class GroupedGemmKernel:
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in range(num_kphases):
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (None, None, kphase_idx, smem_rd_buffer)
|
||||
|
||||
cute.gemm(
|
||||
@ -1092,7 +1095,7 @@ class GroupedGemmKernel:
|
||||
mma_rd_k_block_next < cur_k_block_cnt and is_leader_cta
|
||||
)
|
||||
|
||||
peek_ab_full_status = cute.arch.conditional_mbarrier_try_wait(
|
||||
peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait(
|
||||
need_check_rd_buffer_full,
|
||||
ab_full_mbar_ptr + smem_rd_buffer_next,
|
||||
mma_rd_ab_full_phase_next,
|
||||
@ -1161,19 +1164,23 @@ class GroupedGemmKernel:
|
||||
#
|
||||
# Partition for epilogue
|
||||
#
|
||||
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = (
|
||||
self.epilog_tmem_copy_and_partition(
|
||||
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
|
||||
)
|
||||
(
|
||||
tiled_copy_t2r,
|
||||
tTR_tAcc_base,
|
||||
tTR_rAcc,
|
||||
) = self.epilog_tmem_copy_and_partition(
|
||||
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
|
||||
)
|
||||
|
||||
tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype)
|
||||
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
|
||||
tiled_copy_t2r, tTR_rC, epi_tidx, sC
|
||||
)
|
||||
tma_atom_c, bSG_sC, bSG_gC_partitioned = (
|
||||
self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC)
|
||||
)
|
||||
(
|
||||
tma_atom_c,
|
||||
bSG_sC,
|
||||
bSG_gC_partitioned,
|
||||
) = self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC)
|
||||
|
||||
#
|
||||
# Persistent tile scheduling loop
|
||||
@ -1270,7 +1277,7 @@ class GroupedGemmKernel:
|
||||
#
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
|
||||
for subtile_idx in cutlass.range_dynamic(subtile_cnt):
|
||||
for subtile_idx in range(subtile_cnt):
|
||||
#
|
||||
# Load accumulator from tensor memory buffer to register
|
||||
#
|
||||
@ -1493,11 +1500,11 @@ class GroupedGemmKernel:
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
|
||||
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
|
||||
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
gC_mnl_epi = cute.flat_divide(
|
||||
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
||||
)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
tTR_rAcc = cute.make_fragment(
|
||||
@ -1569,14 +1576,14 @@ class GroupedGemmKernel:
|
||||
- tCgC: The destination global memory tensor partitioned for the TMA operation.
|
||||
:rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
|
||||
"""
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
gC_epi = cute.flat_divide(
|
||||
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
||||
)
|
||||
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
|
||||
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL)
|
||||
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
@ -1595,7 +1602,7 @@ class GroupedGemmKernel:
|
||||
epi_tile: cute.Tile,
|
||||
c_dtype: type[cutlass.Numeric],
|
||||
c_layout: utils.LayoutEnum,
|
||||
num_smem_capacity: int,
|
||||
smem_capacity: int,
|
||||
occupancy: int,
|
||||
) -> tuple[int, int, int]:
|
||||
"""Computes the number of stages for accumulator, A/B operands, and epilogue based on heuristics.
|
||||
@ -1614,8 +1621,8 @@ class GroupedGemmKernel:
|
||||
:type c_dtype: type[cutlass.Numeric]
|
||||
:param c_layout: Layout enum of operand C in global memory.
|
||||
:type c_layout: utils.LayoutEnum
|
||||
:param num_smem_capacity: Total available shared memory capacity in bytes.
|
||||
:type num_smem_capacity: int
|
||||
:param smem_capacity: Total available shared memory capacity in bytes.
|
||||
:type smem_capacity: int
|
||||
:param occupancy: Target number of CTAs per SM (occupancy).
|
||||
:type occupancy: int
|
||||
|
||||
@ -1658,7 +1665,7 @@ class GroupedGemmKernel:
|
||||
# Subtract reserved bytes and initial epilogue bytes
|
||||
# Divide remaining by bytes needed per A/B stage
|
||||
num_ab_stage = (
|
||||
num_smem_capacity // occupancy
|
||||
smem_capacity // occupancy
|
||||
- GroupedGemmKernel.reserved_smem_bytes
|
||||
- epi_bytes
|
||||
) // ab_bytes_per_stage
|
||||
@ -1667,7 +1674,7 @@ class GroupedGemmKernel:
|
||||
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
||||
# Add remaining unused smem to epilogue
|
||||
remaining_smem = (
|
||||
num_smem_capacity
|
||||
smem_capacity
|
||||
- occupancy * ab_bytes_per_stage * num_ab_stage
|
||||
- occupancy * (GroupedGemmKernel.reserved_smem_bytes + epi_bytes)
|
||||
)
|
||||
@ -1775,20 +1782,6 @@ class GroupedGemmKernel:
|
||||
epi_bytes = cute.size_in_bytes(c_dtype, epi_smem_layout_staged)
|
||||
return ab_bytes + epi_bytes
|
||||
|
||||
@staticmethod
|
||||
def _get_tma_atom_kind(atom_sm_cnt: int, mcast: bool):
|
||||
"""Select the appropriate TMA copy atom based on the number of SMs and the multicast flag."""
|
||||
if atom_sm_cnt == 2 and mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 2 and not mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO)
|
||||
elif atom_sm_cnt == 1 and mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE)
|
||||
elif atom_sm_cnt == 1 and not mcast:
|
||||
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
|
||||
|
||||
raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}")
|
||||
|
||||
@staticmethod
|
||||
def _compute_num_tmem_alloc_cols(
|
||||
tiled_mma: cute.TiledMma,
|
||||
@ -1909,8 +1902,6 @@ def run_grouped_gemm(
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU is required to run this example!")
|
||||
|
||||
torch.manual_seed(2025)
|
||||
|
||||
# Create tensor and return the pointer, tensor, and stride
|
||||
def create_tensor_and_stride(
|
||||
l: int,
|
||||
@ -1920,42 +1911,17 @@ def run_grouped_gemm(
|
||||
dtype: type[cutlass.Numeric],
|
||||
is_dynamic_layout: bool = True,
|
||||
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
|
||||
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
||||
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
||||
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
||||
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
||||
# omit stride for L mode as it is always 1 for grouped GEMM
|
||||
strides = (1, mode0) if is_mode0_major else (mode1, 1)
|
||||
assert dtype in {cutlass.Float16, cutlass.BFloat16, cutlass.Float32}
|
||||
is_unsigned = False
|
||||
|
||||
torch_dtype = cutlass_torch.dtype(dtype)
|
||||
torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor(
|
||||
shape,
|
||||
torch_dtype,
|
||||
permute_order=permute_order,
|
||||
init_type=cutlass_torch.TensorInitType.RANDOM,
|
||||
init_config=cutlass_torch.RandomInitConfig(
|
||||
min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
|
||||
),
|
||||
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
|
||||
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
|
||||
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
|
||||
)
|
||||
torch_tensor = torch_tensor_cpu.cuda()
|
||||
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
||||
|
||||
cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
|
||||
if is_dynamic_layout:
|
||||
cute_tensor = cute_tensor.mark_layout_dynamic(
|
||||
leading_dim=(0 if is_mode0_major else 1)
|
||||
)
|
||||
cute_tensor = cutlass_torch.convert_cute_tensor(
|
||||
f32_torch_tensor,
|
||||
return (
|
||||
torch_tensor.data_ptr(),
|
||||
torch_tensor,
|
||||
cute_tensor,
|
||||
dtype,
|
||||
is_dynamic_layout=is_dynamic_layout,
|
||||
torch_tensor_cpu,
|
||||
torch_tensor.stride()[:-1],
|
||||
)
|
||||
# Get pointer of the tensor
|
||||
ptr = torch_tensor.data_ptr()
|
||||
return ptr, torch_tensor, cute_tensor, f32_torch_tensor, strides
|
||||
|
||||
# iterate all groups and create tensors for each group
|
||||
torch_fp32_tensors_abc = []
|
||||
@ -1964,15 +1930,27 @@ def run_grouped_gemm(
|
||||
strides_abc = []
|
||||
ptrs_abc = []
|
||||
for _, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
ptr_a, torch_tensor_a, cute_tensor_a, tensor_fp32_a, stride_mk_a = (
|
||||
create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
|
||||
)
|
||||
ptr_b, torch_tensor_b, cute_tensor_b, tensor_fp32_b, stride_nk_b = (
|
||||
create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
|
||||
)
|
||||
ptr_c, torch_tensor_c, cute_tensor_c, tensor_fp32_c, stride_mn_c = (
|
||||
create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
|
||||
)
|
||||
(
|
||||
ptr_a,
|
||||
torch_tensor_a,
|
||||
cute_tensor_a,
|
||||
tensor_fp32_a,
|
||||
stride_mk_a,
|
||||
) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
|
||||
(
|
||||
ptr_b,
|
||||
torch_tensor_b,
|
||||
cute_tensor_b,
|
||||
tensor_fp32_b,
|
||||
stride_nk_b,
|
||||
) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
|
||||
(
|
||||
ptr_c,
|
||||
torch_tensor_c,
|
||||
cute_tensor_c,
|
||||
tensor_fp32_c,
|
||||
stride_mn_c,
|
||||
) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
|
||||
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
|
||||
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
|
||||
torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c])
|
||||
@ -2005,19 +1983,16 @@ def run_grouped_gemm(
|
||||
)
|
||||
# Prepare tensormap buffer for each SM
|
||||
num_tensormap_buffers = sm_count
|
||||
tensormap_pytorch_tensor = (
|
||||
torch.empty(
|
||||
(
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
.fill_(0)
|
||||
.cuda()
|
||||
tensormap_shape = (
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
)
|
||||
tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like(
|
||||
torch.empty(tensormap_shape, dtype=torch.int64),
|
||||
cutlass.Int64,
|
||||
is_dynamic_layout=False,
|
||||
)
|
||||
tensormap_cute_tensor = from_dlpack(tensormap_pytorch_tensor, assumed_align=16)
|
||||
|
||||
grouped_gemm = GroupedGemmKernel(
|
||||
acc_dtype,
|
||||
@ -2027,23 +2002,30 @@ def run_grouped_gemm(
|
||||
tensormap_update_mode,
|
||||
)
|
||||
|
||||
# Convert integer list to torch tensor and cute tensor
|
||||
def convert_list_to_tensor(l, dtype) -> tuple[torch.Tensor, cute.Tensor]:
|
||||
torch_tensor = torch.tensor(l, dtype=dtype).cuda()
|
||||
cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
|
||||
return torch_tensor, cute_tensor
|
||||
|
||||
# layout (num_groups, 4):(4, 1)
|
||||
problem_sizes_mnkl_torch_tensor, problem_sizes_mnkl_cute_tensor = (
|
||||
convert_list_to_tensor(problem_sizes_mnkl, torch.int32)
|
||||
(
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_dim_size_mnkl_torch,
|
||||
) = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(problem_sizes_mnkl, dtype=torch.int32),
|
||||
cutlass.Int32,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
# layout (num_groups, 3, 2):(6, 2, 1)
|
||||
strides_abc_torch_tensor, strides_abc_cute_tensor = convert_list_to_tensor(
|
||||
strides_abc, torch.int32
|
||||
tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(strides_abc, dtype=torch.int32),
|
||||
cutlass.Int32,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
# layout (num_groups,3):(3, 1)
|
||||
ptrs_abc_torch_tensor, ptrs_abc_cute_tensor = convert_list_to_tensor(
|
||||
ptrs_abc, torch.int64
|
||||
tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(ptrs_abc, dtype=torch.int64),
|
||||
cutlass.Int64,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
# Compute total number of cluster tiles we need to compute for given grouped GEMM problem
|
||||
@ -2077,10 +2059,9 @@ def run_grouped_gemm(
|
||||
problem_sizes_mnkl, cluster_tile_shape_mn
|
||||
)
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
# Initialize Stream
|
||||
current_stream = cutlass_torch.default_stream()
|
||||
|
||||
# Compile grouped GEMM kernel
|
||||
compiled_grouped_gemm = cute.compile(
|
||||
grouped_gemm,
|
||||
@ -2088,11 +2069,11 @@ def run_grouped_gemm(
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
num_groups,
|
||||
problem_sizes_mnkl_cute_tensor,
|
||||
strides_abc_cute_tensor,
|
||||
ptrs_abc_cute_tensor,
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
total_num_clusters,
|
||||
tensormap_cute_tensor,
|
||||
tensor_of_tensormap,
|
||||
max_active_clusters,
|
||||
current_stream,
|
||||
)
|
||||
@ -2104,10 +2085,10 @@ def run_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
problem_sizes_mnkl_cute_tensor,
|
||||
strides_abc_cute_tensor,
|
||||
ptrs_abc_cute_tensor,
|
||||
tensormap_cute_tensor,
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
# Execution
|
||||
@ -2116,28 +2097,27 @@ def run_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
problem_sizes_mnkl_cute_tensor,
|
||||
strides_abc_cute_tensor,
|
||||
ptrs_abc_cute_tensor,
|
||||
tensormap_cute_tensor,
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compute reference result
|
||||
if not skip_ref_check:
|
||||
refs = []
|
||||
for a, b, _ in torch_fp32_tensors_abc:
|
||||
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
|
||||
refs.append(ref)
|
||||
for i, ((_, _, c), ref) in enumerate(zip(torch_tensors_abc, refs)):
|
||||
for i, (a, b, c) in enumerate(torch_tensors_abc):
|
||||
ref = torch.einsum(
|
||||
"mkl,nkl->mnl",
|
||||
a.cpu().to(dtype=torch.float32),
|
||||
b.cpu().to(dtype=torch.float32),
|
||||
)
|
||||
print(f"checking group {i}")
|
||||
if c_dtype == cutlass.Float32:
|
||||
ref_c = ref
|
||||
else:
|
||||
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
|
||||
torch.testing.assert_close(
|
||||
c.cpu(),
|
||||
ref_c,
|
||||
ref.to(cutlass_torch.dtype(c_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-05,
|
||||
)
|
||||
@ -2266,6 +2246,8 @@ if __name__ == "__main__":
|
||||
else:
|
||||
tensormap_update_mode = utils.TensorMapUpdateMode.SMEM
|
||||
|
||||
torch.manual_seed(2025)
|
||||
|
||||
run_grouped_gemm(
|
||||
args.num_groups,
|
||||
args.problem_sizes_mnkl,
|
||||
|
||||
3619
examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py
Normal file
3619
examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,397 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def ssd_reference_fp32_all(x, a, delta, B, C, Y_out, Fstate_out, D, has_d, d_has_hdim):
|
||||
"""
|
||||
Rearrange tensor dimensions from cuda layout to reference layout, then directly call TriDao's ssd implementation
|
||||
Arguments:
|
||||
X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L)
|
||||
A/delta: (L, C, H, B):(1, L, C*L, H*C*L)
|
||||
a: (H):(1)
|
||||
B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
|
||||
D: (1, H):(0, 1) or (D, H):(1, D)
|
||||
has_d: bool
|
||||
d_has_hdim: bool
|
||||
Return:
|
||||
Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L)
|
||||
Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N)
|
||||
"""
|
||||
assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype
|
||||
|
||||
A = delta * a.view(1, 1, -1, 1)
|
||||
X = x * delta.unsqueeze(0)
|
||||
|
||||
# Rearrange to match cutlass layout to tridao's layout
|
||||
block_len = A.shape[0]
|
||||
initial_states = None
|
||||
# A: l c h b-> b c l h
|
||||
A = A.permute(3, 1, 0, 2)
|
||||
# X: p l c h b -> b c l h p
|
||||
X = X.permute(4, 2, 1, 3, 0)
|
||||
# B: l n c g b -> b c l g n
|
||||
B = B.permute(4, 2, 0, 3, 1)
|
||||
# C: l n c g b -> b c l g n
|
||||
C = C.permute(4, 2, 0, 3, 1)
|
||||
# X/A/B/C: b c l ... -> b (c l) ...
|
||||
X, A, B, C = [x.reshape(x.shape[0], -1, *x.shape[3:]) for x in (X, A, B, C)]
|
||||
|
||||
# Ngroup (g to h) mapping
|
||||
B_val, CL_val, G_val, N_val = B.shape
|
||||
H_val = X.shape[2]
|
||||
ngroup_ratio = H_val // G_val
|
||||
# B/C: (B, CL, H, N)
|
||||
h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio
|
||||
B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
||||
C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
||||
|
||||
###################################################################
|
||||
# Call reference implementation from Tri Dao ssd_minimal_discrete
|
||||
Y, final_state = ssd_minimal_discrete_fp32_all(
|
||||
X, A, B, C, block_len, initial_states
|
||||
)
|
||||
###################################################################
|
||||
|
||||
if has_d:
|
||||
D_val = Y.shape[3]
|
||||
if not d_has_hdim:
|
||||
D = D.expand(D_val, -1)
|
||||
Y = Y + torch.einsum("bchp,ph->bchp", X, D)
|
||||
|
||||
# Rearrange to match tridao's layout to cutlass layout
|
||||
# Y: b (c l) h p -> b c l h p
|
||||
Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3])
|
||||
# Y: b c l h p -> l p c h b
|
||||
Y = Y.permute(2, 4, 1, 3, 0)
|
||||
# Fstate_out: b h p n -> p n h b
|
||||
Fstate_out.copy_(final_state.permute(2, 3, 1, 0))
|
||||
Y_out.copy_(Y)
|
||||
return
|
||||
|
||||
|
||||
def ssd_reference_lowprecision_intermediates(
|
||||
x, a, delta, B, C, Y_out, Fstate_out, intermediate_dtype, D, has_d, d_has_hdim
|
||||
):
|
||||
"""
|
||||
Rearrange tensor dimensions from cuda layout to reference layout, then call a reduced intermediate dtype version of ssd implementation
|
||||
Arguments:
|
||||
X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L)
|
||||
A/delta: (L, C, H, B):(1, L, C*L, H*C*L)
|
||||
a: (H):(1)
|
||||
B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
|
||||
intermediate_dtype: input and intermediate data type
|
||||
D: (1, H):(0, 1) or (D, H):(1, D)
|
||||
has_d: bool
|
||||
d_has_hdim: bool
|
||||
Return:
|
||||
Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L)
|
||||
Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N)
|
||||
"""
|
||||
assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype
|
||||
|
||||
A = delta * a.view(1, 1, -1, 1)
|
||||
|
||||
# Rearrange to match cutlass layout to tridao's layout
|
||||
block_len = A.shape[0]
|
||||
initial_states = None
|
||||
# A: l c h b-> b c l h
|
||||
A = A.permute(3, 1, 0, 2)
|
||||
# delta: l c h b-> b c l h
|
||||
delta = delta.permute(3, 1, 0, 2)
|
||||
# x: p l c h b -> b c l h p
|
||||
x = x.permute(4, 2, 1, 3, 0)
|
||||
# B: l n c g b -> b c l g n
|
||||
B = B.permute(4, 2, 0, 3, 1)
|
||||
# C: l n c g b -> b c l g n
|
||||
C = C.permute(4, 2, 0, 3, 1)
|
||||
# x/A/delta/B/C: b c l ... -> b (c l) ...
|
||||
x, A, delta, B, C = [
|
||||
tensor.reshape(tensor.shape[0], -1, *tensor.shape[3:])
|
||||
for tensor in (x, A, delta, B, C)
|
||||
]
|
||||
|
||||
# Ngroup (g to h) mapping
|
||||
B_val, CL_val, G_val, N_val = B.shape
|
||||
H_val = x.shape[2]
|
||||
ngroup_ratio = H_val // G_val
|
||||
# B/C: (B, CL, H, N)
|
||||
h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio
|
||||
B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
||||
C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
||||
|
||||
# Type convert input tensors to input dtype (same as intermediate dtype)
|
||||
x = x.to(intermediate_dtype).to(torch.float32)
|
||||
A = A.to(intermediate_dtype).to(torch.float32)
|
||||
delta = delta.to(intermediate_dtype).to(torch.float32)
|
||||
B = B.to(intermediate_dtype).to(torch.float32)
|
||||
C = C.to(intermediate_dtype).to(torch.float32)
|
||||
|
||||
#########################################################################
|
||||
# Call reference implementation ssd_minimal_discrete_bf16_intermediates
|
||||
Y, final_state = ssd_minimal_discrete_lowprecision_intermediates(
|
||||
x, A, delta, B, C, block_len, intermediate_dtype, initial_states
|
||||
)
|
||||
#########################################################################
|
||||
|
||||
if has_d:
|
||||
D = D.to(intermediate_dtype).to(torch.float32)
|
||||
D_val = Y.shape[3]
|
||||
if not d_has_hdim:
|
||||
D = D.expand(D_val, -1)
|
||||
Y = Y + torch.einsum("bchp,ph->bchp", x, D)
|
||||
|
||||
# Type convert output tensors to output dtype (same as intermediate dtype)
|
||||
Y = Y.to(intermediate_dtype).to(torch.float32)
|
||||
final_state = final_state.to(intermediate_dtype).to(torch.float32)
|
||||
|
||||
# Rearrange to match tridao's layout to cutlass layout
|
||||
# Y: b (c l) h p -> b c l h p
|
||||
Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3])
|
||||
# Y: b c l h p -> l p c h b
|
||||
Y = Y.permute(2, 4, 1, 3, 0)
|
||||
# Fstate_out: b h p n -> p n h b
|
||||
Fstate_out.copy_(final_state.permute(2, 3, 1, 0))
|
||||
Y_out.copy_(Y)
|
||||
return
|
||||
|
||||
|
||||
def analyze_relative_diffs(actual, expected):
|
||||
"""
|
||||
Print statistics of relative differences between actual and expected tensors
|
||||
"""
|
||||
# Calculate relative differences
|
||||
abs_diff = (actual - expected).abs()
|
||||
rel_diff = abs_diff / (torch.maximum(expected.abs(), actual.abs()) + 0.00001)
|
||||
|
||||
total_elements = rel_diff.numel()
|
||||
|
||||
# Handle special cases first
|
||||
nan_mask = torch.isnan(rel_diff)
|
||||
inf_mask = torch.isinf(rel_diff)
|
||||
nan_count = nan_mask.sum().item()
|
||||
inf_count = inf_mask.sum().item()
|
||||
|
||||
# Find position and value of maximum relative difference
|
||||
max_rel_diff = (
|
||||
rel_diff[~nan_mask & ~inf_mask].max()
|
||||
if (~nan_mask & ~inf_mask).any()
|
||||
else float("nan")
|
||||
)
|
||||
max_rel_diff_pos = (
|
||||
rel_diff[~nan_mask & ~inf_mask].argmax()
|
||||
if (~nan_mask & ~inf_mask).any()
|
||||
else -1
|
||||
)
|
||||
|
||||
# Print max relative difference info
|
||||
print(f"Maximum relative difference:")
|
||||
print(f"Position: {max_rel_diff_pos}")
|
||||
print(f"Value: {max_rel_diff:.6e}")
|
||||
print(f"Actual value: {actual.flatten()[max_rel_diff_pos]}")
|
||||
print(f"Expected value: {expected.flatten()[max_rel_diff_pos]}")
|
||||
print(f"NaN values: {nan_count} ({100.0 * nan_count / total_elements:.2f}%)")
|
||||
print(f"Inf values: {inf_count} ({100.0 * inf_count / total_elements:.2f}%)\n")
|
||||
|
||||
# Check different rtol thresholds
|
||||
rtol_levels = [1e-5, 1e-4, 1e-3, 1e-2, 5e-02, 1e-01]
|
||||
|
||||
for i, rtol in enumerate(rtol_levels):
|
||||
if i == 0:
|
||||
mask = rel_diff <= rtol
|
||||
else:
|
||||
mask = (rel_diff <= rtol) & (rel_diff > rtol_levels[i - 1])
|
||||
|
||||
count = mask.sum().item()
|
||||
percentage = (count / total_elements) * 100
|
||||
|
||||
if i == 0:
|
||||
print(f"Elements with rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)")
|
||||
else:
|
||||
print(
|
||||
f"Elements with {rtol_levels[i-1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)"
|
||||
)
|
||||
|
||||
# Print elements exceeding the largest rtol
|
||||
mask = rel_diff > rtol_levels[-1]
|
||||
count = mask.sum().item()
|
||||
percentage = (count / total_elements) * 100
|
||||
print(f"Elements with rtol > {rtol_levels[-1]:.0e}: {count} ({percentage:.2f}%)\n")
|
||||
|
||||
|
||||
def segsum(x):
|
||||
"""
|
||||
More stable segment sum calculation.
|
||||
x: b h c l
|
||||
"""
|
||||
T = x.size(-1)
|
||||
# x: b h c l -> b h c l l
|
||||
x = x.unsqueeze(-1).expand(*x.shape, T)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
|
||||
def ssd_minimal_discrete_fp32_all(X, A, B, C, block_len, initial_states=None):
|
||||
"""
|
||||
This is same with https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py
|
||||
(all accumulation and intermediate results in fp32)
|
||||
|
||||
Arguments:
|
||||
X: (batch(B), length(C*L), n_heads(H), d_head(D))
|
||||
A: (batch(B), length(C*L), n_heads(H))
|
||||
B: (batch(B), length(C*L), n_heads(H), d_state(N))
|
||||
C: (batch(B), length(C*L), n_heads(H), d_state(N))
|
||||
Return:
|
||||
Y: (batch(B), length(C*L), n_heads(H), d_head(D))
|
||||
final_state: (B, H, D, N)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
# X/A/B/C:b (c l) ... -> b c l ...
|
||||
X, A, B, C = [
|
||||
x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, B, C)
|
||||
]
|
||||
|
||||
# A: b c l h -> b h c l
|
||||
A = A.permute(0, 3, 1, 2)
|
||||
# A_cumsum: (B, H, C, L)
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
segsum_A = segsum(A)
|
||||
L = torch.exp(segsum_A)
|
||||
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
||||
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
||||
# Y: b c l h p -> b (c l) h p
|
||||
Y = (Y_diag + Y_off).reshape(Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4])
|
||||
return Y, final_state
|
||||
|
||||
|
||||
def ssd_minimal_discrete_lowprecision_intermediates(
|
||||
X, A, delta, B, C, block_len, intermediate_dtype, initial_states=None
|
||||
):
|
||||
"""
|
||||
This is adjusted from ssd_minimal_discrete_fp32_all, with exceptions:
|
||||
1. accumulation in fp32 but intermediates Q/b_tmem/P are in intermediate_dtype
|
||||
2. delta is not pre-multiplied with X, delta was applied to generate Q/b_tmem to match GPU implementation
|
||||
|
||||
Arguments:
|
||||
X: (batch(B), length(C*L), n_heads(H), d_head(D))
|
||||
A: (batch(B), length(C*L), n_heads(H))
|
||||
delta: (batch(B), length(C*L), n_heads(H))
|
||||
B: (batch(B), length(C*L), n_heads(H), d_state(N))
|
||||
C: (batch(B), length(C*L), n_heads(H), d_state(N))
|
||||
Return:
|
||||
Y: (batch(B), length(C*L), n_heads(H), d_head(D))
|
||||
final_state: (B, H, D, N)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
# X/A/delta/B/C: b (c l) ... -> b c l ...
|
||||
X, A, delta, B, C = [
|
||||
x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, delta, B, C)
|
||||
]
|
||||
|
||||
# A: b c l h -> b h c l
|
||||
A = A.permute(0, 3, 1, 2)
|
||||
# delta: b c l h -> b h c l
|
||||
delta = delta.permute(0, 3, 1, 2)
|
||||
# A_cumsum: (B, H, C, L)
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
segsum_A = segsum(A)
|
||||
L = torch.exp(segsum_A)
|
||||
intra_acc_0 = torch.einsum("bclhn,bcshn->bclhs", C, B)
|
||||
Q = torch.einsum("bclhs,bhcls,bhcs->bclhs", intra_acc_0, L, delta)
|
||||
Y_diag = torch.einsum(
|
||||
"bclhs,bcshp->bclhp", Q.to(intermediate_dtype).to(torch.float32), X
|
||||
)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
||||
b_tmem = torch.einsum("bclhn,bhcl,bhcl->bclhn", B, decay_states, delta)
|
||||
states = torch.einsum(
|
||||
"bclhn,bclhp->bchpn", b_tmem.to(intermediate_dtype).to(torch.float32), X
|
||||
)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
final_state = final_state
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off_tmp = torch.einsum(
|
||||
"bclhn,bchpn->bclhp", C, states.to(intermediate_dtype).to(torch.float32)
|
||||
)
|
||||
Y_off = torch.einsum("bclhp,bhcl->bclhp", Y_off_tmp, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
||||
# Y: b c l h p -> b (c l) h p
|
||||
Y = (Y_diag + Y_off).reshape(
|
||||
Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4]
|
||||
) # b (c l) h p
|
||||
return Y, final_state
|
||||
@ -0,0 +1,200 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
Boolean,
|
||||
Integer,
|
||||
Int32,
|
||||
min,
|
||||
extract_mlir_values,
|
||||
new_from_mlir_values,
|
||||
dsl_user_op,
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
import cutlass.cute as cute
|
||||
from cutlass.utils import WorkTileInfo
|
||||
|
||||
|
||||
class Mamba2SSDTileSchedulerParams:
|
||||
def __init__(
|
||||
self,
|
||||
problem_shape_ntiles: int,
|
||||
eh: int,
|
||||
ngroup_ratio: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
self.problem_shape_ntiles = problem_shape_ntiles
|
||||
self.eh = eh
|
||||
self.ngroup_ratio = ngroup_ratio
|
||||
self._loc = loc
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
values, self._values_pos = [], []
|
||||
for obj in [self.problem_shape_ntiles, self.eh, self.ngroup_ratio]:
|
||||
obj_values = extract_mlir_values(obj)
|
||||
values += obj_values
|
||||
self._values_pos.append(len(obj_values))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
obj_list = []
|
||||
for obj, n_items in zip(
|
||||
[self.problem_shape_ntiles, self.eh, self.ngroup_ratio], self._values_pos
|
||||
):
|
||||
obj_list.append(new_from_mlir_values(obj, values[:n_items]))
|
||||
values = values[n_items:]
|
||||
return Mamba2SSDTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
|
||||
|
||||
@dsl_user_op
|
||||
def get_grid_shape(
|
||||
self, max_active_clusters: Int32, *, loc=None, ip=None
|
||||
) -> Tuple[Integer, Integer, Integer]:
|
||||
return (min(self.problem_shape_ntiles, max_active_clusters), 1, 1)
|
||||
|
||||
|
||||
class Mamba2SSDTileScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
params: Mamba2SSDTileSchedulerParams,
|
||||
num_persistent_ctas: Int32,
|
||||
current_work_linear_idx: Int32,
|
||||
num_tiles_executed: Int32,
|
||||
):
|
||||
self.params = params
|
||||
self.num_persistent_ctas = num_persistent_ctas
|
||||
self._current_work_linear_idx = current_work_linear_idx
|
||||
self._num_tiles_executed = num_tiles_executed
|
||||
|
||||
def __extract_mlir_values__(self) -> list[ir.Value]:
|
||||
values = extract_mlir_values(self.num_persistent_ctas)
|
||||
values.extend(extract_mlir_values(self._current_work_linear_idx))
|
||||
values.extend(extract_mlir_values(self._num_tiles_executed))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(
|
||||
self, values: list[ir.Value]
|
||||
) -> "Mamba2SSDTileScheduler":
|
||||
assert len(values) == 3
|
||||
new_num_persistent_ctas = new_from_mlir_values(
|
||||
self.num_persistent_ctas, [values[0]]
|
||||
)
|
||||
new_current_work_linear_idx = new_from_mlir_values(
|
||||
self._current_work_linear_idx, [values[1]]
|
||||
)
|
||||
new_num_tiles_executed = new_from_mlir_values(
|
||||
self._num_tiles_executed, [values[2]]
|
||||
)
|
||||
return Mamba2SSDTileScheduler(
|
||||
self.params,
|
||||
new_num_persistent_ctas,
|
||||
new_current_work_linear_idx,
|
||||
new_num_tiles_executed,
|
||||
)
|
||||
|
||||
# called by host
|
||||
@dsl_user_op
|
||||
@staticmethod
|
||||
def create(
|
||||
params: Mamba2SSDTileSchedulerParams,
|
||||
block_idx: Tuple[Integer, Integer, Integer],
|
||||
grid_dim: Tuple[Integer, Integer, Integer],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
params = params
|
||||
|
||||
# Calculate the number of persistent clusters by dividing the total grid size
|
||||
# by the number of CTAs per cluster
|
||||
num_persistent_ctas = Int32(cute.size(grid_dim, loc=loc, ip=ip))
|
||||
|
||||
bidx, bidy, bidz = block_idx
|
||||
|
||||
# Initialize workload index equals to the cluster index in the grid
|
||||
current_work_linear_idx = Int32(bidx)
|
||||
|
||||
# Initialize number of tiles executed to zero
|
||||
num_tiles_executed = Int32(0)
|
||||
return Mamba2SSDTileScheduler(
|
||||
params,
|
||||
num_persistent_ctas,
|
||||
current_work_linear_idx,
|
||||
num_tiles_executed,
|
||||
)
|
||||
|
||||
# called by host
|
||||
@staticmethod
|
||||
def get_grid_shape(
|
||||
params: Mamba2SSDTileSchedulerParams,
|
||||
max_active_clusters: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Integer, Integer, Integer]:
|
||||
return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip)
|
||||
|
||||
# private method
|
||||
def _get_current_work_for_linear_idx(
|
||||
self, current_work_linear_idx: Int32, *, loc=None, ip=None
|
||||
) -> WorkTileInfo:
|
||||
is_valid = current_work_linear_idx < cute.size(
|
||||
self.params.problem_shape_ntiles, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
eh_idx = current_work_linear_idx % self.params.eh
|
||||
b_idx = current_work_linear_idx // self.params.eh
|
||||
g_idx = eh_idx // self.params.ngroup_ratio
|
||||
# cur_tile_coord is (b_idx, eh_idx, g_idx)
|
||||
cur_tile_coord = tuple(Int32(x) for x in (b_idx, eh_idx, g_idx))
|
||||
|
||||
return WorkTileInfo(cur_tile_coord, is_valid)
|
||||
|
||||
@dsl_user_op
|
||||
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
return self._get_current_work_for_linear_idx(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
return self.get_current_work(loc=loc, ip=ip)
|
||||
|
||||
@dsl_user_op
|
||||
def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
|
||||
self._current_work_linear_idx += Int32(advance_count) * Int32(
|
||||
self.num_persistent_ctas
|
||||
)
|
||||
self._num_tiles_executed += Int32(1)
|
||||
|
||||
@property
|
||||
def num_tiles_executed(self) -> Int32:
|
||||
return self._num_tiles_executed
|
||||
@ -30,7 +30,7 @@ cmake_minimum_required(VERSION 3.15)
|
||||
project(tensor)
|
||||
|
||||
# Find Python
|
||||
find_package(Python COMPONENTS Interpreter Development REQUIRED)
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
|
||||
# Get Python site-packages directory using Python
|
||||
execute_process(
|
||||
|
||||
@ -36,6 +36,7 @@ import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.utils.hopper_helpers as sm90_utils
|
||||
@ -579,20 +580,25 @@ class HopperWgmmaGemmKernel:
|
||||
mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr()
|
||||
|
||||
# Threads/warps participating in this pipeline
|
||||
mainloop_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread)
|
||||
# Set the consumer arrive count to the number of mcast size
|
||||
consumer_arrive_cnt = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
||||
mainloop_pipeline_consumer_group = utils.CooperativeGroup(
|
||||
utils.Agent.Thread, consumer_arrive_cnt
|
||||
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread
|
||||
)
|
||||
# Each warp will constribute to the arrive count with the number of mcast size
|
||||
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
||||
num_warps = self.threads_per_cta // 32
|
||||
consumer_arrive_cnt = mcast_size * num_warps
|
||||
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, consumer_arrive_cnt
|
||||
)
|
||||
|
||||
mainloop_pipeline = utils.PipelineTmaAsync.create(
|
||||
cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
|
||||
mainloop_pipeline = pipeline.PipelineTmaAsync.create(
|
||||
barrier_storage=mainloop_pipeline_array_ptr,
|
||||
num_stages=self.ab_stage,
|
||||
producer_group=mainloop_pipeline_producer_group,
|
||||
consumer_group=mainloop_pipeline_consumer_group,
|
||||
tx_count=tma_copy_bytes,
|
||||
cta_layout_vmnk=cta_layout_mnk,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
)
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
@ -616,11 +622,11 @@ class HopperWgmmaGemmKernel:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Local_tile partition global tensors
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# (bM, bK, loopK)
|
||||
# (bM, bK, RestK)
|
||||
gA_mkl = cute.local_tile(
|
||||
mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
|
||||
)
|
||||
# (bN, bK, loopK)
|
||||
# (bN, bK, RestK)
|
||||
gB_nkl = cute.local_tile(
|
||||
mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
|
||||
)
|
||||
@ -696,14 +702,14 @@ class HopperWgmmaGemmKernel:
|
||||
k_tile_cnt = cute.size(gA_mkl, mode=[2])
|
||||
prefetch_k_tile_cnt = cutlass.max(cutlass.min(self.ab_stage, k_tile_cnt), 0)
|
||||
|
||||
mainloop_producer_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Producer, self.ab_stage
|
||||
mainloop_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.ab_stage
|
||||
)
|
||||
if warp_idx == 0:
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch TMA load
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
for prefetch_idx in cutlass.range_dynamic(prefetch_k_tile_cnt, unroll=1):
|
||||
for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1):
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
# Wait for A/B buffers to be empty before loading into them
|
||||
# Also sets the transaction barrier for the A/B buffers
|
||||
@ -748,11 +754,11 @@ class HopperWgmmaGemmKernel:
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
k_pipe_mmas = 1
|
||||
|
||||
mainloop_consumer_read_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Consumer, self.ab_stage
|
||||
mainloop_consumer_read_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.ab_stage
|
||||
)
|
||||
mainloop_consumer_release_state = utils.make_pipeline_state(
|
||||
utils.PipelineUserType.Consumer, self.ab_stage
|
||||
mainloop_consumer_release_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.ab_stage
|
||||
)
|
||||
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
@ -763,14 +769,14 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_tile in cutlass.range_dynamic(k_pipe_mmas, unroll=1):
|
||||
for k_tile in range(k_pipe_mmas):
|
||||
# Wait for A/B buffer to be ready
|
||||
mainloop_pipeline.consumer_wait(
|
||||
mainloop_consumer_read_state, peek_ab_full_status
|
||||
)
|
||||
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
for k_block_idx in range(num_k_blocks):
|
||||
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
||||
k_block_coord = (
|
||||
None,
|
||||
None,
|
||||
@ -800,7 +806,7 @@ class HopperWgmmaGemmKernel:
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
# MAINLOOP
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
for k_tile in cutlass.range_dynamic(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
|
||||
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
# Wait for TMA copies to complete
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
@ -811,7 +817,7 @@ class HopperWgmmaGemmKernel:
|
||||
# WGMMA
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
for k_block_idx in range(num_k_blocks):
|
||||
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
||||
k_block_coord = (
|
||||
None,
|
||||
None,
|
||||
@ -949,7 +955,7 @@ class HopperWgmmaGemmKernel:
|
||||
epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1])
|
||||
epi_tile_shape = tcgc_for_tma_partition.shape[1]
|
||||
|
||||
for epi_idx in cutlass.range_dynamic(epi_tile_num, unroll=epi_tile_num):
|
||||
for epi_idx in cutlass.range(epi_tile_num, unroll=epi_tile_num):
|
||||
# Copy from accumulators to D registers
|
||||
for epi_v in range(size_tRS_rD):
|
||||
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
|
||||
@ -1213,7 +1219,7 @@ class HopperWgmmaGemmKernel:
|
||||
c_cta_v_layout = cute.composition(
|
||||
cute.make_identity_layout(tensor_c.shape), epi_tile
|
||||
)
|
||||
tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tma_tile_atom(
|
||||
tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
tensor_c,
|
||||
epi_smem_layout,
|
||||
@ -1250,7 +1256,7 @@ class HopperWgmmaGemmKernel:
|
||||
)
|
||||
|
||||
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
|
||||
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tma_tile_atom(
|
||||
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
op,
|
||||
tensor,
|
||||
smem_layout,
|
||||
|
||||
@ -297,7 +297,7 @@
|
||||
" assert depth <= 1, f\"Depth of coalesced layout should be <= 1, got {depth}\"\n",
|
||||
"\n",
|
||||
" print(\">>> 3. Checking layout functionality remains the same after the coalesce operation:\")\n",
|
||||
" for i in range(original_size):\n",
|
||||
" for i in cutlass.range_constexpr(original_size):\n",
|
||||
" original_value = layout(i)\n",
|
||||
" coalesced_value = result(i)\n",
|
||||
" print(f\"Index {i}: original {original_value}, coalesced {coalesced_value}\")\n",
|
||||
|
||||
@ -60,48 +60,7 @@
|
||||
"@cute.jit\n",
|
||||
"def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n",
|
||||
" ...\n",
|
||||
"```\n",
|
||||
"To differentiate between compile-time and runtime values, CuTe DSL introduces primitive types that \n",
|
||||
"represent dynamic values in JIT-compiled code.\n",
|
||||
"\n",
|
||||
"CuTe DSL provides a comprehensive set of primitive numeric types for representing dynamic values at \n",
|
||||
"runtime. These types are formally defined within the CuTe DSL typing system:\n",
|
||||
"\n",
|
||||
"### Integer Types\n",
|
||||
"- `Int8` - 8-bit signed integer\n",
|
||||
"- `Int16` - 16-bit signed integer \n",
|
||||
"- `Int32` - 32-bit signed integer\n",
|
||||
"- `Int64` - 64-bit signed integer\n",
|
||||
"- `Int128` - 128-bit signed integer\n",
|
||||
"- `Uint8` - 8-bit unsigned integer\n",
|
||||
"- `Uint16` - 16-bit unsigned integer\n",
|
||||
"- `Uint32` - 32-bit unsigned integer\n",
|
||||
"- `Uint64` - 64-bit unsigned integer\n",
|
||||
"- `Uint128` - 128-bit unsigned integer\n",
|
||||
"\n",
|
||||
"### Floating Point Types\n",
|
||||
"- `Float16` - 16-bit floating point\n",
|
||||
"- `Float32` - 32-bit floating point \n",
|
||||
"- `Float64` - 64-bit floating point\n",
|
||||
"- `BFloat16` - Brain Floating Point format (16-bit)\n",
|
||||
"- `TFloat32` - Tensor Float32 format (reduced precision format used in tensor operations)\n",
|
||||
"- `Float8E4M3` - 8-bit floating point with 4-bit exponent and 3-bit mantissa\n",
|
||||
"- `Float8E5M2` - 8-bit floating point with 5-bit exponent and 2-bit mantissa\n",
|
||||
"\n",
|
||||
"These specialized types are designed to represent dynamic values in CuTe DSL code that will be \n",
|
||||
"evaluated at runtime, in contrast to Python's built-in numeric types which are evaluated during \n",
|
||||
"compilation.\n",
|
||||
"\n",
|
||||
"### Example usage:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"x = cutlass.Int32(5) # Creates a 32-bit integer\n",
|
||||
"y = cutlass.Float32(3.14) # Creates a 32-bit float\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n",
|
||||
" ...\n",
|
||||
"```"
|
||||
"```\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -120,7 +120,7 @@
|
||||
" src_vec = src.load()\n",
|
||||
" dst_vec = src_vec[indices]\n",
|
||||
" print(f\"{src_vec} -> {dst_vec}\")\n",
|
||||
" if isinstance(dst_vec, cute.TensorSSA):\n",
|
||||
" if cutlass.const_expr(isinstance(dst_vec, cute.TensorSSA)):\n",
|
||||
" dst.store(dst_vec)\n",
|
||||
" cute.print_tensor(dst)\n",
|
||||
" else:\n",
|
||||
|
||||
Reference in New Issue
Block a user