diff --git a/csrc/launch_bounds_utils.h b/csrc/launch_bounds_utils.h new file mode 100644 index 0000000000..d5a8969011 --- /dev/null +++ b/csrc/launch_bounds_utils.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +// maximum blocks per SM cap +#ifndef VLLM_LAUNCH_BLOCKS_CAP + #define VLLM_LAUNCH_BLOCKS_CAP 4 +#endif + +// compile-time estimate of max threads per SM for launch bounds. +#ifndef VLLM_MAX_THREADS_PER_SM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300 + #define VLLM_MAX_THREADS_PER_SM 1536 + #else + #define VLLM_MAX_THREADS_PER_SM 2048 + #endif +#endif + +// compute the number of blocks per SM to request in __launch_bounds__ +#define VLLM_BLOCKS_DIV(VAL) (VLLM_MAX_THREADS_PER_SM / (VAL)) +#define VLLM_CLAMP_BLOCKS_PER_SM(VAL) \ + (((VAL) <= 0) \ + ? 1 \ + : (((VAL) < VLLM_LAUNCH_BLOCKS_CAP) ? (VAL) : VLLM_LAUNCH_BLOCKS_CAP)) +#define VLLM_BLOCKS_PER_SM(BLOCK_THREADS) \ + VLLM_CLAMP_BLOCKS_PER_SM(VLLM_BLOCKS_DIV(BLOCK_THREADS)) + +// runtime-time helper to compute blocks/SM +static inline int vllm_runtime_blocks_per_sm(int block_threads) { + int device = -1; + cudaGetDevice(&device); + int max_threads_per_sm = VLLM_MAX_THREADS_PER_SM; + cudaDeviceGetAttribute(&max_threads_per_sm, + cudaDevAttrMaxThreadsPerMultiProcessor, device); + int blocks = (block_threads > 0) ? (max_threads_per_sm / block_threads) : 1; + return VLLM_CLAMP_BLOCKS_PER_SM(blocks); +} diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 74fde23782..7539f836ec 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -26,6 +26,7 @@ #include "dispatch_utils.h" #include "cuda_utils.h" +#include "launch_bounds_utils.h" #include "nvfp4_utils.cuh" namespace vllm { @@ -63,7 +64,7 @@ __inline__ __device__ PackedVec compute_silu_mul(PackedVec& vec, // Use UE4M3 by default. template -__global__ void __launch_bounds__(1024, 4) +__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) { @@ -131,7 +132,8 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); - int const numBlocksPerSM = 2048 / block.x; + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); VLLM_DISPATCH_HALF_TYPES( diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index ce3ba2c19b..6d385e0dd9 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -26,12 +26,13 @@ #include "dispatch_utils.h" #include "nvfp4_utils.cuh" +#include "launch_bounds_utils.h" namespace vllm { // Use UE4M3 by default. template -__global__ void __launch_bounds__(512, 4) +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, @@ -129,7 +130,7 @@ __global__ void __launch_bounds__(512, 4) // Kernel for LARGE_M_TOPK = true (large m_topk optimized version) template -__global__ void __launch_bounds__(1024, 4) +__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, @@ -233,8 +234,9 @@ void quant_impl(void* output, void* output_scale, void* input, int const workSizePerRow = k / ELTS_PER_THREAD; int const totalWorkSize = m_topk * workSizePerRow; dim3 block(std::min(workSizePerRow, 512)); - // Get number of blocks per SM (assume we can fully utilize the SM). - int const numBlocksPerSM = 2048 / block.x; + // Get number of blocks per SM + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); dim3 grid(std::min(static_cast((totalWorkSize + block.x - 1) / block.x), multiProcessorCount * numBlocksPerSM)); while (grid.x <= multiProcessorCount && block.x > 64) { diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 0c1b9ef066..5575ee8e41 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -26,13 +26,14 @@ #include "dispatch_utils.h" #include "cuda_utils.h" +#include "launch_bounds_utils.h" #include "nvfp4_utils.cuh" namespace vllm { // Use UE4M3 by default. template -__global__ void __launch_bounds__(512, 4) +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout) { using PackedVec = PackedVec; @@ -75,8 +76,9 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, // Grid, Block size. // Each thread converts 8 values. dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); - // Get number of blocks per SM (assume we can fully utilize the SM). - int const numBlocksPerSM = 2048 / block.x; + // Get number of blocks per SM + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel.