v4.2 tag release. (#2638)

This commit is contained in:
Junkai-Wu
2025-09-16 00:21:53 +08:00
committed by GitHub
parent 56f0718a97
commit 6a35b4d22f
161 changed files with 14056 additions and 3793 deletions

View File

@ -0,0 +1,314 @@
import os
import torch
import argparse
from cuda import cuda
from cuda.bindings import driver
from typing import Type
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.multiprocessing as mp
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from cutlass._mlir.dialects import llvm, builtin, vector, arith
WORLD_SIZE = 8
PING_PONG_SIZE = 3
def setup(rank, world_size):
# set environment variables for torch.distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12959"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
class AllReduceKernel:
@cute.jit
def __call__(
self,
rank,
signal,
local_input: cute.Tensor,
local_output: cute.Tensor,
buffer0: cute.Tensor,
buffer1: cute.Tensor,
buffer2: cute.Tensor,
buffer3: cute.Tensor,
buffer4: cute.Tensor,
buffer5: cute.Tensor,
buffer6: cute.Tensor,
buffer7: cute.Tensor,
stream: cuda.CUstream,
):
# define constants for future use
num_of_elements = cute.size(local_input.layout)
# 128 threads per block and 4 elements per thread
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
tile = cute.size(tv_layout.shape)
buffers = [
buffer0,
buffer1,
buffer2,
buffer3,
buffer4,
buffer5,
buffer6,
buffer7,
]
tiled_buffers = [
cute.logical_divide(buffer, (tile, None, None)) for buffer in buffers
]
tiled_input = cute.zipped_divide(local_input, cute.make_layout(tile))
tiled_output = cute.zipped_divide(local_output, cute.make_layout(tile))
self.kernel(
tiled_buffers[0],
tiled_buffers[1],
tiled_buffers[2],
tiled_buffers[3],
tiled_buffers[4],
tiled_buffers[5],
tiled_buffers[6],
tiled_buffers[7],
tiled_input,
tiled_output,
tv_layout,
cutlass.Int32(signal),
cutlass.Int32(rank),
).launch(
grid=[num_of_elements // tile, 1, 1],
block=[tv_layout.shape[0], 1, 1],
stream=stream,
)
# GPU device kernel
@cute.kernel
def kernel(
self,
buffer0: cute.Tensor,
buffer1: cute.Tensor,
buffer2: cute.Tensor,
buffer3: cute.Tensor,
buffer4: cute.Tensor,
buffer5: cute.Tensor,
buffer6: cute.Tensor,
buffer7: cute.Tensor,
local_input: cute.Tensor,
local_output: cute.Tensor,
tv_layout: cute.Layout,
signal: cutlass.Int32,
rank: cutlass.Int32,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
ping = signal % 3
pong = (signal + 1) % 3
buffers = [
buffer0,
buffer1,
buffer2,
buffer3,
buffer4,
buffer5,
buffer6,
buffer7,
]
def get_buffer():
t = buffers[2]
if rank == cutlass.Int32(0):
t = buffers[0]
if rank == cutlass.Int32(1):
t = buffers[1]
if rank == cutlass.Int32(2):
t = buffers[2]
if rank == cutlass.Int32(3):
t = buffers[3]
if rank == cutlass.Int32(4):
t = buffers[4]
if rank == cutlass.Int32(5):
t = buffers[5]
if rank == cutlass.Int32(6):
t = buffers[6]
if rank == cutlass.Int32(7):
t = buffers[7]
return t
buffer_local = get_buffer()
cta_coord = (None, bidx)
local_tile_in = local_input[cta_coord]
local_tile_out = local_output[cta_coord]
ping_coord = ((None, bidx), None, ping)
read_buffer = buffer_local[ping_coord]
pong_coord = ((None, bidx), None, pong)
clear_buffer = buffer_local[pong_coord]
write_coord = ((None, bidx), rank, ping)
write_buffers = [buffer[write_coord] for buffer in buffers]
# assume all buffers have the same element type with input
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
buffer0.element_type,
num_bits_per_copy=64,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
buffer0.element_type,
num_bits_per_copy=64,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
)
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
thr_copy = tiled_copy.get_slice(tidx)
thr_write_buffer_list = [
thr_copy.partition_D(tensor) for tensor in write_buffers
]
thr_read_buffer = thr_copy.partition_S(read_buffer)
thr_clear_buffer = thr_copy.partition_D(clear_buffer)
thr_in = thr_copy.partition_S(local_tile_in)
thr_out = thr_copy.partition_D(local_tile_out)
frg_in = cute.make_fragment_like(thr_in)
frg_clear = cute.make_fragment_like(thr_clear_buffer)
frg_acc = cute.make_fragment_like(thr_out)
frg_acc.fill(0.0)
clear_tensor = frg_clear.load()
frg_size = cute.size(clear_tensor.shape)
neg0_i32_vec = cute.full_like(clear_tensor, 0x80000000, cutlass.Int32)
neg0_f32_vec = vector.bitcast(T.vector(frg_size, T.f32()), neg0_i32_vec)
neg0_f32_tensor = cute.TensorSSA(
neg0_f32_vec, clear_tensor.shape, cutlass.Float32
)
frg_clear.store(neg0_f32_tensor)
cute.copy(copy_atom_load, thr_in, frg_in)
for thr_write_buffer in thr_write_buffer_list:
cute.copy(copy_atom_store, frg_in, thr_write_buffer)
cute.copy(copy_atom_store, frg_clear, thr_clear_buffer)
frg_in_vector_neg0_i32 = cute.full_like(
frg_in, cutlass.Int32(0x80000000), cutlass.Int32
)
frg_in_size = cute.size(frg_in.shape)
for i in range(WORLD_SIZE):
read_coord = (None, 0, i)
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
frg_vector = frg_in.load()
frg_vector_i32 = vector.bitcast(T.vector(frg_in_size, T.i32()), frg_vector)
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
while not isNotNeg0:
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
frg_vector = frg_in.load()
frg_vector_i32 = vector.bitcast(
T.vector(frg_in_size, T.i32()), frg_vector
)
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
frg_acc.store(frg_in.load() + frg_acc.load())
cute.copy(copy_atom_stg, frg_acc, thr_out)
def run_all_reduce(rank, M, N, dtype: Type[cutlass.Numeric]):
setup(rank, WORLD_SIZE)
input_tensor = torch.randn(M * N, device=f"cuda:{rank}")
output_tensor = torch.zeros(M * N, device=f"cuda:{rank}")
# init tensors on different devices
t = symm_mem.empty(
[
PING_PONG_SIZE,
WORLD_SIZE,
M * N,
],
device="cuda",
).neg_()
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
buffer_tensor_list = [
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 1, 0)
for rank in range(WORLD_SIZE)
]
# enable peer access
driver.cuInit(0)
dev_list = [driver.cuDeviceGet(i)[1] for i in range(WORLD_SIZE)]
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
for i in range(WORLD_SIZE):
driver.cuCtxSetCurrent(ctx_list[i])
for j in range(WORLD_SIZE):
if i == j:
continue
driver.cuCtxEnablePeerAccess(ctx_list[j], 0)
driver.cuCtxSetCurrent(ctx_list[rank])
stream = cutlass.cuda.default_stream()
all_reduce_kernel = AllReduceKernel()
dlpack_buffers = [from_dlpack(x, assumed_align=32) for x in buffer_tensor_list]
all_reduce_kernel(
rank,
0,
from_dlpack(input_tensor, assumed_align=32),
from_dlpack(output_tensor, assumed_align=32),
*dlpack_buffers,
stream,
)
torch.cuda.synchronize(0)
# use torch api to get reference and inplace stored to input_tensor
ref_tensor = input_tensor.clone()
dist.all_reduce(ref_tensor, op=dist.ReduceOp.SUM)
# check result of output tensor, allow small error due to different accumulator datatypes
equal_mask = (ref_tensor.cpu() - output_tensor.cpu()).abs() < 1e-4
result = (equal_mask.sum()).item() == ref_tensor.numel()
if result:
print(f"rank {rank} test passed")
else:
print(f"rank {rank} test failed")
print(
"ref_tensor[ref_tensor != output_tensor]: ",
ref_tensor[ref_tensor != output_tensor],
)
print(
"output_tensor[ref_tensor != output_tensor]: ",
output_tensor[ref_tensor != output_tensor],
)
cleanup()
def main():
M = 1024
N = 1024
# each process will run run_all_reduce on different device
mp.spawn(run_all_reduce, args=(M, N, cutlass.Float32), nprocs=WORLD_SIZE, join=True)
return
if __name__ == "__main__":
main()

View File

@ -0,0 +1,166 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import sys
import os
from typing import Tuple
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import make_ptr
"""
An Example demonstrating how to call off-the-shelf kernel by-passing dlpack protocol
The example shows how to directly pass pointers from PyTorch tensors to off-the-shelf kernels
written by CuTe DSL with a thin customized wrapper jit function. The jit function will be
compiled with inline without introducing overhead.
To run this example:
.. code-block:: bash
python examples/ampere/call_bypass_dlpack.py
It's worth to mention that by-passing dlpack protocol can resolve the issue that dlpack doesn't handle shape-1
mode correctly. For example, the following code will fail, because dlpack will convert the shape-1 mode
with stride-1 which propagate alignment incorrectly.
.. code-block:: python
@cute.kernel
def fails_kernel(gX: cute.Tensor):
bidx, _, _ = cute.arch.block_idx()
mX = gX[None, bidx, None] # We wish to retain alignment
# assert mX.iterator.alignment == 16
@cute.jit
def fails(gX_: cute.Tensor):
gX = gX_
fails_kernel(gX).launch(grid=(1, 1, 1), block=(128, 1, 1))
gX_torch = torch.rand((128, 1, 128), device="cuda", dtype=torch.bfloat16)
fails(from_dlpack(gX_torch, assumed_align=16))
"""
# Add the current directory to sys.path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from tensorop_gemm import TensorOpGemm
@cute.jit
def tensor_op_gemm_wrapper(
a_ptr: cute.Pointer,
b_ptr: cute.Pointer,
c_ptr: cute.Pointer,
m: cutlass.Int32,
n: cutlass.Int32,
k: cutlass.Int32,
l: cutlass.Int32,
):
print(f"\n[DSL INFO] Input Parameters:")
print(f"[DSL INFO] mnkl: {(m, n, k, l)}")
# Assume alignment of shape to call tensorop_gemm example
m = cute.assume(m, divby=8)
n = cute.assume(n, divby=8)
# Torch is row major
a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2))
b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2))
c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2))
mA = cute.make_tensor(a_ptr, layout=a_layout)
mB = cute.make_tensor(b_ptr, layout=b_layout)
mC = cute.make_tensor(c_ptr, layout=c_layout)
print(f"[DSL INFO] mA: {mA}")
print(f"[DSL INFO] mB: {mB}")
print(f"[DSL INFO] mC: {mC}")
tensor_op_gemm = TensorOpGemm(
a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1)
)
print(f"\n[DSL INFO] Created TensorOpGemm instance")
print(f"[DSL INFO] Input dtype: {a_ptr.value_type}")
print(f"[DSL INFO] Output dtype: {c_ptr.value_type}")
print(f"[DSL INFO] Accumulation dtype: {cutlass.Float32}")
print(f"[DSL INFO] Atom layout: {(2, 2, 1)}")
# No need to compile inside jit function
tensor_op_gemm(mA, mB, mC)
print(f"\n[DSL INFO] Executed TensorOpGemm")
def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
print(f"\nRunning TensorOpGemm test with:")
print(f"Tensor dimensions: {mnkl}")
# (M,K,L)
a = torch.randn(
mnkl[3], mnkl[2], mnkl[0], dtype=torch.float16, device="cuda"
).permute(2, 1, 0)
# (N,K,L)
b = torch.randn(
mnkl[3], mnkl[2], mnkl[1], dtype=torch.float16, device="cuda"
).permute(2, 1, 0)
# (N,M,L)
c = torch.randn(
mnkl[3], mnkl[0], mnkl[1], dtype=torch.float16, device="cuda"
).permute(1, 2, 0)
print(f"Input tensor shapes:")
print(f"a: {a.shape}, dtype: {a.dtype}")
print(f"b: {b.shape}, dtype: {b.dtype}")
print(f"c: {c.shape}, dtype: {c.dtype}\n")
a_ptr = make_ptr(
cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
b_ptr = make_ptr(
cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
c_ptr = make_ptr(
cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, *mnkl)
torch.cuda.synchronize()
ref = torch.einsum("mkl,nkl->mnl", a, b)
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
print(f"\n[DSL INFO] Results verified successfully!")
print(f"First few elements of result: \n{c[:3, :3, :3]}")
if __name__ == "__main__":
run_tensor_op_gemm_wrapper((512, 256, 128, 16))

View File

@ -226,15 +226,15 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
print(f"c: {c.shape}, dtype: {c.dtype}\n")
buffer_a = BufferWithLayout(
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem),
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
(2, 1, 0),
)
buffer_b = BufferWithLayout(
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem),
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
(2, 1, 0),
)
buffer_c = BufferWithLayout(
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem),
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
(2, 1, 0),
)

View File

@ -0,0 +1,189 @@
import os
import torch
import argparse
from typing import Type
from cuda.bindings import driver
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.multiprocessing as mp
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
def setup(rank, world_size):
# set environment variables for torch.distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12995"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
@cute.kernel
def vector_add_kernel(
g0: cute.Tensor,
g1: cute.Tensor,
g2: cute.Tensor,
g3: cute.Tensor,
g4: cute.Tensor,
g5: cute.Tensor,
g6: cute.Tensor,
g7: cute.Tensor,
gOut: cute.Tensor,
tv_layout: cute.Layout,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
cta_coord = (None, bidx)
local_tile_out = gOut[cta_coord]
local_tile_list = [
g0[cta_coord],
g1[cta_coord],
g2[cta_coord],
g3[cta_coord],
g4[cta_coord],
g5[cta_coord],
g6[cta_coord],
g7[cta_coord],
]
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
g0.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
g0.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
thr_copy = tiled_copy.get_slice(tidx)
thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list]
thr_out = thr_copy.partition_D(local_tile_out)
frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list]
frg_acc = cute.make_fragment_like(thr_out)
frg_acc.fill(0.0)
for thr, frg in zip(thr_tensor_list, frg_tensor_list):
cute.copy(copy_atom_load, thr, frg)
tmp = frg.load() + frg_acc.load()
frg_acc.store(tmp)
cute.copy(copy_atom_store, frg_acc, thr_out)
@cute.jit
def vector_add(
m0: cute.Tensor,
m1: cute.Tensor,
m2: cute.Tensor,
m3: cute.Tensor,
m4: cute.Tensor,
m5: cute.Tensor,
m6: cute.Tensor,
m7: cute.Tensor,
output: cute.Tensor,
):
# define constants for future use
num_of_elements = cute.size(m0.layout)
# 128 threads per block and 4 elements per thread
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
tile = cute.size(tv_layout.shape)
tensors = [m0, m1, m2, m3, m4, m5, m6, m7]
divided_tensors = [
cute.zipped_divide(tensor, cute.make_layout(tile)) for tensor in tensors
]
gOut = cute.zipped_divide(output, cute.make_layout(tile)) # ((Tile),(Rest))
vector_add_kernel(
divided_tensors[0],
divided_tensors[1],
divided_tensors[2],
divided_tensors[3],
divided_tensors[4],
divided_tensors[5],
divided_tensors[6],
divided_tensors[7],
gOut,
tv_layout,
).launch(
grid=[num_of_elements // tile, 1, 1],
block=[tv_layout.shape[0], 1, 1],
)
def run_vector_add(rank, world_size, M, N, dtype: Type[cutlass.Numeric]):
setup(rank, world_size)
t = symm_mem.empty(M * N, device="cuda")
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
# get tensors from other devices from the symmetric memory
tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)]
tensor_list[rank].random_(0, 100)
# enable peer access
driver.cuInit(0)
dev_list = [driver.cuDeviceGet(i)[1] for i in range(world_size)]
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
driver.cuCtxSetCurrent(ctx_list[rank])
for i in range(world_size):
if i == rank:
continue
driver.cuCtxEnablePeerAccess(ctx_list[i], 0)
output = torch.zeros(M * N, device=f"cuda:{rank}")
# we have to explicitly pass each tensor instead of a list of tensors
vector_add(
from_dlpack(tensor_list[0], assumed_align=32),
from_dlpack(tensor_list[1], assumed_align=32),
from_dlpack(tensor_list[2], assumed_align=32),
from_dlpack(tensor_list[3], assumed_align=32),
from_dlpack(tensor_list[4], assumed_align=32),
from_dlpack(tensor_list[5], assumed_align=32),
from_dlpack(tensor_list[6], assumed_align=32),
from_dlpack(tensor_list[7], assumed_align=32),
from_dlpack(output, assumed_align=32),
)
sum_tensor = sum([tensor.cpu() for tensor in tensor_list])
if sum(sum_tensor.cpu() == output.cpu()) == sum_tensor.numel():
print("test passed")
else:
print("test failed")
print(sum_tensor.cpu())
print(output.cpu())
cleanup()
def main():
world_size = torch.cuda.device_count()
M = 1024
N = 1024
# each process will run run_vector_add on different device
mp.spawn(
run_vector_add,
args=(world_size, M, N, cutlass.Float32),
nprocs=world_size,
join=True,
)
return
if __name__ == "__main__":
main()

View File

@ -0,0 +1,114 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass.cute as cute
import cutlass
"""
Example of automatic shared memory size computation for configuring kernel launch
This example demonstrates how to let the DSL automatically set shared memory
size for a kernel launch rather explicitly configuring it at launch time,
provided that developers are using `SmemAllocator` for all allocations.
Usage:
python dynamic_smem_size.py # Show auto inference
"""
@cute.struct
class SharedData:
"""A struct to demonstrate shared memory allocation."""
values: cute.struct.MemRange[cutlass.Float32, 64] # 256 bytes
counter: cutlass.Int32 # 4 bytes
flag: cutlass.Int8 # 1 byte
@cute.kernel
def kernel():
"""
Example kernel that allocates shared memory.
The total allocation will be automatically calculated when smem=None.
"""
allocator = cutlass.utils.SmemAllocator()
# Allocate various types of shared memory
shared_data = allocator.allocate(SharedData)
raw_buffer = allocator.allocate(512, byte_alignment=64)
int_array = allocator.allocate_array(element_type=cutlass.Int32, num_elems=128)
tensor_smem = allocator.allocate_tensor(
element_type=cutlass.Float16,
layout=cute.make_layout((32, 16)),
byte_alignment=16,
swizzle=None,
)
return
@cute.kernel
def kernel_no_smem():
"""
Example kernel that does not allocates shared memory.
The total allocation will be automatically calculated as 0 when smem=None.
"""
tidx, _, _ = cute.arch.block_idx()
if tidx == 0:
cute.printf("Hello world")
return
if __name__ == "__main__":
# Initialize CUDA context
cutlass.cuda.initialize_cuda_context()
print("Launching kernel with auto smem size. (launch config `smem=None`)")
# Compile the example
@cute.jit
def launch_kernel1():
k = kernel()
k.launch(
grid=(1, 1, 1),
block=(1, 1, 1),
)
print(f"Kernel recorded internal smem usage: {k.smem_usage()}")
@cute.jit
def launch_kernel2():
k = kernel_no_smem()
k.launch(
grid=(1, 1, 1),
block=(1, 1, 1),
)
print(f"Kernel recorded internal smem usage: {k.smem_usage()}")
cute.compile(launch_kernel1)
cute.compile(launch_kernel2)
print("PASS")

View File

@ -327,7 +327,6 @@ class FlashAttentionForwardAmpere:
).launch(
grid=grid_dim,
block=[self._num_threads, 1, 1],
smem=SharedStorage.size_in_bytes(),
stream=stream,
)
@ -1014,13 +1013,10 @@ class FlashAttentionForwardAmpere:
)
# compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e))
acc_S_row_exp = cute.TensorSSA(
self._exp2f(
acc_S_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2
),
tuple(acc_S_row.shape),
cutlass.Float32,
acc_S_row_exp = cute.math.exp2(
acc_S_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2,
fastmath=True,
)
# acc_S_row_sum => f32
acc_S_row_sum = acc_S_row_exp.reduce(
@ -1028,9 +1024,10 @@ class FlashAttentionForwardAmpere:
)
# if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum.
if cutlass.const_expr(not is_first_n_block):
prev_minus_cur_exp = self._exp2f(
prev_minus_cur_exp = cute.math.exp2(
row_max_prev_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2,
fastmath=True,
)
acc_S_row_sum = (
acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp
@ -1141,26 +1138,6 @@ class FlashAttentionForwardAmpere:
"""
return self._threadquad_reduce(val, lambda x, y: x + y)
def _exp2f(
self, x: Union[cute.TensorSSA, cutlass.Float32]
) -> Union[cute.TensorSSA, cutlass.Float32]:
"""exp2f calculation for both vector and scalar.
:param x: input value
:type x: cute.TensorSSA or cutlass.Float32
:return: exp2 value
:rtype: cute.TensorSSA or cutlass.Float32
"""
if isinstance(x, cute.TensorSSA):
res = cute.make_fragment(x.shape, cutlass.Float32)
res.store(x)
for i in range(cute.size(x.shape)):
res[i] = self._exp2f(res[i])
return res.load()
return cute.arch.exp2(x)
def run(
dtype: Type[cutlass.Numeric],

View File

@ -136,10 +136,6 @@ class SGemm:
stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)),
)
smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes(
mB.element_type, sB_layout
)
# ///////////////////////////////////////////////////////////////////////////////
# Create copy layouts that will be used for asynchronous
# global memory -> shared memory copies:
@ -258,7 +254,6 @@ class SGemm:
).launch(
grid=grid_dim,
block=[cute.size(atoms_layout), 1, 1],
smem=smem_size,
stream=stream,
)
@ -738,14 +733,20 @@ def run(
print("Compiling kernel with cute.compile ...")
start_time = time.time()
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream)
compiled_fn = cute.compile(
sgemm,
a_tensor,
b_tensor,
c_tensor,
stream=current_stream,
)
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
print("Executing GEMM kernel...")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
compiled_fn(a_tensor, b_tensor, c_tensor)
torch.cuda.synchronize()
print("Verifying results...")
ref = torch.einsum("mk,nk->mn", a, b)
@ -804,7 +805,7 @@ def run(
)
avg_time_us = testing.benchmark(
gemm,
compiled_fn,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
@ -837,6 +838,7 @@ if __name__ == "__main__":
parser.add_argument("--c_major", choices=["n", "m"], default="n")
parser.add_argument("--warmup_iterations", default=2, type=int)
parser.add_argument("--iterations", default=100, type=int)
parser.add_argument("--static_shape", action="store_true")
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument(
"--use_cold_l2",

View File

@ -69,7 +69,7 @@ class complex:
class SharedStorage:
# struct elements with natural alignment
a: cute.struct.MemRange[cutlass.Float32, 32] # array
b: cutlass.Int64 # scalar
b: cutlass.Int64 # saclar
c: complex # nested struct
# struct elements with strict alignment
x: cute.struct.Align[

View File

@ -471,7 +471,7 @@ class TensorOpGemm:
cute.arch.sync_threads()
# Start async loads for the first k-tile. Here we take care of the k residue
# via if/else check along the k dimension. Because we shifted the identity tensor
# by the residue_k and because the identity tensor is a counting tensor, the
# by the residue_k and because the identity tensor is a coord tensor, the
# values of any identity tensor element that is poison is less than -1
num_smem_stages = cute.size(tAsA, mode=[3])
k_tile_count = cute.size(tAgA, mode=[3])
@ -683,7 +683,7 @@ class TensorOpGemm:
# Copy results of D back to shared memory
cute.autovec_copy(tCrD, tCsC)
# Create counting tensor for C
# Create coord tensor for C
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
mcC = cute.make_identity_tensor(
(