[Compile] Fix Compile Warning for Ignoring MIN_BLOCK_PER_SM (#25193)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -26,6 +26,7 @@
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
@ -63,7 +64,7 @@ __inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(1024, 4)
|
||||
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out,
|
||||
uint32_t* SFout) {
|
||||
@ -131,7 +132,8 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024));
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
|
||||
@ -26,12 +26,13 @@
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "nvfp4_utils.cuh"
|
||||
#include "launch_bounds_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(512, 4)
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
@ -129,7 +130,7 @@ __global__ void __launch_bounds__(512, 4)
|
||||
|
||||
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(1024, 4)
|
||||
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
@ -233,8 +234,9 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
int const workSizePerRow = k / ELTS_PER_THREAD;
|
||||
int const totalWorkSize = m_topk * workSizePerRow;
|
||||
dim3 block(std::min(workSizePerRow, 512));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
// Get number of blocks per SM
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
|
||||
multiProcessorCount * numBlocksPerSM));
|
||||
while (grid.x <= multiProcessorCount && block.x > 64) {
|
||||
|
||||
@ -26,13 +26,14 @@
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(512, 4)
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
@ -75,8 +76,9 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
// Get number of blocks per SM
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
// Launch the cvt kernel.
|
||||
|
||||
Reference in New Issue
Block a user