Files
cutlass/examples/python/CuTeDSL/ampere/distributed_vector_add.py
2025-09-15 12:21:53 -04:00

190 lines
5.3 KiB
Python

import os
import torch
import argparse
from typing import Type
from cuda.bindings import driver
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.multiprocessing as mp
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
def setup(rank, world_size):
# set environment variables for torch.distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12995"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
@cute.kernel
def vector_add_kernel(
g0: cute.Tensor,
g1: cute.Tensor,
g2: cute.Tensor,
g3: cute.Tensor,
g4: cute.Tensor,
g5: cute.Tensor,
g6: cute.Tensor,
g7: cute.Tensor,
gOut: cute.Tensor,
tv_layout: cute.Layout,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
cta_coord = (None, bidx)
local_tile_out = gOut[cta_coord]
local_tile_list = [
g0[cta_coord],
g1[cta_coord],
g2[cta_coord],
g3[cta_coord],
g4[cta_coord],
g5[cta_coord],
g6[cta_coord],
g7[cta_coord],
]
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
g0.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
g0.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
thr_copy = tiled_copy.get_slice(tidx)
thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list]
thr_out = thr_copy.partition_D(local_tile_out)
frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list]
frg_acc = cute.make_fragment_like(thr_out)
frg_acc.fill(0.0)
for thr, frg in zip(thr_tensor_list, frg_tensor_list):
cute.copy(copy_atom_load, thr, frg)
tmp = frg.load() + frg_acc.load()
frg_acc.store(tmp)
cute.copy(copy_atom_store, frg_acc, thr_out)
@cute.jit
def vector_add(
m0: cute.Tensor,
m1: cute.Tensor,
m2: cute.Tensor,
m3: cute.Tensor,
m4: cute.Tensor,
m5: cute.Tensor,
m6: cute.Tensor,
m7: cute.Tensor,
output: cute.Tensor,
):
# define constants for future use
num_of_elements = cute.size(m0.layout)
# 128 threads per block and 4 elements per thread
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
tile = cute.size(tv_layout.shape)
tensors = [m0, m1, m2, m3, m4, m5, m6, m7]
divided_tensors = [
cute.zipped_divide(tensor, cute.make_layout(tile)) for tensor in tensors
]
gOut = cute.zipped_divide(output, cute.make_layout(tile)) # ((Tile),(Rest))
vector_add_kernel(
divided_tensors[0],
divided_tensors[1],
divided_tensors[2],
divided_tensors[3],
divided_tensors[4],
divided_tensors[5],
divided_tensors[6],
divided_tensors[7],
gOut,
tv_layout,
).launch(
grid=[num_of_elements // tile, 1, 1],
block=[tv_layout.shape[0], 1, 1],
)
def run_vector_add(rank, world_size, M, N, dtype: Type[cutlass.Numeric]):
setup(rank, world_size)
t = symm_mem.empty(M * N, device="cuda")
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
# get tensors from other devices from the symmetric memory
tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)]
tensor_list[rank].random_(0, 100)
# enable peer access
driver.cuInit(0)
dev_list = [driver.cuDeviceGet(i)[1] for i in range(world_size)]
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
driver.cuCtxSetCurrent(ctx_list[rank])
for i in range(world_size):
if i == rank:
continue
driver.cuCtxEnablePeerAccess(ctx_list[i], 0)
output = torch.zeros(M * N, device=f"cuda:{rank}")
# we have to explicitly pass each tensor instead of a list of tensors
vector_add(
from_dlpack(tensor_list[0], assumed_align=32),
from_dlpack(tensor_list[1], assumed_align=32),
from_dlpack(tensor_list[2], assumed_align=32),
from_dlpack(tensor_list[3], assumed_align=32),
from_dlpack(tensor_list[4], assumed_align=32),
from_dlpack(tensor_list[5], assumed_align=32),
from_dlpack(tensor_list[6], assumed_align=32),
from_dlpack(tensor_list[7], assumed_align=32),
from_dlpack(output, assumed_align=32),
)
sum_tensor = sum([tensor.cpu() for tensor in tensor_list])
if sum(sum_tensor.cpu() == output.cpu()) == sum_tensor.numel():
print("test passed")
else:
print("test failed")
print(sum_tensor.cpu())
print(output.cpu())
cleanup()
def main():
world_size = torch.cuda.device_count()
M = 1024
N = 1024
# each process will run run_vector_add on different device
mp.spawn(
run_vector_add,
args=(world_size, M, N, cutlass.Float32),
nprocs=world_size,
join=True,
)
return
if __name__ == "__main__":
main()