v4.2 tag release. (#2638)

This commit is contained in:
Junkai-Wu
2025-09-16 00:21:53 +08:00
committed by GitHub
parent 56f0718a97
commit 6a35b4d22f
161 changed files with 14056 additions and 3793 deletions

View File

@ -0,0 +1,314 @@
import os
import torch
import argparse
from cuda import cuda
from cuda.bindings import driver
from typing import Type
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.multiprocessing as mp
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from cutlass._mlir.dialects import llvm, builtin, vector, arith
WORLD_SIZE = 8
PING_PONG_SIZE = 3
def setup(rank, world_size):
# set environment variables for torch.distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12959"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
class AllReduceKernel:
@cute.jit
def __call__(
self,
rank,
signal,
local_input: cute.Tensor,
local_output: cute.Tensor,
buffer0: cute.Tensor,
buffer1: cute.Tensor,
buffer2: cute.Tensor,
buffer3: cute.Tensor,
buffer4: cute.Tensor,
buffer5: cute.Tensor,
buffer6: cute.Tensor,
buffer7: cute.Tensor,
stream: cuda.CUstream,
):
# define constants for future use
num_of_elements = cute.size(local_input.layout)
# 128 threads per block and 4 elements per thread
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
tile = cute.size(tv_layout.shape)
buffers = [
buffer0,
buffer1,
buffer2,
buffer3,
buffer4,
buffer5,
buffer6,
buffer7,
]
tiled_buffers = [
cute.logical_divide(buffer, (tile, None, None)) for buffer in buffers
]
tiled_input = cute.zipped_divide(local_input, cute.make_layout(tile))
tiled_output = cute.zipped_divide(local_output, cute.make_layout(tile))
self.kernel(
tiled_buffers[0],
tiled_buffers[1],
tiled_buffers[2],
tiled_buffers[3],
tiled_buffers[4],
tiled_buffers[5],
tiled_buffers[6],
tiled_buffers[7],
tiled_input,
tiled_output,
tv_layout,
cutlass.Int32(signal),
cutlass.Int32(rank),
).launch(
grid=[num_of_elements // tile, 1, 1],
block=[tv_layout.shape[0], 1, 1],
stream=stream,
)
# GPU device kernel
@cute.kernel
def kernel(
self,
buffer0: cute.Tensor,
buffer1: cute.Tensor,
buffer2: cute.Tensor,
buffer3: cute.Tensor,
buffer4: cute.Tensor,
buffer5: cute.Tensor,
buffer6: cute.Tensor,
buffer7: cute.Tensor,
local_input: cute.Tensor,
local_output: cute.Tensor,
tv_layout: cute.Layout,
signal: cutlass.Int32,
rank: cutlass.Int32,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
ping = signal % 3
pong = (signal + 1) % 3
buffers = [
buffer0,
buffer1,
buffer2,
buffer3,
buffer4,
buffer5,
buffer6,
buffer7,
]
def get_buffer():
t = buffers[2]
if rank == cutlass.Int32(0):
t = buffers[0]
if rank == cutlass.Int32(1):
t = buffers[1]
if rank == cutlass.Int32(2):
t = buffers[2]
if rank == cutlass.Int32(3):
t = buffers[3]
if rank == cutlass.Int32(4):
t = buffers[4]
if rank == cutlass.Int32(5):
t = buffers[5]
if rank == cutlass.Int32(6):
t = buffers[6]
if rank == cutlass.Int32(7):
t = buffers[7]
return t
buffer_local = get_buffer()
cta_coord = (None, bidx)
local_tile_in = local_input[cta_coord]
local_tile_out = local_output[cta_coord]
ping_coord = ((None, bidx), None, ping)
read_buffer = buffer_local[ping_coord]
pong_coord = ((None, bidx), None, pong)
clear_buffer = buffer_local[pong_coord]
write_coord = ((None, bidx), rank, ping)
write_buffers = [buffer[write_coord] for buffer in buffers]
# assume all buffers have the same element type with input
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
buffer0.element_type,
num_bits_per_copy=64,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
buffer0.element_type,
num_bits_per_copy=64,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
)
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
thr_copy = tiled_copy.get_slice(tidx)
thr_write_buffer_list = [
thr_copy.partition_D(tensor) for tensor in write_buffers
]
thr_read_buffer = thr_copy.partition_S(read_buffer)
thr_clear_buffer = thr_copy.partition_D(clear_buffer)
thr_in = thr_copy.partition_S(local_tile_in)
thr_out = thr_copy.partition_D(local_tile_out)
frg_in = cute.make_fragment_like(thr_in)
frg_clear = cute.make_fragment_like(thr_clear_buffer)
frg_acc = cute.make_fragment_like(thr_out)
frg_acc.fill(0.0)
clear_tensor = frg_clear.load()
frg_size = cute.size(clear_tensor.shape)
neg0_i32_vec = cute.full_like(clear_tensor, 0x80000000, cutlass.Int32)
neg0_f32_vec = vector.bitcast(T.vector(frg_size, T.f32()), neg0_i32_vec)
neg0_f32_tensor = cute.TensorSSA(
neg0_f32_vec, clear_tensor.shape, cutlass.Float32
)
frg_clear.store(neg0_f32_tensor)
cute.copy(copy_atom_load, thr_in, frg_in)
for thr_write_buffer in thr_write_buffer_list:
cute.copy(copy_atom_store, frg_in, thr_write_buffer)
cute.copy(copy_atom_store, frg_clear, thr_clear_buffer)
frg_in_vector_neg0_i32 = cute.full_like(
frg_in, cutlass.Int32(0x80000000), cutlass.Int32
)
frg_in_size = cute.size(frg_in.shape)
for i in range(WORLD_SIZE):
read_coord = (None, 0, i)
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
frg_vector = frg_in.load()
frg_vector_i32 = vector.bitcast(T.vector(frg_in_size, T.i32()), frg_vector)
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
while not isNotNeg0:
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
frg_vector = frg_in.load()
frg_vector_i32 = vector.bitcast(
T.vector(frg_in_size, T.i32()), frg_vector
)
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
frg_acc.store(frg_in.load() + frg_acc.load())
cute.copy(copy_atom_stg, frg_acc, thr_out)
def run_all_reduce(rank, M, N, dtype: Type[cutlass.Numeric]):
setup(rank, WORLD_SIZE)
input_tensor = torch.randn(M * N, device=f"cuda:{rank}")
output_tensor = torch.zeros(M * N, device=f"cuda:{rank}")
# init tensors on different devices
t = symm_mem.empty(
[
PING_PONG_SIZE,
WORLD_SIZE,
M * N,
],
device="cuda",
).neg_()
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
buffer_tensor_list = [
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 1, 0)
for rank in range(WORLD_SIZE)
]
# enable peer access
driver.cuInit(0)
dev_list = [driver.cuDeviceGet(i)[1] for i in range(WORLD_SIZE)]
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
for i in range(WORLD_SIZE):
driver.cuCtxSetCurrent(ctx_list[i])
for j in range(WORLD_SIZE):
if i == j:
continue
driver.cuCtxEnablePeerAccess(ctx_list[j], 0)
driver.cuCtxSetCurrent(ctx_list[rank])
stream = cutlass.cuda.default_stream()
all_reduce_kernel = AllReduceKernel()
dlpack_buffers = [from_dlpack(x, assumed_align=32) for x in buffer_tensor_list]
all_reduce_kernel(
rank,
0,
from_dlpack(input_tensor, assumed_align=32),
from_dlpack(output_tensor, assumed_align=32),
*dlpack_buffers,
stream,
)
torch.cuda.synchronize(0)
# use torch api to get reference and inplace stored to input_tensor
ref_tensor = input_tensor.clone()
dist.all_reduce(ref_tensor, op=dist.ReduceOp.SUM)
# check result of output tensor, allow small error due to different accumulator datatypes
equal_mask = (ref_tensor.cpu() - output_tensor.cpu()).abs() < 1e-4
result = (equal_mask.sum()).item() == ref_tensor.numel()
if result:
print(f"rank {rank} test passed")
else:
print(f"rank {rank} test failed")
print(
"ref_tensor[ref_tensor != output_tensor]: ",
ref_tensor[ref_tensor != output_tensor],
)
print(
"output_tensor[ref_tensor != output_tensor]: ",
output_tensor[ref_tensor != output_tensor],
)
cleanup()
def main():
M = 1024
N = 1024
# each process will run run_all_reduce on different device
mp.spawn(run_all_reduce, args=(M, N, cutlass.Float32), nprocs=WORLD_SIZE, join=True)
return
if __name__ == "__main__":
main()

View File

@ -0,0 +1,166 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import sys
import os
from typing import Tuple
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import make_ptr
"""
An Example demonstrating how to call off-the-shelf kernel by-passing dlpack protocol
The example shows how to directly pass pointers from PyTorch tensors to off-the-shelf kernels
written by CuTe DSL with a thin customized wrapper jit function. The jit function will be
compiled with inline without introducing overhead.
To run this example:
.. code-block:: bash
python examples/ampere/call_bypass_dlpack.py
It's worth to mention that by-passing dlpack protocol can resolve the issue that dlpack doesn't handle shape-1
mode correctly. For example, the following code will fail, because dlpack will convert the shape-1 mode
with stride-1 which propagate alignment incorrectly.
.. code-block:: python
@cute.kernel
def fails_kernel(gX: cute.Tensor):
bidx, _, _ = cute.arch.block_idx()
mX = gX[None, bidx, None] # We wish to retain alignment
# assert mX.iterator.alignment == 16
@cute.jit
def fails(gX_: cute.Tensor):
gX = gX_
fails_kernel(gX).launch(grid=(1, 1, 1), block=(128, 1, 1))
gX_torch = torch.rand((128, 1, 128), device="cuda", dtype=torch.bfloat16)
fails(from_dlpack(gX_torch, assumed_align=16))
"""
# Add the current directory to sys.path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from tensorop_gemm import TensorOpGemm
@cute.jit
def tensor_op_gemm_wrapper(
a_ptr: cute.Pointer,
b_ptr: cute.Pointer,
c_ptr: cute.Pointer,
m: cutlass.Int32,
n: cutlass.Int32,
k: cutlass.Int32,
l: cutlass.Int32,
):
print(f"\n[DSL INFO] Input Parameters:")
print(f"[DSL INFO] mnkl: {(m, n, k, l)}")
# Assume alignment of shape to call tensorop_gemm example
m = cute.assume(m, divby=8)
n = cute.assume(n, divby=8)
# Torch is row major
a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2))
b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2))
c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2))
mA = cute.make_tensor(a_ptr, layout=a_layout)
mB = cute.make_tensor(b_ptr, layout=b_layout)
mC = cute.make_tensor(c_ptr, layout=c_layout)
print(f"[DSL INFO] mA: {mA}")
print(f"[DSL INFO] mB: {mB}")
print(f"[DSL INFO] mC: {mC}")
tensor_op_gemm = TensorOpGemm(
a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1)
)
print(f"\n[DSL INFO] Created TensorOpGemm instance")
print(f"[DSL INFO] Input dtype: {a_ptr.value_type}")
print(f"[DSL INFO] Output dtype: {c_ptr.value_type}")
print(f"[DSL INFO] Accumulation dtype: {cutlass.Float32}")
print(f"[DSL INFO] Atom layout: {(2, 2, 1)}")
# No need to compile inside jit function
tensor_op_gemm(mA, mB, mC)
print(f"\n[DSL INFO] Executed TensorOpGemm")
def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
print(f"\nRunning TensorOpGemm test with:")
print(f"Tensor dimensions: {mnkl}")
# (M,K,L)
a = torch.randn(
mnkl[3], mnkl[2], mnkl[0], dtype=torch.float16, device="cuda"
).permute(2, 1, 0)
# (N,K,L)
b = torch.randn(
mnkl[3], mnkl[2], mnkl[1], dtype=torch.float16, device="cuda"
).permute(2, 1, 0)
# (N,M,L)
c = torch.randn(
mnkl[3], mnkl[0], mnkl[1], dtype=torch.float16, device="cuda"
).permute(1, 2, 0)
print(f"Input tensor shapes:")
print(f"a: {a.shape}, dtype: {a.dtype}")
print(f"b: {b.shape}, dtype: {b.dtype}")
print(f"c: {c.shape}, dtype: {c.dtype}\n")
a_ptr = make_ptr(
cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
b_ptr = make_ptr(
cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
c_ptr = make_ptr(
cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, *mnkl)
torch.cuda.synchronize()
ref = torch.einsum("mkl,nkl->mnl", a, b)
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
print(f"\n[DSL INFO] Results verified successfully!")
print(f"First few elements of result: \n{c[:3, :3, :3]}")
if __name__ == "__main__":
run_tensor_op_gemm_wrapper((512, 256, 128, 16))

View File

@ -226,15 +226,15 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
print(f"c: {c.shape}, dtype: {c.dtype}\n")
buffer_a = BufferWithLayout(
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem),
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
(2, 1, 0),
)
buffer_b = BufferWithLayout(
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem),
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
(2, 1, 0),
)
buffer_c = BufferWithLayout(
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem),
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
(2, 1, 0),
)

View File

@ -0,0 +1,189 @@
import os
import torch
import argparse
from typing import Type
from cuda.bindings import driver
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.multiprocessing as mp
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
def setup(rank, world_size):
# set environment variables for torch.distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12995"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
@cute.kernel
def vector_add_kernel(
g0: cute.Tensor,
g1: cute.Tensor,
g2: cute.Tensor,
g3: cute.Tensor,
g4: cute.Tensor,
g5: cute.Tensor,
g6: cute.Tensor,
g7: cute.Tensor,
gOut: cute.Tensor,
tv_layout: cute.Layout,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
cta_coord = (None, bidx)
local_tile_out = gOut[cta_coord]
local_tile_list = [
g0[cta_coord],
g1[cta_coord],
g2[cta_coord],
g3[cta_coord],
g4[cta_coord],
g5[cta_coord],
g6[cta_coord],
g7[cta_coord],
]
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
g0.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
g0.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
thr_copy = tiled_copy.get_slice(tidx)
thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list]
thr_out = thr_copy.partition_D(local_tile_out)
frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list]
frg_acc = cute.make_fragment_like(thr_out)
frg_acc.fill(0.0)
for thr, frg in zip(thr_tensor_list, frg_tensor_list):
cute.copy(copy_atom_load, thr, frg)
tmp = frg.load() + frg_acc.load()
frg_acc.store(tmp)
cute.copy(copy_atom_store, frg_acc, thr_out)
@cute.jit
def vector_add(
m0: cute.Tensor,
m1: cute.Tensor,
m2: cute.Tensor,
m3: cute.Tensor,
m4: cute.Tensor,
m5: cute.Tensor,
m6: cute.Tensor,
m7: cute.Tensor,
output: cute.Tensor,
):
# define constants for future use
num_of_elements = cute.size(m0.layout)
# 128 threads per block and 4 elements per thread
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
tile = cute.size(tv_layout.shape)
tensors = [m0, m1, m2, m3, m4, m5, m6, m7]
divided_tensors = [
cute.zipped_divide(tensor, cute.make_layout(tile)) for tensor in tensors
]
gOut = cute.zipped_divide(output, cute.make_layout(tile)) # ((Tile),(Rest))
vector_add_kernel(
divided_tensors[0],
divided_tensors[1],
divided_tensors[2],
divided_tensors[3],
divided_tensors[4],
divided_tensors[5],
divided_tensors[6],
divided_tensors[7],
gOut,
tv_layout,
).launch(
grid=[num_of_elements // tile, 1, 1],
block=[tv_layout.shape[0], 1, 1],
)
def run_vector_add(rank, world_size, M, N, dtype: Type[cutlass.Numeric]):
setup(rank, world_size)
t = symm_mem.empty(M * N, device="cuda")
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
# get tensors from other devices from the symmetric memory
tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)]
tensor_list[rank].random_(0, 100)
# enable peer access
driver.cuInit(0)
dev_list = [driver.cuDeviceGet(i)[1] for i in range(world_size)]
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
driver.cuCtxSetCurrent(ctx_list[rank])
for i in range(world_size):
if i == rank:
continue
driver.cuCtxEnablePeerAccess(ctx_list[i], 0)
output = torch.zeros(M * N, device=f"cuda:{rank}")
# we have to explicitly pass each tensor instead of a list of tensors
vector_add(
from_dlpack(tensor_list[0], assumed_align=32),
from_dlpack(tensor_list[1], assumed_align=32),
from_dlpack(tensor_list[2], assumed_align=32),
from_dlpack(tensor_list[3], assumed_align=32),
from_dlpack(tensor_list[4], assumed_align=32),
from_dlpack(tensor_list[5], assumed_align=32),
from_dlpack(tensor_list[6], assumed_align=32),
from_dlpack(tensor_list[7], assumed_align=32),
from_dlpack(output, assumed_align=32),
)
sum_tensor = sum([tensor.cpu() for tensor in tensor_list])
if sum(sum_tensor.cpu() == output.cpu()) == sum_tensor.numel():
print("test passed")
else:
print("test failed")
print(sum_tensor.cpu())
print(output.cpu())
cleanup()
def main():
world_size = torch.cuda.device_count()
M = 1024
N = 1024
# each process will run run_vector_add on different device
mp.spawn(
run_vector_add,
args=(world_size, M, N, cutlass.Float32),
nprocs=world_size,
join=True,
)
return
if __name__ == "__main__":
main()

View File

@ -0,0 +1,114 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass.cute as cute
import cutlass
"""
Example of automatic shared memory size computation for configuring kernel launch
This example demonstrates how to let the DSL automatically set shared memory
size for a kernel launch rather explicitly configuring it at launch time,
provided that developers are using `SmemAllocator` for all allocations.
Usage:
python dynamic_smem_size.py # Show auto inference
"""
@cute.struct
class SharedData:
"""A struct to demonstrate shared memory allocation."""
values: cute.struct.MemRange[cutlass.Float32, 64] # 256 bytes
counter: cutlass.Int32 # 4 bytes
flag: cutlass.Int8 # 1 byte
@cute.kernel
def kernel():
"""
Example kernel that allocates shared memory.
The total allocation will be automatically calculated when smem=None.
"""
allocator = cutlass.utils.SmemAllocator()
# Allocate various types of shared memory
shared_data = allocator.allocate(SharedData)
raw_buffer = allocator.allocate(512, byte_alignment=64)
int_array = allocator.allocate_array(element_type=cutlass.Int32, num_elems=128)
tensor_smem = allocator.allocate_tensor(
element_type=cutlass.Float16,
layout=cute.make_layout((32, 16)),
byte_alignment=16,
swizzle=None,
)
return
@cute.kernel
def kernel_no_smem():
"""
Example kernel that does not allocates shared memory.
The total allocation will be automatically calculated as 0 when smem=None.
"""
tidx, _, _ = cute.arch.block_idx()
if tidx == 0:
cute.printf("Hello world")
return
if __name__ == "__main__":
# Initialize CUDA context
cutlass.cuda.initialize_cuda_context()
print("Launching kernel with auto smem size. (launch config `smem=None`)")
# Compile the example
@cute.jit
def launch_kernel1():
k = kernel()
k.launch(
grid=(1, 1, 1),
block=(1, 1, 1),
)
print(f"Kernel recorded internal smem usage: {k.smem_usage()}")
@cute.jit
def launch_kernel2():
k = kernel_no_smem()
k.launch(
grid=(1, 1, 1),
block=(1, 1, 1),
)
print(f"Kernel recorded internal smem usage: {k.smem_usage()}")
cute.compile(launch_kernel1)
cute.compile(launch_kernel2)
print("PASS")

View File

@ -327,7 +327,6 @@ class FlashAttentionForwardAmpere:
).launch(
grid=grid_dim,
block=[self._num_threads, 1, 1],
smem=SharedStorage.size_in_bytes(),
stream=stream,
)
@ -1014,13 +1013,10 @@ class FlashAttentionForwardAmpere:
)
# compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e))
acc_S_row_exp = cute.TensorSSA(
self._exp2f(
acc_S_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2
),
tuple(acc_S_row.shape),
cutlass.Float32,
acc_S_row_exp = cute.math.exp2(
acc_S_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2,
fastmath=True,
)
# acc_S_row_sum => f32
acc_S_row_sum = acc_S_row_exp.reduce(
@ -1028,9 +1024,10 @@ class FlashAttentionForwardAmpere:
)
# if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum.
if cutlass.const_expr(not is_first_n_block):
prev_minus_cur_exp = self._exp2f(
prev_minus_cur_exp = cute.math.exp2(
row_max_prev_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2,
fastmath=True,
)
acc_S_row_sum = (
acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp
@ -1141,26 +1138,6 @@ class FlashAttentionForwardAmpere:
"""
return self._threadquad_reduce(val, lambda x, y: x + y)
def _exp2f(
self, x: Union[cute.TensorSSA, cutlass.Float32]
) -> Union[cute.TensorSSA, cutlass.Float32]:
"""exp2f calculation for both vector and scalar.
:param x: input value
:type x: cute.TensorSSA or cutlass.Float32
:return: exp2 value
:rtype: cute.TensorSSA or cutlass.Float32
"""
if isinstance(x, cute.TensorSSA):
res = cute.make_fragment(x.shape, cutlass.Float32)
res.store(x)
for i in range(cute.size(x.shape)):
res[i] = self._exp2f(res[i])
return res.load()
return cute.arch.exp2(x)
def run(
dtype: Type[cutlass.Numeric],

View File

@ -136,10 +136,6 @@ class SGemm:
stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)),
)
smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes(
mB.element_type, sB_layout
)
# ///////////////////////////////////////////////////////////////////////////////
# Create copy layouts that will be used for asynchronous
# global memory -> shared memory copies:
@ -258,7 +254,6 @@ class SGemm:
).launch(
grid=grid_dim,
block=[cute.size(atoms_layout), 1, 1],
smem=smem_size,
stream=stream,
)
@ -738,14 +733,20 @@ def run(
print("Compiling kernel with cute.compile ...")
start_time = time.time()
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream)
compiled_fn = cute.compile(
sgemm,
a_tensor,
b_tensor,
c_tensor,
stream=current_stream,
)
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
print("Executing GEMM kernel...")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
compiled_fn(a_tensor, b_tensor, c_tensor)
torch.cuda.synchronize()
print("Verifying results...")
ref = torch.einsum("mk,nk->mn", a, b)
@ -804,7 +805,7 @@ def run(
)
avg_time_us = testing.benchmark(
gemm,
compiled_fn,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
@ -837,6 +838,7 @@ if __name__ == "__main__":
parser.add_argument("--c_major", choices=["n", "m"], default="n")
parser.add_argument("--warmup_iterations", default=2, type=int)
parser.add_argument("--iterations", default=100, type=int)
parser.add_argument("--static_shape", action="store_true")
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument(
"--use_cold_l2",

View File

@ -69,7 +69,7 @@ class complex:
class SharedStorage:
# struct elements with natural alignment
a: cute.struct.MemRange[cutlass.Float32, 32] # array
b: cutlass.Int64 # scalar
b: cutlass.Int64 # saclar
c: complex # nested struct
# struct elements with strict alignment
x: cute.struct.Align[

View File

@ -471,7 +471,7 @@ class TensorOpGemm:
cute.arch.sync_threads()
# Start async loads for the first k-tile. Here we take care of the k residue
# via if/else check along the k dimension. Because we shifted the identity tensor
# by the residue_k and because the identity tensor is a counting tensor, the
# by the residue_k and because the identity tensor is a coord tensor, the
# values of any identity tensor element that is poison is less than -1
num_smem_stages = cute.size(tAsA, mode=[3])
k_tile_count = cute.size(tAgA, mode=[3])
@ -683,7 +683,7 @@ class TensorOpGemm:
# Copy results of D back to shared memory
cute.autovec_copy(tCrD, tCsC)
# Create counting tensor for C
# Create coord tensor for C
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
mcC = cute.make_identity_tensor(
(

View File

@ -610,7 +610,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
return
@ -797,7 +796,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_block_cnt = cute.size(gA_mkl, mode=[3])
k_tile_cnt = cute.size(gA_mkl, mode=[3])
#
# Partition global tensor for TiledMMA_A/B/C
@ -946,17 +945,17 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
ab_producer_state.reset_count()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
#
# Tma load loop
#
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(
ab_producer_state, peek_ab_empty_status
@ -992,10 +991,10 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
mcast_mask=sfb_full_mcast_mask,
)
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
ab_producer_state.advance()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
@ -1103,10 +1102,10 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
# Peek (try_wait) AB buffer full for k_block = 0
# Peek (try_wait) AB buffer full for k_tile = 0
ab_consumer_state.reset_count()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state
)
@ -1125,7 +1124,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
#
# Mma mainloop
#
for k_block in range(k_block_cnt):
for k_tile in range(k_tile_cnt):
if is_leader_cta:
# Conditionally wait for AB buffer full
ab_pipeline.consumer_wait(
@ -1154,44 +1153,44 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
)
# tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (
None,
None,
kphase_idx,
kblock_idx,
ab_consumer_state.index,
)
# Set SFA/SFB tensor to tiled_mma
sf_kphase_coord = (None, None, kphase_idx)
sf_kblock_coord = (None, None, kblock_idx)
tiled_mma.set(
tcgen05.Field.SFA,
tCtSFA[sf_kphase_coord].iterator,
tCtSFA[sf_kblock_coord].iterator,
)
tiled_mma.set(
tcgen05.Field.SFB,
tCtSFB[sf_kphase_coord].iterator,
tCtSFB[sf_kblock_coord].iterator,
)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kphase_coord],
tCrB[kphase_coord],
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kphase
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
ab_pipeline.consumer_release(ab_consumer_state)
# Peek (try_wait) AB buffer full for k_block = k_block + 1
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
ab_consumer_state.advance()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_block_cnt:
if ab_consumer_state.count < k_tile_cnt:
if is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state

View File

@ -486,7 +486,6 @@ class DenseGemmKernel:
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
return
@ -660,7 +659,7 @@ class DenseGemmKernel:
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_block_cnt = cute.size(gA_mkl, mode=[3])
k_tile_cnt = cute.size(gA_mkl, mode=[3])
#
# Partition global tensor for TiledMMA_A/B/C
@ -788,19 +787,19 @@ class DenseGemmKernel:
#
# Pipelining TMA load A/B and MMA mainloop
#
prefetch_k_block_cnt = cutlass.min(self.num_ab_stage - 2, k_block_cnt)
prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt)
if warp_idx == 0:
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
#
# Prefetch TMA load A/B
#
for prefetch_idx in cutlass.range(prefetch_k_block_cnt, unroll=1):
for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1):
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
@ -820,27 +819,27 @@ class DenseGemmKernel:
mcast_mask=b_full_mcast_mask,
)
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
ab_producer_state.advance()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
# Peek (try_wait) AB buffer full for k_block = 0
# Peek (try_wait) AB buffer full for k_tile = 0
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
#
# MMA mainloop
#
for k_block in range(k_block_cnt):
for k_tile in range(k_tile_cnt):
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
# TMA load A/B
cute.copy(
tma_atom_a,
@ -862,35 +861,35 @@ class DenseGemmKernel:
ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
# tCtAcc += tCrA * tCrB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (None, None, kphase_idx, ab_consumer_state.index)
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (None, None, kblock_idx, ab_consumer_state.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kphase_coord],
tCrB[kphase_coord],
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kphase
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
ab_pipeline.consumer_release(ab_consumer_state)
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
ab_producer_state.advance()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
# Peek (try_wait) AB buffer full for k_block = k_block + 1
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
ab_consumer_state.advance()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_block_cnt:
if ab_consumer_state.count < k_tile_cnt:
if is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state
@ -1009,8 +1008,8 @@ class DenseGemmKernel:
# Wait A/B buffer empty
#
if warp_idx == 0:
# Reverse prefetch_k_block_cnt times to next available buffer
for i in range(prefetch_k_block_cnt):
# Reverse prefetch_k_tile_cnt times to next available buffer
for i in range(prefetch_k_tile_cnt):
ab_producer_state.reverse()
ab_pipeline.producer_tail(ab_producer_state)
return

View File

@ -510,7 +510,6 @@ class PersistentDenseGemmKernel:
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
return
@ -669,7 +668,7 @@ class PersistentDenseGemmKernel:
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_block_cnt = cute.size(gA_mkl, mode=[3])
k_tile_cnt = cute.size(gA_mkl, mode=[3])
#
# Partition global tensor for TiledMMA_A/B/C
@ -774,17 +773,17 @@ class PersistentDenseGemmKernel:
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
ab_producer_state.reset_count()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
#
# Tma load loop
#
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(
ab_producer_state, peek_ab_empty_status
@ -806,10 +805,10 @@ class PersistentDenseGemmKernel:
mcast_mask=b_full_mcast_mask,
)
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
ab_producer_state.advance()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
@ -877,10 +876,10 @@ class PersistentDenseGemmKernel:
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
# Peek (try_wait) AB buffer full for k_block = 0
# Peek (try_wait) AB buffer full for k_tile = 0
ab_consumer_state.reset_count()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state
)
@ -899,7 +898,7 @@ class PersistentDenseGemmKernel:
#
# Mma mainloop
#
for k_block in range(k_block_cnt):
for k_tile in range(k_tile_cnt):
if is_leader_cta:
# Conditionally wait for AB buffer full
ab_pipeline.consumer_wait(
@ -907,32 +906,32 @@ class PersistentDenseGemmKernel:
)
# tCtAcc += tCrA * tCrB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (
None,
None,
kphase_idx,
kblock_idx,
ab_consumer_state.index,
)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kphase_coord],
tCrB[kphase_coord],
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kphase
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
ab_pipeline.consumer_release(ab_consumer_state)
# Peek (try_wait) AB buffer full for k_block = k_block + 1
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
ab_consumer_state.advance()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_block_cnt:
if ab_consumer_state.count < k_tile_cnt:
if is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state

View File

@ -110,17 +110,6 @@ Constraints:
"""
class PipelineStateMinimal:
"""
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
"""
def __init__(self, count, index, phase):
self.count = count
self.index = index
self.phase = phase
class DenseGemmKernel:
"""
This class implements batched matrix multiplication (C = A x B) with support for various data types
@ -497,7 +486,6 @@ class DenseGemmKernel:
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
return
@ -576,13 +564,19 @@ class DenseGemmKernel:
pipeline.Agent.Thread, num_tma_producer
)
ab_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=ab_pipeline_producer_group,
consumer_group=ab_pipeline_consumer_group,
tx_count=self.num_tma_load_bytes,
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
cta_layout_vmnk=cluster_layout_vmnk,
)
ab_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_ab_stage
)
ab_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
# Initialize acc_pipeline (barrier) and states
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
@ -590,10 +584,10 @@ class DenseGemmKernel:
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
)
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=acc_pipeline_producer_group,
consumer_group=acc_pipeline_consumer_group,
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
cta_layout_vmnk=cluster_layout_vmnk,
)
acc_producer_state = pipeline.make_pipeline_state(
@ -665,7 +659,7 @@ class DenseGemmKernel:
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_block_cnt = cute.size(gA_mkl, mode=[3])
k_tile_cnt = cute.size(gA_mkl, mode=[3])
#
# Partition global tensor for TiledMMA_A/B/C
@ -793,24 +787,12 @@ class DenseGemmKernel:
# ///////////////////////////////////////////////////////////////////////////////
# MAINLOOP
# ///////////////////////////////////////////////////////////////////////////////
prefetch_k_block_cnt = cutlass.min(self.num_ab_stage - 2, k_block_cnt)
prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt)
if warp_idx == 0:
for k_block in cutlass.range(
k_block_cnt,
pipelining=self.num_ab_stage - 2,
for k_tile in cutlass.range(
k_tile_cnt,
prefetch_stages=self.num_ab_stage - 2,
):
ab_producer_state = PipelineStateMinimal(
k_block,
k_block % self.num_ab_stage,
cutlass.Int32((k_block // self.num_ab_stage) % 2) ^ 1,
)
ab_consumer_state = PipelineStateMinimal(
k_block,
k_block % self.num_ab_stage,
cutlass.Int32((k_block // self.num_ab_stage) % 2),
)
# wait for AB buffer empty
ab_pipeline.producer_acquire(ab_producer_state)
@ -835,22 +817,26 @@ class DenseGemmKernel:
ab_pipeline.consumer_wait(ab_consumer_state)
# tCtAcc += tCrA * tCrB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (None, None, kphase_idx, ab_consumer_state.index)
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (None, None, kblock_idx, ab_consumer_state.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kphase_coord],
tCrB[kphase_coord],
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kphase
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
ab_pipeline.consumer_release(ab_consumer_state)
ab_producer_state.advance()
ab_consumer_state.advance()
# Async arrive accumulator buffer full
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
@ -964,12 +950,10 @@ class DenseGemmKernel:
# Wait A/B buffer empty
#
if warp_idx == 0:
ab_producer_state = PipelineStateMinimal(
k_block_cnt,
k_block_cnt % self.num_ab_stage,
cutlass.Int32((k_block_cnt // self.num_ab_stage) % 2) ^ 1,
)
ab_pipeline.producer_acquire(ab_producer_state)
# Reverse prefetch_k_tile_cnt times to next available buffer
for i in range(prefetch_k_tile_cnt):
ab_producer_state.reverse()
ab_pipeline.producer_tail(ab_producer_state)
return
def epilog_tmem_copy_and_partition(
@ -1579,7 +1563,6 @@ def run_dense_gemm(
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
measure_launch_overhead=False,
):
"""
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
@ -1725,7 +1708,7 @@ def run_dense_gemm(
ref_c = ref
elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
# m major: (l, n, m) -> (m, n, l)
# k major: (l, m, n) -> (m, n, l)
# n major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(

File diff suppressed because it is too large Load Diff

View File

@ -475,7 +475,6 @@ class GroupedGemmKernel:
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
return
@ -785,7 +784,7 @@ class GroupedGemmKernel:
)
tensormap_init_done = cutlass.Boolean(False)
# tile count we have searched
total_k_block_cnt = cutlass.Int32(0)
total_k_tile_cnt = cutlass.Int32(0)
# group index of last tile
last_group_idx = cutlass.Int32(-1)
work_tile = tile_sched.initial_work_tile_info()
@ -795,7 +794,7 @@ class GroupedGemmKernel:
cur_tile_coord,
problem_sizes_mnkl,
)
cur_k_block_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
cur_group_idx = grouped_gemm_cta_tile_info.group_idx
is_group_changed = cur_group_idx != last_group_idx
# skip tensormap update if we're working on the same group
@ -861,17 +860,17 @@ class GroupedGemmKernel:
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
num_prev_k_blk = total_k_block_cnt
total_k_block_cnt += cur_k_block_cnt
num_prev_k_blk = total_k_tile_cnt
total_k_tile_cnt += cur_k_tile_cnt
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
tma_wr_k_block = cutlass.Int32(0)
smem_wr_buffer = (num_prev_k_blk + tma_wr_k_block) % self.num_ab_stage
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
tma_wr_k_tile = cutlass.Int32(0)
smem_wr_buffer = (num_prev_k_blk + tma_wr_k_tile) % self.num_ab_stage
tma_wr_ab_empty_phase = (
num_prev_k_blk + tma_wr_k_block
num_prev_k_blk + tma_wr_k_tile
) // self.num_ab_stage % 2 ^ 1
peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait(
tma_wr_k_block < cur_k_block_cnt,
tma_wr_k_tile < cur_k_tile_cnt,
ab_empty_mbar_ptr + smem_wr_buffer,
tma_wr_ab_empty_phase,
)
@ -882,10 +881,10 @@ class GroupedGemmKernel:
#
# Tma load loop
#
for k_block in cutlass.range(0, cur_k_block_cnt, 1, unroll=1):
tma_wr_k_block_next = tma_wr_k_block + 1
for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1):
tma_wr_k_tile_next = tma_wr_k_tile + 1
smem_wr_buffer_next = (
num_prev_k_blk + tma_wr_k_block_next
num_prev_k_blk + tma_wr_k_tile_next
) % self.num_ab_stage
tma_wr_ab_empty_phase_next = (
tma_wr_ab_empty_phase ^ 1
@ -911,7 +910,7 @@ class GroupedGemmKernel:
# Load A/B with TMA
cute.copy(
tma_atom_a,
tAgA_slice[(None, tma_wr_k_block)],
tAgA_slice[(None, tma_wr_k_tile)],
tAsA[(None, smem_wr_buffer)],
tma_bar_ptr=smem_full_mbar_ptr,
mcast_mask=a_full_mcast_mask,
@ -922,7 +921,7 @@ class GroupedGemmKernel:
)
cute.copy(
tma_atom_b,
tBgB_slice[(None, tma_wr_k_block)],
tBgB_slice[(None, tma_wr_k_tile)],
tBsB[(None, smem_wr_buffer)],
tma_bar_ptr=smem_full_mbar_ptr,
mcast_mask=b_full_mcast_mask,
@ -932,14 +931,14 @@ class GroupedGemmKernel:
),
)
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait(
tma_wr_k_block_next < cur_k_block_cnt,
tma_wr_k_tile_next < cur_k_tile_cnt,
ab_empty_mbar_ptr + smem_wr_buffer_next,
tma_wr_ab_empty_phase_next,
)
tma_wr_k_block = tma_wr_k_block_next
tma_wr_k_tile = tma_wr_k_tile_next
smem_wr_buffer = smem_wr_buffer_next
tma_wr_ab_empty_phase = tma_wr_ab_empty_phase_next
@ -998,12 +997,12 @@ class GroupedGemmKernel:
work_tile = tile_sched.initial_work_tile_info()
# tile count we have searched
total_k_block_cnt = cutlass.Int32(0)
total_k_tile_cnt = cutlass.Int32(0)
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
# MMA warp is only interested in number of tiles along K dimension
(
cur_k_block_cnt,
cur_k_tile_cnt,
cur_group_idx,
) = group_gemm_ts_helper.search_cluster_tile_count_k(
cur_tile_coord,
@ -1014,17 +1013,17 @@ class GroupedGemmKernel:
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_buf_idx)]
num_prev_k_blk = total_k_block_cnt
total_k_block_cnt += cur_k_block_cnt
num_prev_k_blk = total_k_tile_cnt
total_k_tile_cnt += cur_k_tile_cnt
# Peek (try_wait) AB buffer full for k_block = 0
mma_rd_k_block = cutlass.Int32(0)
smem_rd_buffer = (num_prev_k_blk + mma_rd_k_block) % self.num_ab_stage
# Peek (try_wait) AB buffer full for k_tile = 0
mma_rd_k_tile = cutlass.Int32(0)
smem_rd_buffer = (num_prev_k_blk + mma_rd_k_tile) % self.num_ab_stage
need_check_rd_buffer_full = (
mma_rd_k_block < cur_k_block_cnt and is_leader_cta
mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta
)
mma_rd_ab_full_phase = (
(num_prev_k_blk + mma_rd_k_block) // self.num_ab_stage % 2
(num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2
)
peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait(
need_check_rd_buffer_full,
@ -1051,10 +1050,10 @@ class GroupedGemmKernel:
#
# Mma mainloop
#
for k_block in range(cur_k_block_cnt):
mma_rd_k_block_next = cutlass.Int32(k_block + 1)
for k_tile in range(cur_k_tile_cnt):
mma_rd_k_tile_next = cutlass.Int32(k_tile + 1)
smem_rd_buffer_next = (
num_prev_k_blk + mma_rd_k_block_next
num_prev_k_blk + mma_rd_k_tile_next
) % self.num_ab_stage
mma_rd_ab_full_phase_next = (
mma_rd_ab_full_phase ^ 1
@ -1069,18 +1068,18 @@ class GroupedGemmKernel:
)
# tCtAcc += tCrA * tCrB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (None, None, kphase_idx, smem_rd_buffer)
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (None, None, kblock_idx, smem_rd_buffer)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kphase_coord],
tCrB[kphase_coord],
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kphase
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
@ -1091,9 +1090,9 @@ class GroupedGemmKernel:
self.cta_group,
)
# Peek (try_wait) AB buffer full for k_block = k_block + 1
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
need_check_rd_buffer_full = (
mma_rd_k_block_next < cur_k_block_cnt and is_leader_cta
mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta
)
peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait(
@ -1102,7 +1101,7 @@ class GroupedGemmKernel:
mma_rd_ab_full_phase_next,
)
mma_rd_k_block = mma_rd_k_block_next
mma_rd_k_tile = mma_rd_k_tile_next
smem_rd_buffer = smem_rd_buffer_next
mma_rd_ab_full_phase = mma_rd_ab_full_phase_next
@ -1201,7 +1200,7 @@ class GroupedGemmKernel:
# wait tensormap initialization complete before update
tensormap_manager.fence_tensormap_initialization()
# tile count we have searched
total_k_block_cnt = cutlass.Int32(0)
total_k_tile_cnt = cutlass.Int32(0)
# group index of last tile
last_group_idx = cutlass.Int32(-1)
while work_tile.is_valid_tile:
@ -1240,8 +1239,8 @@ class GroupedGemmKernel:
grouped_gemm_cta_tile_info.cta_tile_idx_n,
0,
)
cur_k_block_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
total_k_block_cnt += cur_k_block_cnt
cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
total_k_tile_cnt += cur_k_tile_cnt
#
# Slice to per mma tile index
@ -1370,8 +1369,8 @@ class GroupedGemmKernel:
#
if warp_idx == self.epilog_warp_id[0]:
cute.arch.mbarrier_wait(
(ab_empty_mbar_ptr + ((total_k_block_cnt - 1) % self.num_ab_stage)),
(((total_k_block_cnt - 1) // self.num_ab_stage) % 2),
(ab_empty_mbar_ptr + ((total_k_tile_cnt - 1) % self.num_ab_stage)),
(((total_k_tile_cnt - 1) // self.num_ab_stage) % 2),
)
@cute.jit

View File

@ -622,7 +622,6 @@ class SSDKernel:
block=[self.threads_per_cta, 1, 1],
cluster=self.cluster_shape_mnk,
min_blocks_per_mp=1,
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
@ -693,7 +692,7 @@ class SSDKernel:
G = cute.size(tma_tensor_b, mode=[3])
NGROUP_RATIO = EH // G
# Make tiledMma
# Make TiledMma
(
tiled_mma_intra1,
tiled_mma_intra2,
@ -1745,7 +1744,7 @@ class SSDKernel:
cute.arch.fence_view_async_tmem_load()
# Combine INTER1_ACC/last_column/State
exp_last_column = cute.arch.exp(last_column.ir_value())
exp_last_column = cute.math.exp(last_column, fastmath=True)
for reg_idx in range(0, cute.size(tTR_rP), 2):
(
tTR_rP[reg_idx],
@ -2267,9 +2266,11 @@ class SSDKernel:
) = cute.arch.fma_packed_f32x2(
(tTR_rInter[reg_idx], tTR_rInter[reg_idx + 1]),
(
cute.arch.exp(tTR_rDeltaA[reg_idx].ir_value()),
cute.arch.exp(
tTR_rDeltaA[reg_idx + 1].ir_value()
cute.math.exp(
tTR_rDeltaA[reg_idx], fastmath=True
),
cute.math.exp(
tTR_rDeltaA[reg_idx + 1], fastmath=True
),
),
(tTR_rIntra[reg_idx], tTR_rIntra[reg_idx + 1]),
@ -3072,14 +3073,19 @@ class SSDKernel:
m, n = tCoord[subtile_idx]
if m < n:
tCompute[subtile_idx] = cutlass.Float32(-float("inf"))
LOG2_E = cutlass.Float32(1.4426950408889634)
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
# TODO: use math.exp directly
tCompute_log2e = cute.arch.mul_packed_f32x2(
(tCompute[subtile_idx], tCompute[subtile_idx + 1]), (LOG2_E, LOG2_E)
)
(
tCompute[subtile_idx],
tCompute[subtile_idx + 1],
) = cute.arch.mul_packed_f32x2(
cute.arch.exp_packed_f32x2(
(tCompute[subtile_idx], tCompute[subtile_idx + 1])
(
cute.math.exp2(tCompute_log2e[0], fastmath=True),
cute.math.exp2(tCompute_log2e[1], fastmath=True),
),
(tCrDelta[subtile_idx], tCrDelta[subtile_idx + 1]),
)
@ -3245,11 +3251,11 @@ class SSDKernel:
for reg_idx in range(0, cute.size(tBrB_Compute), 2):
tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2(
(
cute.arch.exp(
(last_column - tBrDeltaA_Compute[reg_idx]).ir_value()
cute.math.exp(
(last_column - tBrDeltaA_Compute[reg_idx]), fastmath=True
),
cute.arch.exp(
(last_column - tBrDeltaA_Compute[reg_idx + 1]).ir_value()
cute.math.exp(
(last_column - tBrDeltaA_Compute[reg_idx + 1]), fastmath=True
),
),
(tBrDelta_Compute[reg_idx], tBrDelta_Compute[reg_idx + 1]),

View File

@ -44,7 +44,7 @@ import cutlass.utils.hopper_helpers as sm90_utils
"""
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
using CUTE DSL.
using CuTe DSL.
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
@ -70,7 +70,7 @@ To run this example:
.. code-block:: bash
python examples/hopper/dense_gemm.py \
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
--mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
--c_dtype Float16 --acc_dtype Float32 \
--a_major k --b_major k --c_major n
@ -85,7 +85,7 @@ To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/hopper/dense_gemm.py \
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
--mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
--c_dtype Float16 --acc_dtype Float32 \
--a_major k --b_major k --c_major n
@ -95,14 +95,11 @@ Constraints:
* For fp16 types, A and B must have the same data type
* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
* Fp8 types only support k-major layout
* Only fp32 accumulation is supported in this example
* CTA tile shape M must be 64/128
* CTA tile shape N must be 64/128/256
* CTA tile shape K must be 64
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
* OOB tiles are not allowed when TMA store is disabled
"""
@ -128,10 +125,10 @@ def parse_arguments() -> argparse.Namespace:
help="mnkl dimensions (comma-separated)",
)
parser.add_argument(
"--tile_shape_mnk",
"--tile_shape_mn",
type=parse_comma_separated_ints,
choices=[(128, 128, 64), (128, 256, 64), (128, 64, 64), (64, 64, 64)],
default=(128, 128, 64),
choices=[(128, 128), (128, 256), (128, 64), (64, 64)],
default=(128, 128),
help="Cta tile shape (comma-separated)",
)
parser.add_argument(
@ -190,8 +187,8 @@ def parse_arguments() -> argparse.Namespace:
if len(args.mnkl) != 4:
parser.error("--mnkl must contain exactly 4 values")
if len(args.tile_shape_mnk) != 3:
parser.error("--tile_shape_mnk must contain exactly 3 values")
if len(args.tile_shape_mn) != 2:
parser.error("--tile_shape_mn must contain exactly 2 values")
if len(args.cluster_shape_mn) != 2:
parser.error("--cluster_shape_mn must contain exactly 2 values")
@ -210,10 +207,10 @@ class HopperWgmmaGemmKernel:
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
:type tile_shape_mnk: Tuple[int, int, int]
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
:type cluster_shape_mnk: Tuple[int, int, int]
:param tile_shape_mn: Shape of the CTA tile (M,N)
:type tile_shape_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
:note: Data type requirements:
- For 16-bit types: A and B must have the same data type
@ -236,8 +233,8 @@ class HopperWgmmaGemmKernel:
Example:
>>> gemm = HopperWgmmaGemmKernel(
... acc_dtype=cutlass.Float32,
... tile_shape_mnk=(128, 256, 64),
... cluster_shape_mnk=(1, 1, 1)
... tile_shape_mn=(128, 256),
... cluster_shape_mn=(1, 1)
... )
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
"""
@ -245,8 +242,8 @@ class HopperWgmmaGemmKernel:
def __init__(
self,
acc_dtype: type[cutlass.Numeric],
tile_shape_mnk: tuple[int, int, int],
cluster_shape_mnk: tuple[int, int, int],
tile_shape_mn: tuple[int, int],
cluster_shape_mn: tuple[int, int],
):
"""
Initializes the configuration for a Hopper dense GEMM kernel.
@ -256,28 +253,30 @@ class HopperWgmmaGemmKernel:
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
:type tile_shape_mnk: Tuple[int, int, int]
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
:type cluster_shape_mnk: Tuple[int, int, int]
:param tile_shape_mn: Shape of the CTA tile (M,N)
:type tile_shape_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
"""
self.acc_dtype = acc_dtype
self.cluster_shape_mnk = cluster_shape_mnk
self.cluster_shape_mn = cluster_shape_mn
self.mma_inst_shape_mn = None
self.tile_shape_mnk = tuple(tile_shape_mnk)
# K dimension is deferred in _setup_attributes
self.tile_shape_mnk = (*tile_shape_mn, 1)
# For large tile size, using two warp groups is preferred because using only one warp
# group may result in register spill
self.atom_layout_mnk = (
(2, 1, 1)
if tile_shape_mnk[0] > 64 and tile_shape_mnk[1] > 128
if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128
else (1, 1, 1)
)
self.num_mcast_ctas_a = None
self.num_mcast_ctas_b = None
self.is_a_mcast = False
self.is_b_mcast = False
self.tiled_mma = None
self.occupancy = 1
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
@ -315,12 +314,27 @@ class HopperWgmmaGemmKernel:
raise ValueError("CTA tile shape M must be 64/128")
if self.tile_shape_mnk[1] not in [64, 128, 256]:
raise ValueError("CTA tile shape N must be 64/128/256")
if self.tile_shape_mnk[2] not in [64]:
raise ValueError("CTA tile shape K must be 64")
self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_layout.sm90_mma_major_mode(),
self.b_layout.sm90_mma_major_mode(),
self.acc_dtype,
self.atom_layout_mnk,
tiler_mn=(64, self.tile_shape_mnk[1]),
)
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.tile_shape_mnk = (
self.tile_shape_mnk[0],
self.tile_shape_mnk[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1))
self.num_mcast_ctas_a = self.cluster_shape_mn[1]
self.num_mcast_ctas_b = self.cluster_shape_mn[0]
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
@ -401,28 +415,18 @@ class HopperWgmmaGemmKernel:
self._setup_attributes()
tiled_mma = sm90_utils.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_layout.sm90_mma_major_mode(),
self.b_layout.sm90_mma_major_mode(),
self.acc_dtype,
self.atom_layout_mnk,
tiler_mn=(64, self.tile_shape_mnk[1]),
)
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
a,
self.a_smem_layout_staged,
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
self.cluster_shape_mnk[1],
self.cluster_shape_mn[1],
)
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
b,
self.b_smem_layout_staged,
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
self.cluster_shape_mnk[0],
self.cluster_shape_mn[0],
)
tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
@ -431,20 +435,20 @@ class HopperWgmmaGemmKernel:
self.epi_tile,
)
grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mnk)
grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mn)
@cute.struct
class SharedStorage:
mainloop_pipeline_array_ptr: cute.struct.MemRange[
cutlass.Int64, self.ab_stage * 2
]
sa: cute.struct.Align[
sA: cute.struct.Align[
cute.struct.MemRange[
self.a_dtype, cute.cosize(self.a_smem_layout_staged)
],
self.buffer_align_bytes,
]
sb: cute.struct.Align[
sB: cute.struct.Align[
cute.struct.MemRange[
self.b_dtype, cute.cosize(self.b_smem_layout_staged)
],
@ -461,7 +465,7 @@ class HopperWgmmaGemmKernel:
tma_tensor_b,
tma_atom_c,
tma_tensor_c,
tiled_mma,
self.tiled_mma,
self.cta_layout_mnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
@ -469,8 +473,7 @@ class HopperWgmmaGemmKernel:
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=self.cluster_shape_mnk,
smem=self.shared_storage.size_in_bytes(),
cluster=(*self.cluster_shape_mn, 1),
stream=stream,
)
return
@ -562,8 +565,8 @@ class HopperWgmmaGemmKernel:
# Get the pid from cluster id
bidx_in_cluster = cute.arch.block_in_cluster_idx()
pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
pid_m = cid_m * self.cluster_shape_mn[0] + bidx_in_cluster[0]
pid_n = cid_n * self.cluster_shape_mn[1] + bidx_in_cluster[1]
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
cta_rank_in_cluster = cute.arch.make_warp_uniform(
@ -621,22 +624,22 @@ class HopperWgmmaGemmKernel:
)
# Cluster arrive after barrier init
if cute.size(self.cluster_shape_mnk) > 1:
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_arrive_relaxed()
# ///////////////////////////////////////////////////////////////////////////////
# Generate smem tensor A/B
# ///////////////////////////////////////////////////////////////////////////////
sa = storage.sa.get_tensor(
sA = storage.sA.get_tensor(
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
)
sb = storage.sb.get_tensor(
sB = storage.sB.get_tensor(
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
)
sc_ptr = cute.recast_ptr(
sa.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype
sC_ptr = cute.recast_ptr(
sA.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype
)
sc = cute.make_tensor(sc_ptr, epi_smem_layout_staged.outer)
sC = cute.make_tensor(sC_ptr, epi_smem_layout_staged.outer)
# ///////////////////////////////////////////////////////////////////////////////
# Local_tile partition global tensors
@ -673,34 +676,34 @@ class HopperWgmmaGemmKernel:
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
a_cta_crd = cluster_coord_mnk[1]
sa_for_tma_partition = cute.group_modes(sa, 0, 2)
sA_for_tma_partition = cute.group_modes(sA, 0, 2)
gA_for_tma_partition = cute.group_modes(gA_mkl, 0, 2)
tAsA, tAgA_mkl = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
a_cta_crd,
a_cta_layout,
sa_for_tma_partition,
sA_for_tma_partition,
gA_for_tma_partition,
)
# TMA load B partition_S/D
b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
b_cta_crd = cluster_coord_mnk[0]
sb_for_tma_partition = cute.group_modes(sb, 0, 2)
sB_for_tma_partition = cute.group_modes(sB, 0, 2)
gB_for_tma_partition = cute.group_modes(gB_nkl, 0, 2)
tBsB, tBgB_nkl = cute.nvgpu.cpasync.tma_partition(
tma_atom_b,
b_cta_crd,
b_cta_layout,
sb_for_tma_partition,
sB_for_tma_partition,
gB_for_tma_partition,
)
# //////////////////////////////////////////////////////////////////////////////
# Make frangments
# Make fragments
# //////////////////////////////////////////////////////////////////////////////
tCsA = thr_mma.partition_A(sa)
tCsB = thr_mma.partition_B(sb)
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCrA = tiled_mma.make_fragment_A(tCsA)
tCrB = tiled_mma.make_fragment_B(tCsB)
@ -711,7 +714,7 @@ class HopperWgmmaGemmKernel:
# Cluster wait
# ///////////////////////////////////////////////////////////////////////////////
# cluster wait for barrier init
if cute.size(self.cluster_shape_mnk) > 1:
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_wait()
else:
cute.arch.sync_threads()
@ -788,7 +791,7 @@ class HopperWgmmaGemmKernel:
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_tile in range(k_pipe_mmas):
for k_tile in cutlass.range_constexpr(k_pipe_mmas):
# Wait for A/B buffer to be ready
mainloop_pipeline.consumer_wait(
mainloop_consumer_read_state, peek_ab_full_status
@ -917,7 +920,7 @@ class HopperWgmmaGemmKernel:
# /////////////////////////////////////////////////////////////////////////////
cute.nvgpu.warpgroup.wait_group(0)
if cute.size(self.cluster_shape_mnk) > 1:
if cute.size(self.cluster_shape_mn) > 1:
# Wait for all threads in the cluster to finish, avoid early release of smem
cute.arch.cluster_arrive()
cute.arch.cluster_wait()
@ -950,33 +953,45 @@ class HopperWgmmaGemmKernel:
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sD = thr_copy_r2s.partition_D(sc)
tRS_sD = thr_copy_r2s.partition_D(sC)
# (R2S, R2S_M, R2S_N)
tRS_rAcc = tiled_copy_r2s.retile(accumulators)
# Allocate D registers.
rD_shape = cute.shape(thr_copy_r2s.partition_S(sc))
rD_shape = cute.shape(thr_copy_r2s.partition_S(sC))
tRS_rD_layout = cute.make_layout(rD_shape[:3])
tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype)
size_tRS_rD = cute.size(tRS_rD)
sepi_for_tma_partition = cute.group_modes(sc, 0, 2)
tcgc_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
sepi_for_tma_partition = cute.group_modes(sC, 0, 2)
tCgC_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition(
tma_atom_c,
0,
cute.make_layout(1),
sepi_for_tma_partition,
tcgc_for_tma_partition,
tCgC_for_tma_partition,
)
epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1])
epi_tile_shape = tcgc_for_tma_partition.shape[1]
epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1])
epi_tile_shape = tCgC_for_tma_partition.shape[1]
epi_tile_layout = cute.make_layout(
epi_tile_shape, stride=(epi_tile_shape[1], 1)
)
for epi_idx in cutlass.range(epi_tile_num, unroll=epi_tile_num):
# Initialize tma store c_pipeline
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.epi_stage,
producer_group=c_producer_group,
)
for epi_idx in cutlass.range_constexpr(epi_tile_num):
# Copy from accumulators to D registers
for epi_v in range(size_tRS_rD):
for epi_v in cutlass.range_constexpr(size_tRS_rD):
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
# Type conversion
@ -997,10 +1012,6 @@ class HopperWgmmaGemmKernel:
# barrier for sync
cute.arch.barrier()
# Get the global memory coordinate for the current epi tile.
epi_tile_layout = cute.make_layout(
epi_tile_shape, stride=(epi_tile_shape[1], 1)
)
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
# Copy from shared memory to global memory
if warp_idx == 0:
@ -1009,11 +1020,14 @@ class HopperWgmmaGemmKernel:
bSG_sD[(None, epi_buffer)],
bSG_gD[(None, gmem_coord)],
)
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
c_pipeline.producer_commit()
c_pipeline.producer_acquire()
cute.arch.barrier()
if warp_idx == 0:
c_pipeline.producer_tail()
return
@staticmethod
@ -1055,9 +1069,7 @@ class HopperWgmmaGemmKernel:
mbar_helpers_bytes = 1024
ab_stage = (
(smem_capacity - occupancy * 1024) // occupancy
- mbar_helpers_bytes
- epi_bytes
smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
) // ab_bytes_per_stage
return ab_stage, epi_stage
@ -1195,7 +1207,7 @@ class HopperWgmmaGemmKernel:
def _compute_grid(
c: cute.Tensor,
tile_shape_mnk: tuple[int, int, int],
cluster_shape_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
) -> tuple[int, int, int]:
"""Compute grid shape for the output tensor C.
@ -1203,8 +1215,8 @@ class HopperWgmmaGemmKernel:
:type c: cute.Tensor
:param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
:type tile_shape_mnk: tuple[int, int, int]
:param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions.
:type cluster_shape_mnk: tuple[int, int, int]
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
:type cluster_shape_mn: tuple[int, int]
:return: Grid shape for kernel launch.
:rtype: tuple[int, int, int]
@ -1212,8 +1224,9 @@ class HopperWgmmaGemmKernel:
c_shape = (tile_shape_mnk[0], tile_shape_mnk[1])
gc = cute.zipped_divide(c, tiler=c_shape)
clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk)
grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk))
cluster_shape_mnl = (*cluster_shape_mn, 1)
clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnl)
grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnl))
return grid
@staticmethod
@ -1363,7 +1376,7 @@ def run(
a_major: str,
b_major: str,
c_major: str,
tile_shape_mnk: Tuple[int, int, int],
tile_shape_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
tolerance: float,
warmup_iterations: int,
@ -1387,8 +1400,8 @@ def run(
:type acc_dtype: Type[cutlass.Numeric]
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
:type a_major/b_major/c_major: str
:param tile_shape_mnk: CTA tile shape (M, N, K)
:type tile_shape_mnk: Tuple[int, int, int]
:param tile_shape_mn: CTA tile shape (M, N)
:type tile_shape_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster shape (M, N)
:type cluster_shape_mn: Tuple[int, int]
:param tolerance: Tolerance value for reference validation comparison
@ -1411,7 +1424,7 @@ def run(
f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
)
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
@ -1420,7 +1433,6 @@ def run(
# Unpack parameters
m, n, k, l = mnkl
cluster_shape_mnk = (*cluster_shape_mn, 1)
# Skip unsupported types
if not HopperWgmmaGemmKernel.is_valid_dtypes(
@ -1488,7 +1500,7 @@ def run(
b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk)
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mn, cluster_shape_mn)
torch_stream = torch.cuda.Stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
@ -1572,7 +1584,7 @@ if __name__ == "__main__":
args.a_major,
args.b_major,
args.c_major,
args.tile_shape_mnk,
args.tile_shape_mn,
args.cluster_shape_mn,
args.tolerance,
args.warmup_iterations,

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -41,22 +41,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
"b_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
"tensor(raw_ptr(0x0000000006cff170: f32, generic, align<4>) o (3,4):(4,1), data=\n",
" [[ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
" [ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
" [ 2.000000, 2.000000, 2.000000, 2.000000, ]])\n"
]
}
],
"outputs": [],
"source": [
"@cute.jit\n",
"def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
@ -91,22 +78,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor_value<vector<24xf32> o (4, 2, 3)> -> tensor_value<vector<12xf32> o (4, 3)>\n",
"tensor(raw_ptr(0x00000000071acaf0: f32, generic, align<4>) o (4,3):(3,1), data=\n",
" [[ 3.000000, 4.000000, 5.000000, ],\n",
" [ 9.000000, 10.000000, 11.000000, ],\n",
" [ 15.000000, 16.000000, 17.000000, ],\n",
" [ 21.000000, 22.000000, 23.000000, ]])\n"
]
}
],
"outputs": [],
"source": [
"@cute.jit\n",
"def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):\n",
@ -155,19 +129,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor_value<vector<24xf32> o (4, 2, 3)> -> ?\n",
"tensor(raw_ptr(0x00000000013cbbe0: f32, generic, align<4>) o (1):(1), data=\n",
" [ 10.000000, ])\n"
]
}
],
"outputs": [],
"source": [
"def slice_2():\n",
" src_shape = (4, 2, 3)\n",
@ -195,40 +159,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 3.000000, ],\n",
" [ 3.000000, ],\n",
" [ 3.000000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [-1.000000, ],\n",
" [-1.000000, ],\n",
" [-1.000000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 2.000000, ],\n",
" [ 2.000000, ],\n",
" [ 2.000000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 0.500000, ],\n",
" [ 0.500000, ],\n",
" [ 0.500000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 0.000000, ],\n",
" [ 0.000000, ],\n",
" [ 0.000000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 1.000000, ],\n",
" [ 1.000000, ],\n",
" [ 1.000000, ])\n"
]
}
],
"outputs": [],
"source": [
"@cute.jit\n",
"def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
@ -236,28 +169,22 @@
" b_vec = b.load()\n",
"\n",
" add_res = a_vec + b_vec\n",
" res.store(add_res)\n",
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
" cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n",
"\n",
" sub_res = a_vec - b_vec\n",
" res.store(sub_res)\n",
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
" cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n",
"\n",
" mul_res = a_vec * b_vec\n",
" res.store(mul_res)\n",
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
" cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n",
"\n",
" div_res = a_vec / b_vec\n",
" res.store(div_res)\n",
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
" cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n",
"\n",
" floor_div_res = a_vec // b_vec\n",
" res.store(floor_div_res)\n",
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
"\n",
" mod_res = a_vec % b_vec\n",
" res.store(mod_res)\n",
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
" cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n",
"\n",
"\n",
"a = np.empty((3,), dtype=np.float32)\n",
@ -270,68 +197,31 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 3.000000, ],\n",
" [ 3.000000, ],\n",
" [ 3.000000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [-1.000000, ],\n",
" [-1.000000, ],\n",
" [-1.000000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 2.000000, ],\n",
" [ 2.000000, ],\n",
" [ 2.000000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 0.500000, ],\n",
" [ 0.500000, ],\n",
" [ 0.500000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 0.000000, ],\n",
" [ 0.000000, ],\n",
" [ 0.000000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 1.000000, ],\n",
" [ 1.000000, ],\n",
" [ 1.000000, ])\n"
]
}
],
"outputs": [],
"source": [
"@cute.jit\n",
"def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n",
" a_vec = a.load()\n",
"\n",
" add_res = a_vec + c\n",
" res.store(add_res)\n",
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
" cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]\n",
"\n",
" sub_res = a_vec - c\n",
" res.store(sub_res)\n",
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
" cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]\n",
"\n",
" mul_res = a_vec * c\n",
" res.store(mul_res)\n",
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
" cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]\n",
"\n",
" div_res = a_vec / c\n",
" res.store(div_res)\n",
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
" cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]\n",
"\n",
" floor_div_res = a_vec // c\n",
" res.store(floor_div_res)\n",
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
" cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]\n",
"\n",
" mod_res = a_vec % c\n",
" res.store(mod_res)\n",
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
" cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]\n",
"\n",
"a = np.empty((3,), dtype=np.float32)\n",
"a.fill(1.0)\n",
@ -342,17 +232,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[False True False]\n"
]
}
],
"outputs": [],
"source": [
"@cute.jit\n",
"def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
@ -378,17 +260,9 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[3 0 7]\n"
]
}
],
"outputs": [],
"source": [
"@cute.jit\n",
"def binary_op_4(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
@ -420,44 +294,23 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
" [ 2.000000, ],\n",
" [ 2.000000, ],\n",
" [ 2.000000, ])\n",
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
" [-0.756802, ],\n",
" [-0.756802, ],\n",
" [-0.756802, ])\n",
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
" [ 16.000000, ],\n",
" [ 16.000000, ],\n",
" [ 16.000000, ])\n"
]
}
],
"outputs": [],
"source": [
"@cute.jit\n",
"def unary_op_1(res: cute.Tensor, a: cute.Tensor):\n",
" a_vec = a.load()\n",
"\n",
" sqrt_res = cute.math.sqrt(a_vec)\n",
" res.store(sqrt_res)\n",
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
" cute.print_tensor(sqrt_res) # prints [2.000000, 2.000000, 2.000000]\n",
"\n",
" sin_res = cute.math.sin(a_vec)\n",
" res.store(sin_res)\n",
" cute.print_tensor(res) # prints [-0.756802, -0.756802, -0.756802]\n",
" cute.print_tensor(sin_res) # prints [-0.756802, -0.756802, -0.756802]\n",
"\n",
" exp2_res = cute.math.exp2(a_vec)\n",
" res.store(exp2_res)\n",
" cute.print_tensor(res) # prints [16.000000, 16.000000, 16.000000]\n",
" cute.print_tensor(exp2_res) # prints [16.000000, 16.000000, 16.000000]\n",
"\n",
"a = np.array([4.0, 4.0, 4.0], dtype=np.float32)\n",
"res = np.empty((3,), dtype=np.float32)\n",
@ -470,29 +323,18 @@
"source": [
"#### Reduction Operation\n",
"\n",
"The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, `ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and performs this reduction along the dimensions specified by the `reduction_profile.`. The result is typically a new `TensorSSA` with reduced dimensions or a scalar value if reduces across all axes."
"The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, \n",
"`ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and \n",
"performs this reduction along the dimensions specified by the `reduction_profile`. The result \n",
"is typically a new `TensorSSA` with reduced dimensions or a scalar value if it reduces across \n",
"all axes."
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"21.000000\n",
"tensor(raw_ptr(0x00007ffd1ea2bca0: f32, rmem, align<32>) o (2):(1), data=\n",
" [ 6.000000, ],\n",
" [ 15.000000, ])\n",
"tensor(raw_ptr(0x00007ffd1ea2bcc0: f32, rmem, align<32>) o (3):(1), data=\n",
" [ 6.000000, ],\n",
" [ 8.000000, ],\n",
" [ 10.000000, ])\n"
]
}
],
"outputs": [],
"source": [
"@cute.jit\n",
"def reduction_op(a: cute.Tensor):\n",
@ -507,36 +349,138 @@
" 0.0,\n",
" reduction_profile=0\n",
" )\n",
" cute.printf(red_res) # prints 21.000000\n",
" cute.printf(red_res) # prints 21.000000\n",
"\n",
" red_res = a_vec.reduce(\n",
" cute.ReductionOp.ADD,\n",
" 0.0,\n",
" reduction_profile=(None, 1)\n",
" )\n",
" # We can't print the TensorSSA directly at this point, so we store it to a new Tensor and print it.\n",
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
" res.store(red_res)\n",
" cute.print_tensor(res) # prints [6.000000, 15.000000]\n",
" cute.print_tensor(red_res) # prints [6.000000, 15.000000]\n",
"\n",
" red_res = a_vec.reduce(\n",
" cute.ReductionOp.ADD,\n",
" 1.0,\n",
" reduction_profile=(1, None)\n",
" )\n",
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
" res.store(red_res)\n",
" cute.print_tensor(res) # prints [6.000000, 8.000000, 10.000000]\n",
" cute.print_tensor(red_res) # prints [6.000000, 8.000000, 10.000000]\n",
"\n",
"\n",
"a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)\n",
"reduction_op(from_dlpack(a))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Broadcast\n",
"\n",
"`TensorSSA` supports broadcasting operations following NumPy's broadcasting rules. Broadcasting \n",
"allows you to perform operations on arrays of different shapes when certain conditions are met. \n",
"The key rules are:\n",
"\n",
"1. Source shape is padded with 1's to match the rank of target shape\n",
"2. The size in each mode of source shape must either be 1 or equal to target shape\n",
"3. After broadcasting, all modes should match target shape\n",
"\n",
"Let's look at some examples of broadcasting in action:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cutlass\n",
"import cutlass.cute as cute\n",
"\n",
"\n",
"@cute.jit\n",
"def broadcast_examples():\n",
" a = cute.make_fragment((1,3), dtype=cutlass.Float32)\n",
" a[0] = 0.0\n",
" a[1] = 1.0\n",
" a[2] = 2.0\n",
" a_val = a.load()\n",
" cute.print_tensor(a_val.broadcast_to((4, 3)))\n",
" # tensor(raw_ptr(0x00007ffe26625740: f32, rmem, align<32>) o (4,3):(1,4), data=\n",
" # [[ 0.000000, 1.000000, 2.000000, ],\n",
" # [ 0.000000, 1.000000, 2.000000, ],\n",
" # [ 0.000000, 1.000000, 2.000000, ],\n",
" # [ 0.000000, 1.000000, 2.000000, ]])\n",
"\n",
" c = cute.make_fragment((4,1), dtype=cutlass.Float32)\n",
" c[0] = 0.0\n",
" c[1] = 1.0\n",
" c[2] = 2.0\n",
" c[3] = 3.0\n",
" cute.print_tensor(a.load() + c.load())\n",
" # tensor(raw_ptr(0x00007ffe26625780: f32, rmem, align<32>) o (4,3):(1,4), data=\n",
" # [[ 0.000000, 1.000000, 2.000000, ],\n",
" # [ 1.000000, 2.000000, 3.000000, ],\n",
" # [ 2.000000, 3.000000, 4.000000, ],\n",
" # [ 3.000000, 4.000000, 5.000000, ]])\n",
"\n",
"\n",
"broadcast_examples()"
]
},
{
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "raw"
}
},
"source": [
"The examples above demonstrate two key broadcasting scenarios:\n",
"\n",
"1. **Row Vector Broadcasting**: In the first example, we create a row vector `a` with shape \n",
" (1, 3) containing values [0.0, 1.0, 2.0]. When we broadcast it to shape (4, 3), the values \n",
" are repeated across the first dimension, resulting in:\n",
" ```\n",
" [[0.0, 1.0, 2.0],\n",
" [0.0, 1.0, 2.0],\n",
" [0.0, 1.0, 2.0],\n",
" [0.0, 1.0, 2.0]]\n",
" ```\n",
" This demonstrates how a row vector can be broadcast to create multiple identical rows.\n",
"\n",
"2. **Column Vector and Row Vector Addition**: In the second example, we have:\n",
" - A row vector `a` with shape (1, 3) containing [0.0, 1.0, 2.0]\n",
" - A column vector `c` with shape (4, 1) containing [0.0, 1.0, 2.0, 3.0]\n",
" \n",
" When we add these together, both vectors are broadcast to shape (4, 3):\n",
" - The row vector is broadcast vertically (4 times)\n",
" - The column vector is broadcast horizontally (3 times)\n",
" \n",
" The result is:\n",
" ```\n",
" [[0.0 + 0.0, 1.0 + 0.0, 2.0 + 0.0],\n",
" [0.0 + 1.0, 1.0 + 1.0, 2.0 + 1.0],\n",
" [0.0 + 2.0, 1.0 + 2.0, 2.0 + 2.0],\n",
" [0.0 + 3.0, 1.0 + 3.0, 2.0 + 3.0]]\n",
" ```\n",
" =\n",
" ```\n",
" [[0.0, 1.0, 2.0],\n",
" [1.0, 2.0, 3.0],\n",
" [2.0, 3.0, 4.0],\n",
" [3.0, 4.0, 5.0]]\n",
" ```\n",
"\n",
"This demonstrates how `TensorSSA` can automatically handle broadcasting of both row and column \n",
"vectors in arithmetic operations, following the broadcasting rules where each dimension must \n",
"either be 1 or match the target size. The broadcasting is handled implicitly during operations, \n",
"making it easy to work with tensors of different shapes.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv3_12",
"language": "python",
"name": "python3"
},
@ -550,7 +494,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.10"
}
},
"nbformat": 4,