[Build/CI] Fix CUDA 11.8 build (#17679)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith
2025-05-22 15:13:54 -04:00
committed by GitHub
parent f8d2cc5f55
commit 6e588da0f4
9 changed files with 78 additions and 15 deletions

View File

@ -30,7 +30,11 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL)
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
else()
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
endif()
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")

View File

@ -28,4 +28,6 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit);
#endif
#endif
bool moe_permute_unpermute_supported();

View File

@ -5,6 +5,9 @@
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
@ -127,7 +130,45 @@ void moe_unpermute(
});
}
#else
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indicies,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
}
void moe_unpermute(const torch::Tensor& input,
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indicies,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
}
#endif
bool moe_permute_unpermute_supported() {
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
return true;
#else
return false;
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
}
}

View File

@ -1,6 +1,9 @@
#include "moe_permute_unpermute_kernel.h"
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
// CubKeyValueSorter definition begin
CubKeyValueSorter::CubKeyValueSorter()
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
int num_experts) {
auto tidx = threadIdx.x;
auto bidx = blockIdx.x;
auto lidx = tidx & 31;
auto widx = tidx >> 5;
auto warp_count = (blockDim.x + 31) >> 5;
auto offset = bidx * blockDim.x;
auto bound = min(offset + blockDim.x, size);
extern __shared__ int smem_expert_map[];
@ -226,4 +226,6 @@ void getMIndices(int64_t* expert_first_token_offset,
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
}
}
#endif

View File

@ -77,7 +77,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()");
// conditionally compiled so impl registration is in source file
m.def("moe_permute_unpermute_supported() -> bool");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
#endif
}

View File

@ -123,7 +123,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
}
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS groped FP8 kernels need at least CUDA 12.3
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper)
#if defined CUDA_VERSION

View File

@ -263,8 +263,11 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'; \
else \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \
fi && \
export FLASHINFER_ENABLE_AOT=1; \
fi; \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
if [ "$CUDA_MAJOR" -lt 12 ]; then \
export FLASHINFER_ENABLE_SM90=0; \
fi; \
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \
fi
COPY examples examples
@ -275,7 +278,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
. /etc/environment && \
uv pip list
# Although we build Flashinfer with AOT mode, there's still
# Even when we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
# install build dependencies for JIT compilation.
# TODO: Remove this once FlashInfer AOT wheel is fixed
@ -303,8 +306,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt
RUN --mount=type=cache,target=/root/.cache/uv \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
if [ "$CUDA_MAJOR" -ge 12 ]; then \
uv pip install --system -r requirements/dev.txt; \
fi
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \

View File

@ -13,7 +13,7 @@ import torch
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
moe_permute, moe_permute_unpermute_supported, moe_unpermute)
from vllm.platforms import current_platform
NUM_EXPERTS = [16, 64]
@ -167,6 +167,8 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
n_expert: int, ep_size: int, dtype: torch.dtype,
align_block_size: Optional[int]):
if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.")
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size)
expert_map = None

View File

@ -182,3 +182,7 @@ def moe_unpermute(
expert_first_token_offset, n_expert,
n_local_expert, topk, hidden_states)
return hidden_states
def moe_permute_unpermute_supported():
return torch.ops._moe_C.moe_permute_unpermute_supported()