190 lines
5.3 KiB
Python
190 lines
5.3 KiB
Python
import os
|
|
import torch
|
|
import argparse
|
|
from typing import Type
|
|
from cuda.bindings import driver
|
|
|
|
import torch.distributed as dist
|
|
import torch.distributed._symmetric_memory as symm_mem
|
|
import torch.multiprocessing as mp
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
|
|
def setup(rank, world_size):
|
|
# set environment variables for torch.distributed environment
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "12995"
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
torch.cuda.set_device(rank)
|
|
|
|
|
|
def cleanup():
|
|
dist.destroy_process_group()
|
|
|
|
|
|
@cute.kernel
|
|
def vector_add_kernel(
|
|
g0: cute.Tensor,
|
|
g1: cute.Tensor,
|
|
g2: cute.Tensor,
|
|
g3: cute.Tensor,
|
|
g4: cute.Tensor,
|
|
g5: cute.Tensor,
|
|
g6: cute.Tensor,
|
|
g7: cute.Tensor,
|
|
gOut: cute.Tensor,
|
|
tv_layout: cute.Layout,
|
|
):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, _, _ = cute.arch.block_idx()
|
|
cta_coord = (None, bidx)
|
|
local_tile_out = gOut[cta_coord]
|
|
local_tile_list = [
|
|
g0[cta_coord],
|
|
g1[cta_coord],
|
|
g2[cta_coord],
|
|
g3[cta_coord],
|
|
g4[cta_coord],
|
|
g5[cta_coord],
|
|
g6[cta_coord],
|
|
g7[cta_coord],
|
|
]
|
|
|
|
copy_atom_load = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
g0.element_type,
|
|
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
|
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
|
)
|
|
copy_atom_store = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
g0.element_type,
|
|
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
|
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
|
)
|
|
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, tv_layout[0], tv_layout[1])
|
|
thr_copy = tiled_copy.get_slice(tidx)
|
|
|
|
thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list]
|
|
thr_out = thr_copy.partition_D(local_tile_out)
|
|
frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list]
|
|
frg_acc = cute.make_fragment_like(thr_out)
|
|
frg_acc.fill(0.0)
|
|
|
|
for thr, frg in zip(thr_tensor_list, frg_tensor_list):
|
|
cute.copy(copy_atom_load, thr, frg)
|
|
tmp = frg.load() + frg_acc.load()
|
|
frg_acc.store(tmp)
|
|
|
|
cute.copy(copy_atom_store, frg_acc, thr_out)
|
|
|
|
|
|
@cute.jit
|
|
def vector_add(
|
|
m0: cute.Tensor,
|
|
m1: cute.Tensor,
|
|
m2: cute.Tensor,
|
|
m3: cute.Tensor,
|
|
m4: cute.Tensor,
|
|
m5: cute.Tensor,
|
|
m6: cute.Tensor,
|
|
m7: cute.Tensor,
|
|
output: cute.Tensor,
|
|
):
|
|
# define constants for future use
|
|
num_of_elements = cute.size(m0.layout)
|
|
# 128 threads per block and 4 elements per thread
|
|
tv_layout = cute.make_layout(((128), (4)), stride=((1), (1)))
|
|
tile = cute.size(tv_layout.shape)
|
|
|
|
tensors = [m0, m1, m2, m3, m4, m5, m6, m7]
|
|
divided_tensors = [
|
|
cute.zipped_divide(tensor, cute.make_layout(tile)) for tensor in tensors
|
|
]
|
|
gOut = cute.zipped_divide(output, cute.make_layout(tile)) # ((Tile),(Rest))
|
|
vector_add_kernel(
|
|
divided_tensors[0],
|
|
divided_tensors[1],
|
|
divided_tensors[2],
|
|
divided_tensors[3],
|
|
divided_tensors[4],
|
|
divided_tensors[5],
|
|
divided_tensors[6],
|
|
divided_tensors[7],
|
|
gOut,
|
|
tv_layout,
|
|
).launch(
|
|
grid=[num_of_elements // tile, 1, 1],
|
|
block=[tv_layout.shape[0], 1, 1],
|
|
)
|
|
|
|
|
|
def run_vector_add(rank, world_size, M, N, dtype: Type[cutlass.Numeric]):
|
|
setup(rank, world_size)
|
|
|
|
t = symm_mem.empty(M * N, device="cuda")
|
|
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
|
|
# get tensors from other devices from the symmetric memory
|
|
tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)]
|
|
tensor_list[rank].random_(0, 100)
|
|
|
|
# enable peer access
|
|
driver.cuInit(0)
|
|
dev_list = [driver.cuDeviceGet(i)[1] for i in range(world_size)]
|
|
ctx_list = [driver.cuDevicePrimaryCtxRetain(dev)[1] for dev in dev_list]
|
|
driver.cuCtxSetCurrent(ctx_list[rank])
|
|
for i in range(world_size):
|
|
if i == rank:
|
|
continue
|
|
driver.cuCtxEnablePeerAccess(ctx_list[i], 0)
|
|
|
|
output = torch.zeros(M * N, device=f"cuda:{rank}")
|
|
|
|
# we have to explicitly pass each tensor instead of a list of tensors
|
|
vector_add(
|
|
from_dlpack(tensor_list[0], assumed_align=32),
|
|
from_dlpack(tensor_list[1], assumed_align=32),
|
|
from_dlpack(tensor_list[2], assumed_align=32),
|
|
from_dlpack(tensor_list[3], assumed_align=32),
|
|
from_dlpack(tensor_list[4], assumed_align=32),
|
|
from_dlpack(tensor_list[5], assumed_align=32),
|
|
from_dlpack(tensor_list[6], assumed_align=32),
|
|
from_dlpack(tensor_list[7], assumed_align=32),
|
|
from_dlpack(output, assumed_align=32),
|
|
)
|
|
|
|
sum_tensor = sum([tensor.cpu() for tensor in tensor_list])
|
|
|
|
if sum(sum_tensor.cpu() == output.cpu()) == sum_tensor.numel():
|
|
print("test passed")
|
|
else:
|
|
print("test failed")
|
|
print(sum_tensor.cpu())
|
|
print(output.cpu())
|
|
|
|
cleanup()
|
|
|
|
|
|
def main():
|
|
|
|
world_size = torch.cuda.device_count()
|
|
M = 1024
|
|
N = 1024
|
|
|
|
# each process will run run_vector_add on different device
|
|
mp.spawn(
|
|
run_vector_add,
|
|
args=(world_size, M, N, cutlass.Float32),
|
|
nprocs=world_size,
|
|
join=True,
|
|
)
|
|
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|