From b7b7676d67ee517c0f97cfd245531db477606010 Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Tue, 1 Apr 2025 07:49:12 +0200 Subject: [PATCH] [Distributed] Add custom allreduce support for ROCM (#14125) Signed-off-by: ilmarkov Co-authored-by: ilmarkov --- CMakeLists.txt | 2 +- csrc/custom_all_reduce.cu | 49 +++- csrc/custom_all_reduce.cuh | 239 ++++++++++++------ csrc/custom_all_reduce_test.cu | 58 ++++- csrc/ops.h | 9 +- csrc/torch_bindings.cpp | 11 +- tests/distributed/test_custom_all_reduce.py | 2 +- tests/utils.py | 11 +- vllm/_custom_ops.py | 16 +- vllm/config.py | 6 +- .../device_communicators/custom_all_reduce.py | 91 +++---- vllm/platforms/cuda.py | 6 +- vllm/platforms/rocm.py | 33 ++- 13 files changed, 373 insertions(+), 160 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d0436aa1d0..15db4a4f4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,6 +242,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" + "csrc/custom_all_reduce.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -283,7 +284,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 123278bfed..a38d6fa24a 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, - bool full_nvlink) { + bool fully_connected) { int world_size = fake_ipc_ptrs.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); @@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, } return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, - full_nvlink); + fully_connected); } /** @@ -142,3 +142,48 @@ void register_graph_buffers(fptr_t _fa, bytes.reserve(handles.size()); fa->register_graph_buffers(bytes, offsets); } + +std::tuple allocate_shared_buffer_and_handle( + int64_t size) { + auto device_index = c10::cuda::current_device(); + at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + void* buffer; + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + + // Allocate buffer +#if defined(USE_ROCM) + // data buffers need to be "uncached" for signal on MI200 + AT_CUDA_CHECK( + hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); +#else + AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size)); +#endif + AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + + // Create IPC memhandle for the allocated buffer. + // Will use it in open_mem_handle. + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handle = + torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); + AT_CUDA_CHECK( + cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer)); + + return std::make_tuple(reinterpret_cast(buffer), handle); +} + +fptr_t open_mem_handle(torch::Tensor& mem_handle) { + void* ipc_ptr; + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()), + cudaIpcMemLazyEnablePeerAccess)); + return reinterpret_cast(ipc_ptr); +} + +void free_shared_buffer(fptr_t buffer) { + AT_CUDA_CHECK(cudaFree(reinterpret_cast(buffer))); +} diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index b9df4ed160..7150ce29b4 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -5,6 +5,10 @@ #include #include +#if defined(USE_ROCM) +typedef __hip_bfloat16 nv_bfloat16; +#endif + #include #include #include @@ -12,6 +16,7 @@ #include #include +namespace vllm { #define CUDACHECK(cmd) \ do { \ cudaError_t e = cmd; \ @@ -22,24 +27,37 @@ } \ } while (0) -namespace vllm { - +// Maximal number of blocks in allreduce kernel. constexpr int kMaxBlocks = 36; + +// Default number of blocks in allreduce kernel. +#ifndef USE_ROCM +const int defaultBlockLimit = 36; +CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR; +#else +const int defaultBlockLimit = 16; +hipPointer_attribute rangeStartAddrAttr = + HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR; +#endif + // Counter may overflow, but it's fine since unsigned int overflow is // well-defined behavior. using FlagType = uint32_t; + +// Two sets of peer counters are needed for two syncs: starting and ending an +// operation. The reason is that it's possible for peer GPU block to arrive at +// the second sync point while the current GPU block haven't passed the first +// sync point. Thus, peer GPU may write counter+1 while current GPU is busy +// waiting for counter. We use alternating counter array to avoid this +// possibility. struct Signal { - alignas(128) FlagType self_counter[kMaxBlocks][8]; - // Two sets of peer counters are needed for two syncs. The reason is that - // it's possible for peer GPU block to arrive at the second sync point while - // the current GPU block haven't passed the first sync point. Thus, peer GPU - // may write counter+1 while current GPU is busy waiting for counter. We use - // alternating counter array to avoid this possibility. - alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; + alignas(128) FlagType start[kMaxBlocks][8]; + alignas(128) FlagType end[kMaxBlocks][8]; + alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank }; struct __align__(16) RankData { - const void* __restrict__ ptrs[8]; + const void* ptrs[8]; }; struct __align__(16) RankSignals { @@ -134,27 +152,29 @@ DINLINE O downcast(array_t val) { } } +#if !defined(USE_ROCM) + static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); -#else + #else asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); -#endif + #endif } static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { FlagType flag; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); -#else + #else asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" : "=r"(flag) : "l"(flag_addr)); -#endif + #endif return flag; } @@ -170,37 +190,99 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { return flag; } -// is_start: whether this is the very first synchronization barrier. -// need_fence: whether a memory fence is needed. If true, a release-acquire -// semantic is used to enforce memory access order before and after this -// barrier. -template -DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, - int rank) { - if constexpr (!is_start) __syncthreads(); - static_assert( - !(is_start && need_fence)); // Start barrier shouldn't need fence. +// This function is meant to be used as the first synchronization in the all +// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +// prior memory accesses. Note: volatile writes will not be reordered against +// other volatile writes. +template +DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg, + int rank) { + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { - // Increment the counter. Technically we only need one counter, but we use - // multiple per block to eliminate the need to share the counter via smem. - auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank]; + auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x]; + // Write the expected counter value to peer and wait for correct value + // from peer. + st_flag_volatile(peer_counter_ptr, flag); + while (ld_flag_volatile(self_counter_ptr) != flag); + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +} + +// This function is meant to be used as the second or the final +// synchronization barrier in the all reduce kernel. If it's the final +// synchronization barrier, we don't need to make any visibility guarantees +// for prior memory accesses. +template +DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) { + __syncthreads(); + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank]; + auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x]; // Write the expected counter value to peer and wait for correct value from // peer. - auto peer_counter_ptr = - &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; - auto self_counter_ptr = - &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; - if constexpr (need_fence) { - st_flag_release(peer_counter_ptr, val); - while (ld_flag_acquire(self_counter_ptr) != val); + if constexpr (!final_sync) { + st_flag_release(peer_counter_ptr, flag); + while (ld_flag_acquire(self_counter_ptr) != flag); } else { - st_flag_volatile(peer_counter_ptr, val); - while (ld_flag_volatile(self_counter_ptr) != val); + st_flag_volatile(peer_counter_ptr, flag); + while (ld_flag_volatile(self_counter_ptr) != flag); } } - if constexpr (is_start || need_fence) __syncthreads(); + if constexpr (!final_sync) __syncthreads(); + + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; } +#else + +template +DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg, + int rank) { + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], + flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); + // wait until we got true from all ranks + while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], + __ATOMIC_RELAXED, + __MEMORY_SCOPE_DEVICE) < flag); + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +} + +template +DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) { + __syncthreads(); + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], + flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, + __MEMORY_SCOPE_SYSTEM); + // wait until we got true from all ranks + while ( + __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, + __MEMORY_SCOPE_DEVICE) < flag); + } + if constexpr (!final_sync) __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +} + +#endif + template DINLINE P packed_reduce(const P* ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); @@ -220,13 +302,13 @@ __global__ void __launch_bounds__(512, 1) // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - multi_gpu_barrier(sg, self_sg, rank); + barrier_at_start(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } - multi_gpu_barrier(sg, self_sg, rank); + barrier_at_end(sg, self_sg, rank); } template @@ -255,18 +337,20 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; - multi_gpu_barrier(sg, self_sg, rank); + barrier_at_start(sg, self_sg, rank); + // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - multi_gpu_barrier(sg, self_sg, rank); + barrier_at_end(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed // between threads that have the same tid. If thread i computes the sum of - // start + i in the first stage, then thread i also gathers start + i from all - // ranks. + // start + i in the first stage, then thread i also gathers start + i from + // all ranks. + for (int idx = tid; idx < largest_part; idx += stride) { #pragma unroll for (int i = 0; i < ngpus; i++) { @@ -287,21 +371,22 @@ class CustomAllreduce { public: int rank_; int world_size_; - bool full_nvlink_; + // Full NVLink or xGMI connection between GPUs. + bool fully_connected_; RankSignals sg_; - // Stores an map from a pointer to its peer pointters from all ranks. + // Stores an map from a pointer to its peer pointers from all ranks. std::unordered_map buffers_; Signal* self_sg_; // Stores rank data from all ranks. This is mainly for cuda graph purposes. // For cuda graph to work, all kernel arguments must be fixed during graph - // capture time. However, the peer pointers are not known during graph capture - // time. Therefore, during capture, we increment the rank data pointer and use - // that as the argument to the kernel. The kernel arguments are stored in - // graph_unreg_buffers_. The actual peer pointers will be filled in at the - // memory pointed to by the pointers in graph_unreg_buffers_ when - // the IPC handles are exchanged between ranks. + // capture time. However, the peer pointers are not known during graph + // capture time. Therefore, during capture, we increment the rank data + // pointer and use that as the argument to the kernel. The kernel arguments + // are stored in graph_unreg_buffers_. The actual peer pointers will be + // filled in at the memory pointed to by the pointers in + // graph_unreg_buffers_ when the IPC handles are exchanged between ranks. // // The overall process looks like this: // 1. Graph capture. @@ -319,17 +404,18 @@ class CustomAllreduce { * Signals are an array of ipc-enabled buffers from all ranks. * For each of the buffer, the layout is as follows: * | -- sizeof(Signal) -- | ------ a few MB ----- | - * The first section is for allreduce synchronization, and the second section - * is for storing the intermediate results required by some allreduce algos. + * The first section is for allreduce synchronization, and the second + * section is for storing the intermediate results required by some + * allreduce algos. * * Note: this class does not own any device memory. Any required buffers * are passed in from the constructor. */ CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, - int rank, int world_size, bool full_nvlink = true) + int rank, int world_size, bool fully_connected = true) : rank_(rank), world_size_(world_size), - full_nvlink_(full_nvlink), + fully_connected_(fully_connected), self_sg_(signals[rank]), d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { @@ -361,8 +447,7 @@ class CustomAllreduce { void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address - if (cuPointerGetAttribute(&base_ptr, - CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr, (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( @@ -396,11 +481,11 @@ class CustomAllreduce { // Note: when registering graph buffers, we intentionally choose to not // deduplicate the addresses. That means if the allocator reuses some - // addresses, they will be registered again. This is to account for the remote - // possibility of different allocation patterns between ranks. For example, - // rank 1 may get the same input address for the second allreduce, but rank 2 - // got a different address. IPC handles have internal reference counting - // mechanism so overhead should be small. + // addresses, they will be registered again. This is to account for the + // remote possibility of different allocation patterns between ranks. For + // example, rank 1 may get the same input address for the second allreduce, + // but rank 2 got a different address. IPC handles have internal reference + // counting mechanism so overhead should be small. void register_graph_buffers( const std::vector& handles, const std::vector>& offsets) { @@ -431,15 +516,15 @@ class CustomAllreduce { /** * Performs allreduce, assuming input has already been registered. * - * Block and grid default configs are results after careful grid search. Using - * 36 blocks give the best or close to the best runtime on the devices I - * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only - * take a small amount of SMs. Not quite sure the underlying reason, but my - * guess is that too many SMs will cause contention on NVLink bus. + * Block and grid default configs are results after careful grid search. + * Using 36 blocks give the best or close to the best runtime on the devices + * I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also + * only take a small amount of SMs. Not quite sure the underlying reason, + * but my guess is that too many SMs will cause contention on NVLink bus. */ template void allreduce(cudaStream_t stream, T* input, T* output, int size, - int threads = 512, int block_limit = 36) { + int threads = 512, int block_limit = defaultBlockLimit) { auto d = packed_t::P::size; if (size % d != 0) throw std::runtime_error( @@ -473,13 +558,11 @@ class CustomAllreduce { #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); - // TODO(hanzhi713): Threshold is different for A100 and H100. - // Add per device threshold. #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ KL(ngpus, cross_device_reduce_1stage); \ - } else if (full_nvlink_) { \ + } else if (fully_connected_) { \ if ((world_size_ <= 4 && bytes < 512 * 1024) || \ (world_size_ <= 8 && bytes < 256 * 1024)) { \ KL(ngpus, cross_device_reduce_1stage); \ @@ -497,7 +580,8 @@ class CustomAllreduce { REDUCE_CASE(8) default: throw std::runtime_error( - "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "custom allreduce only supports num gpus in (2,4,6,8). Actual " + "num " "gpus = " + std::to_string(world_size_)); } @@ -511,10 +595,11 @@ class CustomAllreduce { } } }; + /** - * To inspect PTX/SASS, copy paste this header file to compiler explorer and add - a template instantiation: + * To inspect PTX/SASS, copy paste this header file to compiler explorer and + add a template instantiation: * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, half *, int, int, int); */ -} // namespace vllm +} // namespace vllm \ No newline at end of file diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index b59ea40d98..f7f0823465 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,9 +1,9 @@ /** * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. - * export MPI_HOME=xxx + * export MPI_HOME=XXX * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o - * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi + * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. @@ -22,7 +22,15 @@ #include "cuda_profiler_api.h" #include "custom_all_reduce.cuh" #include "mpi.h" -#include "nccl.h" +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 nv_bfloat16; + #include "rccl/rccl.h" + #include "custom_all_reduce_hip.cuh" +#else + #include "nccl.h" + #include "custom_all_reduce.cuh" +#endif #define MPICHECK(cmd) \ do { \ @@ -43,16 +51,29 @@ } \ } while (0) +#ifdef USE_ROCM __global__ void dummy_kernel() { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + for (int i = 0; i < 100; i++) { + uint64_t start = wall_clock64(); + uint64_t cycles_elapsed; + do { + cycles_elapsed = wall_clock64() - start; + } while (cycles_elapsed < 100); + } for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms +} #else +__global__ void dummy_kernel() { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms + #else for (int i = 0; i < 100; i++) { long long int start = clock64(); while (clock64() - start < 150000000); // approximately 98.4ms on P40 } -#endif + #endif } +#endif template __global__ void set_data(T* data, int size, int myRank) { @@ -121,8 +142,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, * registration, they are allocated and registered together in the test for * convenience. */ +#ifdef USE_ROCM + CUDACHECK(hipExtMallocWithFlags( + (void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal), + hipDeviceMallocUncached)); +#else CUDACHECK( cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); +#endif CUDACHECK( cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); @@ -311,13 +338,18 @@ int main(int argc, char** argv) { bool performance_test = true; cudaProfilerStart(); - // Uncomment to scan through different block size configs. - // for (int threads : {256, 512, 1024}) { - // for (int block_limit = 16; block_limit < 112; block_limit += 4) { - // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, - // performance_test); - // } - // } +// Uncomment to scan through different block size configs. +// for (int threads : {256, 512, 1024}) { +// for (int block_limit = 16; block_limit < 112; block_limit += 4) { +// run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, +// performance_test); +// } +// } +#ifdef USE_ROCM + const int block_limit = 16; +#else + const int block_limit = 36; +#endif // Scan through different sizes to test performance. for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); @@ -326,4 +358,4 @@ int main(int argc, char** argv) { cudaProfilerStop(); MPICHECK(MPI_Finalize()); return EXIT_SUCCESS; -} +} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 77d1ab768d..a0985d3242 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -267,10 +267,10 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const std::optional& has_initial_state, bool silu_activation, int64_t pad_slot_id); -#ifndef USE_ROCM using fptr_t = int64_t; fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, - torch::Tensor& rank_data, int64_t rank, bool full_nvlink); + torch::Tensor& rank_data, int64_t rank, + bool fully_connected); void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); void dispose(fptr_t _fa); @@ -281,4 +281,7 @@ get_graph_buffer_ipc_meta(fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); -#endif +std::tuple allocate_shared_buffer_and_handle( + int64_t size); +int64_t open_mem_handle(torch::Tensor& mem_handle); +void free_shared_buffer(int64_t buffer); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b0a23a3693..feb3882c4d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -614,12 +614,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { &get_max_shared_memory_per_block_device_attribute); } -#ifndef USE_ROCM TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels custom_ar.def( "init_custom_ar(int[] ipc_tensors, Tensor rank_data, " - "int rank, bool full_nvlink) -> int"); + "int rank, bool fully_connected) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); custom_ar.def( "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " @@ -632,7 +631,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.def("register_buffer", ®ister_buffer); custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); custom_ar.def("register_graph_buffers", ®ister_graph_buffers); + + custom_ar.def("allocate_shared_buffer_and_handle", + &allocate_shared_buffer_and_handle); + custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle); + custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle); + + custom_ar.def("free_shared_buffer", &free_shared_buffer); } -#endif REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index bfa7d06c4d..a7ba45c9e5 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -106,7 +106,7 @@ def eager_allreduce( # communicate independently num_communication = rank // tp_size + 1 sz = 1024 - fa = get_tp_group().ca_comm + fa = get_tp_group().device_communicator.ca_comm inp = torch.ones(sz, dtype=torch.float32, device=device) out = inp for _ in range(num_communication): diff --git a/tests/utils.py b/tests/utils.py index 8915453ebd..69c96d3f06 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -612,7 +612,16 @@ def multi_process_parallel( # as compared to multiprocessing. # NOTE: We need to set working_dir for distributed tests, # otherwise we may get import errors on ray workers - ray.init(runtime_env={"working_dir": VLLM_PATH}) + # NOTE: Force ray not to use gitignore file as excluding, otherwise + # it will not move .so files to working dir. + # So we have to manually add some of large directories + os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1" + ray.init( + runtime_env={ + "working_dir": VLLM_PATH, + "excludes": + ["build", ".git", "cmake-build-*", "shellcheck", "dist"] + }) distributed_init_port = get_open_port() refs = [] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2ffcef414c..2aa99ca256 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1337,9 +1337,9 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # custom ar def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor, - rank: int, full_nvlink: bool) -> int: + rank: int, fully_connected: bool) -> int: return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, - full_nvlink) + fully_connected) def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, @@ -1369,6 +1369,18 @@ def register_graph_buffers(fa: int, handles: list[list[int]], torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) +def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]: + return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size) + + +def open_mem_handle(mem_handle: torch.Tensor): + return torch.ops._C_custom_ar.open_mem_handle(mem_handle) + + +def free_shared_buffer(ptr: int) -> None: + torch.ops._C_custom_ar.free_shared_buffer(ptr) + + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/config.py b/vllm/config.py index 1dd9359199..84b9836ef5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1606,11 +1606,13 @@ class ParallelConfig: if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() - if current_platform.is_rocm(): + device_capability = current_platform.get_device_capability() + if (current_platform.is_rocm() and device_capability is not None + and device_capability < (9, 4)): self.disable_custom_all_reduce = True logger.info( "Disabled the custom all-reduce kernel because it is not " - "supported on AMD GPUs.") + "supported on AMD GPUs older than MI300X.") if self.ray_workers_use_nsight and not self.use_ray: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 90f7f2d0f9..45fc2a7118 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import ctypes from contextlib import contextmanager from typing import List, Optional, Union @@ -10,7 +9,6 @@ from torch.distributed import ProcessGroup import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as @@ -22,7 +20,7 @@ try: ops.meta_size() custom_ar = True except Exception: - # For AMD GPUs and CPUs + # For CPUs custom_ar = False logger = init_logger(__name__) @@ -71,7 +69,9 @@ class CustomAllreduce: if not custom_ar: # disable because of missing custom allreduce library - # e.g. in a non-cuda environment + # e.g. in a non-GPU environment + logger.info("Custom allreduce is disabled because " + "of missing custom allreduce library") return self.group = group @@ -129,11 +129,10 @@ class CustomAllreduce: # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - assert current_platform.is_cuda() - from vllm.platforms.cuda import CudaPlatform - cuda_platform: CudaPlatform = current_platform - full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids) - if world_size > 2 and not full_nvlink: + assert current_platform.is_cuda_alike() + fully_connected = current_platform.is_fully_connected( + physical_device_ids) + if world_size > 2 and not fully_connected: logger.warning( "Custom allreduce is disabled because it's not supported on" " more than two PCIe-only GPUs. To silence this warning, " @@ -142,7 +141,8 @@ class CustomAllreduce: # test P2P capability, this checks software/cudaruntime support # this is expensive to compute at the first time # then we cache the result - if not _can_p2p(rank, world_size): + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + if not current_platform.is_rocm() and not _can_p2p(rank, world_size): logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " @@ -154,7 +154,8 @@ class CustomAllreduce: # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, - group=group) + group=group, + uncached=True) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) @@ -169,46 +170,11 @@ class CustomAllreduce: self.max_size = max_size self.rank = rank self.world_size = world_size - self.full_nvlink = full_nvlink + self.fully_connected = fully_connected self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, - self.full_nvlink) + self.fully_connected) ops.register_buffer(self._ptr, self.buffer_ptrs) - @staticmethod - def create_shared_buffer( - size_in_bytes: int, - group: Optional[ProcessGroup] = None) -> List[int]: - """ - Creates a shared buffer and returns a list of pointers - representing the buffer on all processes in the group. - """ - lib = CudaRTLibrary() - pointer = lib.cudaMalloc(size_in_bytes) - handle = lib.cudaIpcGetMemHandle(pointer) - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) - - pointers: List[int] = [] - for i, h in enumerate(handles): - if i == rank: - pointers.append(pointer.value) # type: ignore - else: - pointers.append( - lib.cudaIpcOpenMemHandle(h).value) # type: ignore - - return pointers - - @staticmethod - def free_shared_buffer(pointers: List[int], - group: Optional[ProcessGroup] = None, - rank: Optional[int] = None) -> None: - if rank is None: - rank = dist.get_rank(group=group) - lib = CudaRTLibrary() - lib.cudaFree(ctypes.c_void_p(pointers[rank])) - @contextmanager def capture(self): """ @@ -255,7 +221,7 @@ class CustomAllreduce: return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. - if self.world_size == 2 or self.full_nvlink: + if self.world_size == 2 or self.fully_connected: return inp_size < self.max_size return False @@ -306,3 +272,30 @@ class CustomAllreduce: def __del__(self): self.close() + + @staticmethod + def create_shared_buffer(size_in_bytes: int, + group: Optional[ProcessGroup] = None, + uncached: Optional[bool] = False) -> List[int]: + pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) + + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer) # type: ignore + else: + pointers.append(ops.open_mem_handle(h)) + return pointers + + @staticmethod + def free_shared_buffer(pointers: List[int], + group: Optional[ProcessGroup] = None, + rank: Optional[int] = 0) -> None: + if rank is None: + rank = dist.get_rank(group=group) + ops.free_shared_buffer(pointers[rank]) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ca8a2d2640..28505fca10 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -101,7 +101,7 @@ class CudaPlatformBase(Platform): return True @classmethod - def is_full_nvlink(cls, device_ids: List[int]) -> bool: + def is_fully_connected(cls, device_ids: List[int]) -> bool: raise NotImplementedError @classmethod @@ -362,7 +362,7 @@ class NvmlCudaPlatform(CudaPlatformBase): @classmethod @with_nvml_context - def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: + def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ @@ -427,7 +427,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase): return device_props.total_memory @classmethod - def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: + def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: logger.exception( "NVLink detection not possible, as context support was" " not found. Assuming no NVLink available.") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d196e24ac7..89b778c7b5 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -20,8 +20,9 @@ else: logger = init_logger(__name__) try: - from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles, - amdsmi_init, amdsmi_shut_down) + from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info, + amdsmi_get_processor_handles, amdsmi_init, + amdsmi_shut_down, amdsmi_topo_get_link_type) except ImportError as e: logger.warning("Failed to import from amdsmi with %r", e) @@ -135,10 +136,36 @@ class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) - def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + def get_device_capability(cls, + device_id: int = 0 + ) -> Optional[DeviceCapability]: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) + @staticmethod + @with_amdsmi_context + def is_fully_connected(physical_device_ids: List[int]) -> bool: + """ + Query if the set of gpus are fully connected by xgmi (1 hop) + """ + handles = [ + amdsmi_get_processor_handles()[i] for i in physical_device_ids + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + link_type = amdsmi_topo_get_link_type( + handle, peer_handle) + # type is 2 for XGMI + if link_type["hops"] != 1 or link_type["type"] != 2: + return False + except AmdSmiException as error: + logger.error("AMD 1 hop XGMI detection failed.", + exc_info=error) + return False + return True + @classmethod @with_amdsmi_context @lru_cache(maxsize=8)