v4.2 tag release. (#2638)
This commit is contained in:
314
examples/python/CuTeDSL/ampere/all_reduce.py
Normal file
314
examples/python/CuTeDSL/ampere/all_reduce.py
Normal file
@ -0,0 +1,314 @@
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
from cuda import cuda
|
||||
from cuda.bindings import driver
|
||||
from typing import Type
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass._mlir.dialects import llvm, builtin, vector, arith
|
||||
|
||||
WORLD_SIZE = 8
|
||||
PING_PONG_SIZE = 3
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
# set environment variables for torch.distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12959"
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class AllReduceKernel:
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
rank,
|
||||
signal,
|
||||
local_input: cute.Tensor,
|
||||
local_output: cute.Tensor,
|
||||
buffer0: cute.Tensor,
|
||||
buffer1: cute.Tensor,
|
||||
buffer2: cute.Tensor,
|
||||
buffer3: cute.Tensor,
|
||||
buffer4: cute.Tensor,
|
||||
buffer5: cute.Tensor,
|
||||
buffer6: cute.Tensor,
|
||||
buffer7: cute.Tensor,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
# define constants for future use
|
||||
num_of_elements = cute.size(local_input.layout)
|
||||
# 128 threads per block and 4 elements per thread
|
||||
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
|
||||
tile = cute.size(tv_layout.shape)
|
||||
|
||||
buffers = [
|
||||
buffer0,
|
||||
buffer1,
|
||||
buffer2,
|
||||
buffer3,
|
||||
buffer4,
|
||||
buffer5,
|
||||
buffer6,
|
||||
buffer7,
|
||||
]
|
||||
tiled_buffers = [
|
||||
cute.logical_divide(buffer, (tile, None, None)) for buffer in buffers
|
||||
]
|
||||
|
||||
tiled_input = cute.zipped_divide(local_input, cute.make_layout(tile))
|
||||
tiled_output = cute.zipped_divide(local_output, cute.make_layout(tile))
|
||||
self.kernel(
|
||||
tiled_buffers[0],
|
||||
tiled_buffers[1],
|
||||
tiled_buffers[2],
|
||||
tiled_buffers[3],
|
||||
tiled_buffers[4],
|
||||
tiled_buffers[5],
|
||||
tiled_buffers[6],
|
||||
tiled_buffers[7],
|
||||
tiled_input,
|
||||
tiled_output,
|
||||
tv_layout,
|
||||
cutlass.Int32(signal),
|
||||
cutlass.Int32(rank),
|
||||
).launch(
|
||||
grid=[num_of_elements // tile, 1, 1],
|
||||
block=[tv_layout.shape[0], 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# GPU device kernel
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
buffer0: cute.Tensor,
|
||||
buffer1: cute.Tensor,
|
||||
buffer2: cute.Tensor,
|
||||
buffer3: cute.Tensor,
|
||||
buffer4: cute.Tensor,
|
||||
buffer5: cute.Tensor,
|
||||
buffer6: cute.Tensor,
|
||||
buffer7: cute.Tensor,
|
||||
local_input: cute.Tensor,
|
||||
local_output: cute.Tensor,
|
||||
tv_layout: cute.Layout,
|
||||
signal: cutlass.Int32,
|
||||
rank: cutlass.Int32,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
ping = signal % 3
|
||||
pong = (signal + 1) % 3
|
||||
|
||||
buffers = [
|
||||
buffer0,
|
||||
buffer1,
|
||||
buffer2,
|
||||
buffer3,
|
||||
buffer4,
|
||||
buffer5,
|
||||
buffer6,
|
||||
buffer7,
|
||||
]
|
||||
|
||||
def get_buffer():
|
||||
t = buffers[2]
|
||||
if rank == cutlass.Int32(0):
|
||||
t = buffers[0]
|
||||
if rank == cutlass.Int32(1):
|
||||
t = buffers[1]
|
||||
if rank == cutlass.Int32(2):
|
||||
t = buffers[2]
|
||||
if rank == cutlass.Int32(3):
|
||||
t = buffers[3]
|
||||
if rank == cutlass.Int32(4):
|
||||
t = buffers[4]
|
||||
if rank == cutlass.Int32(5):
|
||||
t = buffers[5]
|
||||
if rank == cutlass.Int32(6):
|
||||
t = buffers[6]
|
||||
if rank == cutlass.Int32(7):
|
||||
t = buffers[7]
|
||||
return t
|
||||
|
||||
buffer_local = get_buffer()
|
||||
cta_coord = (None, bidx)
|
||||
local_tile_in = local_input[cta_coord]
|
||||
local_tile_out = local_output[cta_coord]
|
||||
|
||||
ping_coord = ((None, bidx), None, ping)
|
||||
read_buffer = buffer_local[ping_coord]
|
||||
|
||||
pong_coord = ((None, bidx), None, pong)
|
||||
clear_buffer = buffer_local[pong_coord]
|
||||
|
||||
write_coord = ((None, bidx), rank, ping)
|
||||
write_buffers = [buffer[write_coord] for buffer in buffers]
|
||||
|
||||
# assume all buffers have the same element type with input
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
buffer0.element_type,
|
||||
num_bits_per_copy=64,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
buffer0.element_type,
|
||||
num_bits_per_copy=64,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
)
|
||||
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
|
||||
thr_copy = tiled_copy.get_slice(tidx)
|
||||
|
||||
thr_write_buffer_list = [
|
||||
thr_copy.partition_D(tensor) for tensor in write_buffers
|
||||
]
|
||||
thr_read_buffer = thr_copy.partition_S(read_buffer)
|
||||
thr_clear_buffer = thr_copy.partition_D(clear_buffer)
|
||||
|
||||
thr_in = thr_copy.partition_S(local_tile_in)
|
||||
thr_out = thr_copy.partition_D(local_tile_out)
|
||||
|
||||
frg_in = cute.make_fragment_like(thr_in)
|
||||
frg_clear = cute.make_fragment_like(thr_clear_buffer)
|
||||
frg_acc = cute.make_fragment_like(thr_out)
|
||||
frg_acc.fill(0.0)
|
||||
|
||||
clear_tensor = frg_clear.load()
|
||||
frg_size = cute.size(clear_tensor.shape)
|
||||
neg0_i32_vec = cute.full_like(clear_tensor, 0x80000000, cutlass.Int32)
|
||||
neg0_f32_vec = vector.bitcast(T.vector(frg_size, T.f32()), neg0_i32_vec)
|
||||
neg0_f32_tensor = cute.TensorSSA(
|
||||
neg0_f32_vec, clear_tensor.shape, cutlass.Float32
|
||||
)
|
||||
frg_clear.store(neg0_f32_tensor)
|
||||
|
||||
cute.copy(copy_atom_load, thr_in, frg_in)
|
||||
|
||||
for thr_write_buffer in thr_write_buffer_list:
|
||||
cute.copy(copy_atom_store, frg_in, thr_write_buffer)
|
||||
|
||||
cute.copy(copy_atom_store, frg_clear, thr_clear_buffer)
|
||||
|
||||
frg_in_vector_neg0_i32 = cute.full_like(
|
||||
frg_in, cutlass.Int32(0x80000000), cutlass.Int32
|
||||
)
|
||||
frg_in_size = cute.size(frg_in.shape)
|
||||
|
||||
for i in range(WORLD_SIZE):
|
||||
read_coord = (None, 0, i)
|
||||
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
|
||||
frg_vector = frg_in.load()
|
||||
frg_vector_i32 = vector.bitcast(T.vector(frg_in_size, T.i32()), frg_vector)
|
||||
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
|
||||
while not isNotNeg0:
|
||||
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0])
|
||||
frg_vector = frg_in.load()
|
||||
frg_vector_i32 = vector.bitcast(
|
||||
T.vector(frg_in_size, T.i32()), frg_vector
|
||||
)
|
||||
isNotNeg0 = cute.all_(frg_vector_i32 != frg_in_vector_neg0_i32)
|
||||
frg_acc.store(frg_in.load() + frg_acc.load())
|
||||
|
||||
cute.copy(copy_atom_stg, frg_acc, thr_out)
|
||||
|
||||
|
||||
def run_all_reduce(rank, M, N, dtype: Type[cutlass.Numeric]):
|
||||
setup(rank, WORLD_SIZE)
|
||||
|
||||
input_tensor = torch.randn(M * N, device=f"cuda:{rank}")
|
||||
output_tensor = torch.zeros(M * N, device=f"cuda:{rank}")
|
||||
|
||||
# init tensors on different devices
|
||||
t = symm_mem.empty(
|
||||
[
|
||||
PING_PONG_SIZE,
|
||||
WORLD_SIZE,
|
||||
M * N,
|
||||
],
|
||||
device="cuda",
|
||||
).neg_()
|
||||
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
|
||||
buffer_tensor_list = [
|
||||
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 1, 0)
|
||||
for rank in range(WORLD_SIZE)
|
||||
]
|
||||
|
||||
# enable peer access
|
||||
driver.cuInit(0)
|
||||
dev_list = [driver.cuDeviceGet(i)[1] for i in range(WORLD_SIZE)]
|
||||
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
|
||||
for i in range(WORLD_SIZE):
|
||||
driver.cuCtxSetCurrent(ctx_list[i])
|
||||
for j in range(WORLD_SIZE):
|
||||
if i == j:
|
||||
continue
|
||||
driver.cuCtxEnablePeerAccess(ctx_list[j], 0)
|
||||
driver.cuCtxSetCurrent(ctx_list[rank])
|
||||
|
||||
stream = cutlass.cuda.default_stream()
|
||||
all_reduce_kernel = AllReduceKernel()
|
||||
dlpack_buffers = [from_dlpack(x, assumed_align=32) for x in buffer_tensor_list]
|
||||
all_reduce_kernel(
|
||||
rank,
|
||||
0,
|
||||
from_dlpack(input_tensor, assumed_align=32),
|
||||
from_dlpack(output_tensor, assumed_align=32),
|
||||
*dlpack_buffers,
|
||||
stream,
|
||||
)
|
||||
torch.cuda.synchronize(0)
|
||||
|
||||
# use torch api to get reference and inplace stored to input_tensor
|
||||
ref_tensor = input_tensor.clone()
|
||||
dist.all_reduce(ref_tensor, op=dist.ReduceOp.SUM)
|
||||
|
||||
# check result of output tensor, allow small error due to different accumulator datatypes
|
||||
equal_mask = (ref_tensor.cpu() - output_tensor.cpu()).abs() < 1e-4
|
||||
result = (equal_mask.sum()).item() == ref_tensor.numel()
|
||||
|
||||
if result:
|
||||
print(f"rank {rank} test passed")
|
||||
else:
|
||||
print(f"rank {rank} test failed")
|
||||
print(
|
||||
"ref_tensor[ref_tensor != output_tensor]: ",
|
||||
ref_tensor[ref_tensor != output_tensor],
|
||||
)
|
||||
print(
|
||||
"output_tensor[ref_tensor != output_tensor]: ",
|
||||
output_tensor[ref_tensor != output_tensor],
|
||||
)
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
M = 1024
|
||||
N = 1024
|
||||
|
||||
# each process will run run_all_reduce on different device
|
||||
mp.spawn(run_all_reduce, args=(M, N, cutlass.Float32), nprocs=WORLD_SIZE, join=True)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
examples/python/CuTeDSL/ampere/call_bypass_dlpack.py
Normal file
166
examples/python/CuTeDSL/ampere/call_bypass_dlpack.py
Normal file
@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import sys
|
||||
import os
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import make_ptr
|
||||
|
||||
|
||||
"""
|
||||
An Example demonstrating how to call off-the-shelf kernel by-passing dlpack protocol
|
||||
|
||||
The example shows how to directly pass pointers from PyTorch tensors to off-the-shelf kernels
|
||||
written by CuTe DSL with a thin customized wrapper jit function. The jit function will be
|
||||
compiled with inline without introducing overhead.
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/call_bypass_dlpack.py
|
||||
|
||||
|
||||
It's worth to mention that by-passing dlpack protocol can resolve the issue that dlpack doesn't handle shape-1
|
||||
mode correctly. For example, the following code will fail, because dlpack will convert the shape-1 mode
|
||||
with stride-1 which propagate alignment incorrectly.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@cute.kernel
|
||||
def fails_kernel(gX: cute.Tensor):
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
mX = gX[None, bidx, None] # We wish to retain alignment
|
||||
# assert mX.iterator.alignment == 16
|
||||
|
||||
|
||||
@cute.jit
|
||||
def fails(gX_: cute.Tensor):
|
||||
gX = gX_
|
||||
fails_kernel(gX).launch(grid=(1, 1, 1), block=(128, 1, 1))
|
||||
|
||||
|
||||
gX_torch = torch.rand((128, 1, 128), device="cuda", dtype=torch.bfloat16)
|
||||
fails(from_dlpack(gX_torch, assumed_align=16))
|
||||
|
||||
"""
|
||||
|
||||
# Add the current directory to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from tensorop_gemm import TensorOpGemm
|
||||
|
||||
|
||||
@cute.jit
|
||||
def tensor_op_gemm_wrapper(
|
||||
a_ptr: cute.Pointer,
|
||||
b_ptr: cute.Pointer,
|
||||
c_ptr: cute.Pointer,
|
||||
m: cutlass.Int32,
|
||||
n: cutlass.Int32,
|
||||
k: cutlass.Int32,
|
||||
l: cutlass.Int32,
|
||||
):
|
||||
print(f"\n[DSL INFO] Input Parameters:")
|
||||
print(f"[DSL INFO] mnkl: {(m, n, k, l)}")
|
||||
|
||||
# Assume alignment of shape to call tensorop_gemm example
|
||||
m = cute.assume(m, divby=8)
|
||||
n = cute.assume(n, divby=8)
|
||||
|
||||
# Torch is row major
|
||||
a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2))
|
||||
b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2))
|
||||
c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2))
|
||||
mA = cute.make_tensor(a_ptr, layout=a_layout)
|
||||
mB = cute.make_tensor(b_ptr, layout=b_layout)
|
||||
mC = cute.make_tensor(c_ptr, layout=c_layout)
|
||||
|
||||
print(f"[DSL INFO] mA: {mA}")
|
||||
print(f"[DSL INFO] mB: {mB}")
|
||||
print(f"[DSL INFO] mC: {mC}")
|
||||
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1)
|
||||
)
|
||||
print(f"\n[DSL INFO] Created TensorOpGemm instance")
|
||||
print(f"[DSL INFO] Input dtype: {a_ptr.value_type}")
|
||||
print(f"[DSL INFO] Output dtype: {c_ptr.value_type}")
|
||||
print(f"[DSL INFO] Accumulation dtype: {cutlass.Float32}")
|
||||
print(f"[DSL INFO] Atom layout: {(2, 2, 1)}")
|
||||
|
||||
# No need to compile inside jit function
|
||||
tensor_op_gemm(mA, mB, mC)
|
||||
print(f"\n[DSL INFO] Executed TensorOpGemm")
|
||||
|
||||
|
||||
def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
|
||||
print(f"\nRunning TensorOpGemm test with:")
|
||||
print(f"Tensor dimensions: {mnkl}")
|
||||
|
||||
# (M,K,L)
|
||||
a = torch.randn(
|
||||
mnkl[3], mnkl[2], mnkl[0], dtype=torch.float16, device="cuda"
|
||||
).permute(2, 1, 0)
|
||||
# (N,K,L)
|
||||
b = torch.randn(
|
||||
mnkl[3], mnkl[2], mnkl[1], dtype=torch.float16, device="cuda"
|
||||
).permute(2, 1, 0)
|
||||
# (N,M,L)
|
||||
c = torch.randn(
|
||||
mnkl[3], mnkl[0], mnkl[1], dtype=torch.float16, device="cuda"
|
||||
).permute(1, 2, 0)
|
||||
|
||||
print(f"Input tensor shapes:")
|
||||
print(f"a: {a.shape}, dtype: {a.dtype}")
|
||||
print(f"b: {b.shape}, dtype: {b.dtype}")
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
c_ptr = make_ptr(
|
||||
cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, *mnkl)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ref = torch.einsum("mkl,nkl->mnl", a, b)
|
||||
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
|
||||
print(f"\n[DSL INFO] Results verified successfully!")
|
||||
print(f"First few elements of result: \n{c[:3, :3, :3]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tensor_op_gemm_wrapper((512, 256, 128, 16))
|
||||
@ -226,15 +226,15 @@ def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]):
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
buffer_a = BufferWithLayout(
|
||||
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem),
|
||||
make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
|
||||
(2, 1, 0),
|
||||
)
|
||||
buffer_b = BufferWithLayout(
|
||||
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem),
|
||||
make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
|
||||
(2, 1, 0),
|
||||
)
|
||||
buffer_c = BufferWithLayout(
|
||||
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem),
|
||||
make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32),
|
||||
(2, 1, 0),
|
||||
)
|
||||
|
||||
|
||||
189
examples/python/CuTeDSL/ampere/distributed_vector_add.py
Normal file
189
examples/python/CuTeDSL/ampere/distributed_vector_add.py
Normal file
@ -0,0 +1,189 @@
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
from typing import Type
|
||||
from cuda.bindings import driver
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
# set environment variables for torch.distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12995"
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def vector_add_kernel(
|
||||
g0: cute.Tensor,
|
||||
g1: cute.Tensor,
|
||||
g2: cute.Tensor,
|
||||
g3: cute.Tensor,
|
||||
g4: cute.Tensor,
|
||||
g5: cute.Tensor,
|
||||
g6: cute.Tensor,
|
||||
g7: cute.Tensor,
|
||||
gOut: cute.Tensor,
|
||||
tv_layout: cute.Layout,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
cta_coord = (None, bidx)
|
||||
local_tile_out = gOut[cta_coord]
|
||||
local_tile_list = [
|
||||
g0[cta_coord],
|
||||
g1[cta_coord],
|
||||
g2[cta_coord],
|
||||
g3[cta_coord],
|
||||
g4[cta_coord],
|
||||
g5[cta_coord],
|
||||
g6[cta_coord],
|
||||
g7[cta_coord],
|
||||
]
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
g0.element_type,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
g0.element_type,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
)
|
||||
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
|
||||
thr_copy = tiled_copy.get_slice(tidx)
|
||||
|
||||
thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list]
|
||||
thr_out = thr_copy.partition_D(local_tile_out)
|
||||
frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list]
|
||||
frg_acc = cute.make_fragment_like(thr_out)
|
||||
frg_acc.fill(0.0)
|
||||
|
||||
for thr, frg in zip(thr_tensor_list, frg_tensor_list):
|
||||
cute.copy(copy_atom_load, thr, frg)
|
||||
tmp = frg.load() + frg_acc.load()
|
||||
frg_acc.store(tmp)
|
||||
|
||||
cute.copy(copy_atom_store, frg_acc, thr_out)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def vector_add(
|
||||
m0: cute.Tensor,
|
||||
m1: cute.Tensor,
|
||||
m2: cute.Tensor,
|
||||
m3: cute.Tensor,
|
||||
m4: cute.Tensor,
|
||||
m5: cute.Tensor,
|
||||
m6: cute.Tensor,
|
||||
m7: cute.Tensor,
|
||||
output: cute.Tensor,
|
||||
):
|
||||
# define constants for future use
|
||||
num_of_elements = cute.size(m0.layout)
|
||||
# 128 threads per block and 4 elements per thread
|
||||
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
|
||||
tile = cute.size(tv_layout.shape)
|
||||
|
||||
tensors = [m0, m1, m2, m3, m4, m5, m6, m7]
|
||||
divided_tensors = [
|
||||
cute.zipped_divide(tensor, cute.make_layout(tile)) for tensor in tensors
|
||||
]
|
||||
gOut = cute.zipped_divide(output, cute.make_layout(tile)) # ((Tile),(Rest))
|
||||
vector_add_kernel(
|
||||
divided_tensors[0],
|
||||
divided_tensors[1],
|
||||
divided_tensors[2],
|
||||
divided_tensors[3],
|
||||
divided_tensors[4],
|
||||
divided_tensors[5],
|
||||
divided_tensors[6],
|
||||
divided_tensors[7],
|
||||
gOut,
|
||||
tv_layout,
|
||||
).launch(
|
||||
grid=[num_of_elements // tile, 1, 1],
|
||||
block=[tv_layout.shape[0], 1, 1],
|
||||
)
|
||||
|
||||
|
||||
def run_vector_add(rank, world_size, M, N, dtype: Type[cutlass.Numeric]):
|
||||
setup(rank, world_size)
|
||||
|
||||
t = symm_mem.empty(M * N, device="cuda")
|
||||
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
|
||||
# get tensors from other devices from the symmetric memory
|
||||
tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)]
|
||||
tensor_list[rank].random_(0, 100)
|
||||
|
||||
# enable peer access
|
||||
driver.cuInit(0)
|
||||
dev_list = [driver.cuDeviceGet(i)[1] for i in range(world_size)]
|
||||
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
|
||||
driver.cuCtxSetCurrent(ctx_list[rank])
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
driver.cuCtxEnablePeerAccess(ctx_list[i], 0)
|
||||
|
||||
output = torch.zeros(M * N, device=f"cuda:{rank}")
|
||||
|
||||
# we have to explicitly pass each tensor instead of a list of tensors
|
||||
vector_add(
|
||||
from_dlpack(tensor_list[0], assumed_align=32),
|
||||
from_dlpack(tensor_list[1], assumed_align=32),
|
||||
from_dlpack(tensor_list[2], assumed_align=32),
|
||||
from_dlpack(tensor_list[3], assumed_align=32),
|
||||
from_dlpack(tensor_list[4], assumed_align=32),
|
||||
from_dlpack(tensor_list[5], assumed_align=32),
|
||||
from_dlpack(tensor_list[6], assumed_align=32),
|
||||
from_dlpack(tensor_list[7], assumed_align=32),
|
||||
from_dlpack(output, assumed_align=32),
|
||||
)
|
||||
|
||||
sum_tensor = sum([tensor.cpu() for tensor in tensor_list])
|
||||
|
||||
if sum(sum_tensor.cpu() == output.cpu()) == sum_tensor.numel():
|
||||
print("test passed")
|
||||
else:
|
||||
print("test failed")
|
||||
print(sum_tensor.cpu())
|
||||
print(output.cpu())
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
M = 1024
|
||||
N = 1024
|
||||
|
||||
# each process will run run_vector_add on different device
|
||||
mp.spawn(
|
||||
run_vector_add,
|
||||
args=(world_size, M, N, cutlass.Float32),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
114
examples/python/CuTeDSL/ampere/dynamic_smem_size.py
Normal file
114
examples/python/CuTeDSL/ampere/dynamic_smem_size.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass.cute as cute
|
||||
import cutlass
|
||||
|
||||
"""
|
||||
Example of automatic shared memory size computation for configuring kernel launch
|
||||
|
||||
This example demonstrates how to let the DSL automatically set shared memory
|
||||
size for a kernel launch rather explicitly configuring it at launch time,
|
||||
provided that developers are using `SmemAllocator` for all allocations.
|
||||
|
||||
Usage:
|
||||
python dynamic_smem_size.py # Show auto inference
|
||||
"""
|
||||
|
||||
|
||||
@cute.struct
|
||||
class SharedData:
|
||||
"""A struct to demonstrate shared memory allocation."""
|
||||
|
||||
values: cute.struct.MemRange[cutlass.Float32, 64] # 256 bytes
|
||||
counter: cutlass.Int32 # 4 bytes
|
||||
flag: cutlass.Int8 # 1 byte
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def kernel():
|
||||
"""
|
||||
Example kernel that allocates shared memory.
|
||||
The total allocation will be automatically calculated when smem=None.
|
||||
"""
|
||||
allocator = cutlass.utils.SmemAllocator()
|
||||
|
||||
# Allocate various types of shared memory
|
||||
shared_data = allocator.allocate(SharedData)
|
||||
raw_buffer = allocator.allocate(512, byte_alignment=64)
|
||||
int_array = allocator.allocate_array(element_type=cutlass.Int32, num_elems=128)
|
||||
tensor_smem = allocator.allocate_tensor(
|
||||
element_type=cutlass.Float16,
|
||||
layout=cute.make_layout((32, 16)),
|
||||
byte_alignment=16,
|
||||
swizzle=None,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def kernel_no_smem():
|
||||
"""
|
||||
Example kernel that does not allocates shared memory.
|
||||
The total allocation will be automatically calculated as 0 when smem=None.
|
||||
"""
|
||||
tidx, _, _ = cute.arch.block_idx()
|
||||
if tidx == 0:
|
||||
cute.printf("Hello world")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize CUDA context
|
||||
cutlass.cuda.initialize_cuda_context()
|
||||
|
||||
print("Launching kernel with auto smem size. (launch config `smem=None`)")
|
||||
|
||||
# Compile the example
|
||||
@cute.jit
|
||||
def launch_kernel1():
|
||||
k = kernel()
|
||||
k.launch(
|
||||
grid=(1, 1, 1),
|
||||
block=(1, 1, 1),
|
||||
)
|
||||
print(f"Kernel recorded internal smem usage: {k.smem_usage()}")
|
||||
|
||||
@cute.jit
|
||||
def launch_kernel2():
|
||||
k = kernel_no_smem()
|
||||
k.launch(
|
||||
grid=(1, 1, 1),
|
||||
block=(1, 1, 1),
|
||||
)
|
||||
print(f"Kernel recorded internal smem usage: {k.smem_usage()}")
|
||||
|
||||
cute.compile(launch_kernel1)
|
||||
cute.compile(launch_kernel2)
|
||||
|
||||
print("PASS")
|
||||
@ -327,7 +327,6 @@ class FlashAttentionForwardAmpere:
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[self._num_threads, 1, 1],
|
||||
smem=SharedStorage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@ -1014,13 +1013,10 @@ class FlashAttentionForwardAmpere:
|
||||
)
|
||||
|
||||
# compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e))
|
||||
acc_S_row_exp = cute.TensorSSA(
|
||||
self._exp2f(
|
||||
acc_S_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2
|
||||
),
|
||||
tuple(acc_S_row.shape),
|
||||
cutlass.Float32,
|
||||
acc_S_row_exp = cute.math.exp2(
|
||||
acc_S_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2,
|
||||
fastmath=True,
|
||||
)
|
||||
# acc_S_row_sum => f32
|
||||
acc_S_row_sum = acc_S_row_exp.reduce(
|
||||
@ -1028,9 +1024,10 @@ class FlashAttentionForwardAmpere:
|
||||
)
|
||||
# if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum.
|
||||
if cutlass.const_expr(not is_first_n_block):
|
||||
prev_minus_cur_exp = self._exp2f(
|
||||
prev_minus_cur_exp = cute.math.exp2(
|
||||
row_max_prev_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2
|
||||
- row_max_cur_row * softmax_params.softmax_scale_log2,
|
||||
fastmath=True,
|
||||
)
|
||||
acc_S_row_sum = (
|
||||
acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp
|
||||
@ -1141,26 +1138,6 @@ class FlashAttentionForwardAmpere:
|
||||
"""
|
||||
return self._threadquad_reduce(val, lambda x, y: x + y)
|
||||
|
||||
def _exp2f(
|
||||
self, x: Union[cute.TensorSSA, cutlass.Float32]
|
||||
) -> Union[cute.TensorSSA, cutlass.Float32]:
|
||||
"""exp2f calculation for both vector and scalar.
|
||||
|
||||
:param x: input value
|
||||
:type x: cute.TensorSSA or cutlass.Float32
|
||||
:return: exp2 value
|
||||
:rtype: cute.TensorSSA or cutlass.Float32
|
||||
"""
|
||||
if isinstance(x, cute.TensorSSA):
|
||||
res = cute.make_fragment(x.shape, cutlass.Float32)
|
||||
res.store(x)
|
||||
|
||||
for i in range(cute.size(x.shape)):
|
||||
res[i] = self._exp2f(res[i])
|
||||
|
||||
return res.load()
|
||||
return cute.arch.exp2(x)
|
||||
|
||||
|
||||
def run(
|
||||
dtype: Type[cutlass.Numeric],
|
||||
|
||||
@ -136,10 +136,6 @@ class SGemm:
|
||||
stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)),
|
||||
)
|
||||
|
||||
smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes(
|
||||
mB.element_type, sB_layout
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create copy layouts that will be used for asynchronous
|
||||
# global memory -> shared memory copies:
|
||||
@ -258,7 +254,6 @@ class SGemm:
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[cute.size(atoms_layout), 1, 1],
|
||||
smem=smem_size,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@ -738,14 +733,20 @@ def run(
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor, stream=current_stream)
|
||||
compiled_fn = cute.compile(
|
||||
sgemm,
|
||||
a_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
stream=current_stream,
|
||||
)
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
compiled_fn(a_tensor, b_tensor, c_tensor)
|
||||
torch.cuda.synchronize()
|
||||
print("Verifying results...")
|
||||
ref = torch.einsum("mk,nk->mn", a, b)
|
||||
@ -804,7 +805,7 @@ def run(
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
compiled_fn,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
@ -837,6 +838,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--c_major", choices=["n", "m"], default="n")
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--static_shape", action="store_true")
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
|
||||
@ -69,7 +69,7 @@ class complex:
|
||||
class SharedStorage:
|
||||
# struct elements with natural alignment
|
||||
a: cute.struct.MemRange[cutlass.Float32, 32] # array
|
||||
b: cutlass.Int64 # scalar
|
||||
b: cutlass.Int64 # saclar
|
||||
c: complex # nested struct
|
||||
# struct elements with strict alignment
|
||||
x: cute.struct.Align[
|
||||
|
||||
@ -471,7 +471,7 @@ class TensorOpGemm:
|
||||
cute.arch.sync_threads()
|
||||
# Start async loads for the first k-tile. Here we take care of the k residue
|
||||
# via if/else check along the k dimension. Because we shifted the identity tensor
|
||||
# by the residue_k and because the identity tensor is a counting tensor, the
|
||||
# by the residue_k and because the identity tensor is a coord tensor, the
|
||||
# values of any identity tensor element that is poison is less than -1
|
||||
num_smem_stages = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
@ -683,7 +683,7 @@ class TensorOpGemm:
|
||||
# Copy results of D back to shared memory
|
||||
cute.autovec_copy(tCrD, tCsC)
|
||||
|
||||
# Create counting tensor for C
|
||||
# Create coord tensor for C
|
||||
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
mcC = cute.make_identity_tensor(
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user