Compare commits
10 Commits
v0.9.0
...
lwilkinson
| Author | SHA1 | Date | |
|---|---|---|---|
| d3b51c9bba | |||
| 5d3b0bc39c | |||
| e7f00f3368 | |||
| ef6cb1295c | |||
| 0e8f560f48 | |||
| 09acadf893 | |||
| bf75539f9e | |||
| 7ebe64f94b | |||
| 4e00778a60 | |||
| b478b18f07 |
485
CMakeLists.txt
485
CMakeLists.txt
@ -308,333 +308,142 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
# 9.0 for latest bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_ARCHS)
|
||||
# Marlin kernels: generate and build for supported architectures
|
||||
optional_cuda_sources(
|
||||
NAME Marlin
|
||||
ARCHS "8.0;9.0+PTX"
|
||||
GEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py"
|
||||
GEN_GLOB "csrc/quantization/gptq_marlin/kernel_*.cu"
|
||||
SRCS
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||
)
|
||||
|
||||
#
|
||||
# For the Marlin kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
|
||||
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
|
||||
# AllSpark kernels
|
||||
optional_cuda_sources(
|
||||
NAME AllSpark
|
||||
ARCHS "8.0;8.6;8.7;8.9"
|
||||
SRCS
|
||||
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu"
|
||||
)
|
||||
|
||||
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require CUDA 12.0 or later
|
||||
optional_cuda_sources(
|
||||
NAME scaled_mm_c3x_sm90
|
||||
MIN_VERSION 12.0
|
||||
ARCHS "9.0a"
|
||||
SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu"
|
||||
FLAGS "-DENABLE_SCALED_MM_SM90=1"
|
||||
VERSION_MSG
|
||||
"Not building scaled_mm_c3x_sm90: CUDA Compiler version is not >= 12.0."
|
||||
"Please upgrade to CUDA 12.0 or later to run FP8 quantized models on Hopper."
|
||||
)
|
||||
|
||||
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
|
||||
RESULT_VARIABLE marlin_generation_result
|
||||
OUTPUT_VARIABLE marlin_generation_result
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
|
||||
)
|
||||
|
||||
if (NOT marlin_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Marlin generation failed."
|
||||
" Result: \"${marlin_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
|
||||
else()
|
||||
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run Marlin generate script hash" FORCE)
|
||||
message(STATUS "Marlin generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||
if (ALLSPARK_ARCHS)
|
||||
set(ALLSPARK_SRCS
|
||||
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${ALLSPARK_SRCS}"
|
||||
CUDA_ARCHS "${ALLSPARK_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
|
||||
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
# The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require CUDA 12.8 or later
|
||||
optional_cuda_sources(
|
||||
NAME scaled_mm_c3x_sm100
|
||||
MIN_VERSION 12.8
|
||||
ARCHS "10.0a;10.1a;12.0a"
|
||||
SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
FLAGS "-DENABLE_SCALED_MM_SM100=1"
|
||||
VERSION_MSG
|
||||
"Not building scaled_mm_c3x_sm100: CUDA Compiler version is not >= 12.8."
|
||||
"Please upgrade to CUDA 12.8 or later to run FP8 quantized models on Blackwell."
|
||||
)
|
||||
|
||||
#
|
||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||
# kernels for the remaining archs that are not already built for 3x.
|
||||
# (Build 8.9 for FP8)
|
||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||
"7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
|
||||
# subtract out the archs that are already built for 3x
|
||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||
if (SCALED_MM_2X_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
||||
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
||||
else()
|
||||
if (SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
||||
" for and covered by scaled_mm_c3x")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
# For the cutlass_scaled_mm kernels for Pre-hopper (c2x, i.e. CUTLASS 2.x)
|
||||
optional_cuda_sources(
|
||||
NAME scaled_mm_c2x
|
||||
ARCHS "7.5;8.0;8.9+PTX"
|
||||
SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
||||
FLAGS "-DENABLE_SCALED_MM_C2X=1"
|
||||
)
|
||||
|
||||
#
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
optional_cuda_sources(
|
||||
NAME sparse_scaled_mm_c3x
|
||||
MIN_VERSION 12.2
|
||||
ARCHS "9.0a;"
|
||||
SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu"
|
||||
FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1"
|
||||
VERSION_MSG
|
||||
"Not building sparse_scaled_mm_c3x: CUDA Compiler version is not >= 12.2."
|
||||
"Please upgrade to CUDA 12.2 or later to run FP8 sparse quantized models on Hopper."
|
||||
)
|
||||
|
||||
# FP4 Archs and flags
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
optional_cuda_sources(
|
||||
NAME NVFP4
|
||||
MIN_VERSION 12.8
|
||||
ARCHS "10.0a"
|
||||
SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
|
||||
FLAGS "-DENABLE_NVFP4=1"
|
||||
)
|
||||
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
|
||||
# Add MLA-specific include directories only to MLA source files
|
||||
set_source_files_properties(${SRCS}
|
||||
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common")
|
||||
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
|
||||
# clear MLA_ARCHS
|
||||
set(MLA_ARCHS)
|
||||
endif()
|
||||
optional_cuda_sources(
|
||||
NAME CUTLASS_MLA
|
||||
MIN_VERSION 12.8
|
||||
ARCHS "10.0a"
|
||||
SRCS "csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
FLAGS "-DENABLE_CUTLASS_MLA=1"
|
||||
)
|
||||
# Add MLA-specific include directories only to MLA source files
|
||||
set_source_files_properties(
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common"
|
||||
)
|
||||
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
||||
# to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
optional_cuda_sources(
|
||||
NAME grouped_mm_c3x
|
||||
MIN_VERSION 12.3
|
||||
ARCHS "9.0a;10.0a"
|
||||
SRCS
|
||||
"csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu"
|
||||
FLAGS "-DENABLE_CUTLASS_MOE_SM90=1"
|
||||
VERSION_MSG
|
||||
"Not building grouped_mm_c3x kernels as CUDA Compiler is less than 12.3."
|
||||
"We recommend upgrading to CUDA 12.3 or later if you intend on running FP8 quantized MoE models on Hopper."
|
||||
)
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||
# Only build Machete kernels if we are building for something compatible with sm90a
|
||||
# Machete kernels: generate and build for supported architectures
|
||||
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
|
||||
#
|
||||
# For the Machete kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MACHETE_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py)
|
||||
file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH)
|
||||
|
||||
message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
|
||||
RESULT_VARIABLE machete_generation_result
|
||||
OUTPUT_VARIABLE machete_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||
)
|
||||
|
||||
if (NOT machete_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Machete generation failed."
|
||||
" Result: \"${machete_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
|
||||
else()
|
||||
set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run machete generate script hash" FORCE)
|
||||
message(STATUS "Machete generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Machete generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
# Add machete generated sources
|
||||
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
|
||||
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
|
||||
|
||||
# forward compatible
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MACHETE_GEN_SOURCES}"
|
||||
CUDA_ARCHS "${MACHETE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
csrc/quantization/machete/machete_pytorch.cu)
|
||||
|
||||
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
|
||||
AND MACHETE_ARCHS)
|
||||
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building Machete kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
optional_cuda_sources(
|
||||
NAME Machete
|
||||
MIN_VERSION 12.0
|
||||
ARCHS "${MACHETE_ARCHS}"
|
||||
GEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py"
|
||||
GEN_PYTHONPATH_PREPEND
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/"
|
||||
GEN_GLOB "csrc/quantization/machete/generated/*.cu"
|
||||
SRCS "csrc/quantization/machete/machete_pytorch.cu"
|
||||
VERSION_MSG
|
||||
"Not building Machete kernels as CUDA Compiler version is less than 12.0."
|
||||
"We recommend upgrading to CUDA 12.0 or later to run w4a16 quantized models on Hopper."
|
||||
)
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
@ -670,72 +479,20 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
|
||||
endif()
|
||||
|
||||
# Apply gencode flags to base MOE extension sources
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
## Marlin MOE kernels: generate and include for supported architectures
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(VLLM_MOE_WNA16_SRC
|
||||
"csrc/moe/moe_wna16.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_WNA16_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
# 9.0 for latest bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MOE_MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
|
||||
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
|
||||
|
||||
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
)
|
||||
|
||||
if (NOT moe_marlin_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Marlin MOE generation failed."
|
||||
" Result: \"${moe_marlin_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
|
||||
else()
|
||||
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
|
||||
message(STATUS "Marlin MOE generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
optional_cuda_sources(
|
||||
NAME "Marlin MOE"
|
||||
ARCHS "8.0;9.0+PTX"
|
||||
GEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py"
|
||||
GEN_GLOB "csrc/moe/marlin_moe_wna16/*.cu"
|
||||
OUT_SRCS_VAR VLLM_MOE_EXT_SRC
|
||||
)
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
@ -39,6 +39,35 @@ function (run_python OUT EXPR ERR_MSG)
|
||||
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Generate CUDA sources via a Python script with caching based on script hash
|
||||
function(generate_cuda_sources NAME SCRIPT PYTHONPATH_PREPEND GLOB OUT_SRCS)
|
||||
string(TOUPPER "${NAME}" _UPPER_NAME)
|
||||
set(_CACHE_VAR "${_UPPER_NAME}_GEN_SCRIPT_HASH")
|
||||
file(MD5 "${SCRIPT}" _GEN_HASH)
|
||||
message(STATUS "${NAME} generation script hash: ${_GEN_HASH}")
|
||||
message(STATUS "Last run ${NAME} generation script hash: $CACHE{${_CACHE_VAR}}")
|
||||
if(NOT DEFINED CACHE{${_CACHE_VAR}} OR NOT $CACHE{${_CACHE_VAR}} STREQUAL "${_GEN_HASH}")
|
||||
message(STATUS "Running ${NAME} generation script: ${SCRIPT}, with PYTHONPATH prepend: ${PYTHONPATH_PREPEND}")
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${PYTHONPATH_PREPEND}:$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} "${SCRIPT}"
|
||||
RESULT_VARIABLE _GEN_RESULT
|
||||
OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${NAME}_generation.log"
|
||||
ERROR_FILE "${CMAKE_CURRENT_BINARY_DIR}/${NAME}_generation.log"
|
||||
)
|
||||
if(NOT _GEN_RESULT EQUAL 0)
|
||||
message(FATAL_ERROR "${NAME} generation failed. Result: \"${_GEN_RESULT}\"\nCheck the log for details: ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_generation.log")
|
||||
else()
|
||||
set(${_CACHE_VAR} "${_GEN_HASH}" CACHE STRING "Last run ${NAME} generation script hash" FORCE)
|
||||
message(STATUS "${NAME} generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "${NAME} generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
file(GLOB ${OUT_SRCS} "${GLOB}")
|
||||
endfunction()
|
||||
|
||||
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
|
||||
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
|
||||
macro (append_cmake_prefix_path PKG EXPR)
|
||||
@ -86,6 +115,87 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
||||
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Macro to conditionally include CUDA sources based on architecture and CUDA
|
||||
# compiler version, optionally generating sources via a Python script.
|
||||
# Usage:
|
||||
# optional_cuda_sources(
|
||||
# NAME <name>
|
||||
# [MIN_VERSION <version>]
|
||||
# ARCHS <arch1;arch2;...>
|
||||
# SRCS <source1> [<source2> ...]
|
||||
# [FLAGS <flag1> ...]
|
||||
# [VERSION_MSG <line1> [<line2> ...]]
|
||||
# [NO_ARCH_MSG <line1> [<line2> ...]]
|
||||
# [GEN_SCRIPT <path/to/generate_script.py>]
|
||||
# [GEN_PYTHONPATH_PREPEND <str_to_prepend_to_pythonpath>]
|
||||
# [GEN_GLOB <glob_pattern_for_generated_sources>]
|
||||
# This will check if `CMAKE_CUDA_COMPILER_VERSION` is greater than or equal
|
||||
# to `MIN_VERSION` and the `cuda_archs_loose_intersection` of `ARCHS` and
|
||||
# `CUDA_ARCHS` (taken from global scope) is not empty. If both these conditions
|
||||
# are met, it will:
|
||||
# 1) set the gencode flags of the sources in `SRCS` to be the
|
||||
# `cuda_archs_loose_intersection` of `ARCHS` and `CUDA_ARCHS`.
|
||||
# 2) append the sources in `SRCS` to the global `VLLM_EXT_SRC` variable.
|
||||
# 3) append the flags in `FLAGS` to the global `VLLM_GPU_FLAGS` variable.
|
||||
#
|
||||
# This will also run GEN_SCRIPT (if supplied and the hash of the script does not
|
||||
# match the latest in the cmake cache), with GEN_PYTHONPATH_PREPEND prepended to
|
||||
# to the PYTHONPATH when calling before globbing sources matching GEN_GLOB
|
||||
# and appending them alongside SRCS (with the gencodes set)
|
||||
macro(optional_cuda_sources)
|
||||
set(oneValueArgs NAME MIN_VERSION GEN_SCRIPT GEN_PYTHONPATH_PREPEND GEN_GLOB OUT_SRCS_VAR)
|
||||
set(multiValueArgs ARCHS SRCS FLAGS VERSION_MSG NO_ARCH_MSG)
|
||||
cmake_parse_arguments(OCS "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
if(NOT OCS_NAME)
|
||||
message(FATAL_ERROR "optional_cuda_sources: NAME is required")
|
||||
endif()
|
||||
if(NOT OCS_ARCHS)
|
||||
message(FATAL_ERROR "optional_cuda_sources ${OCS_NAME}: ARCHS is required")
|
||||
endif()
|
||||
if(NOT OCS_SRCS AND NOT OCS_GEN_SCRIPT)
|
||||
message(FATAL_ERROR "optional_cuda_sources ${OCS_NAME}: either SRCS or GEN_SCRIPT must be provided")
|
||||
endif()
|
||||
if(NOT OCS_MIN_VERSION)
|
||||
set(OCS_MIN_VERSION "0.0")
|
||||
endif()
|
||||
cuda_archs_loose_intersection(_OCS_ARCHS "${OCS_ARCHS}" "${CUDA_ARCHS}")
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL ${OCS_MIN_VERSION})
|
||||
if(_OCS_ARCHS)
|
||||
set(_OCS_SRCS ${OCS_SRCS})
|
||||
# Generate sources if a script is provided
|
||||
if(OCS_GEN_SCRIPT AND OCS_GEN_GLOB)
|
||||
generate_cuda_sources(
|
||||
${OCS_NAME} "${OCS_GEN_SCRIPT}" "${OCS_GEN_PYTHONPATH_PREPEND}" "${OCS_GEN_GLOB}" _OCS_GEN_SRCS)
|
||||
list(APPEND _OCS_SRCS ${_OCS_GEN_SRCS})
|
||||
endif()
|
||||
set_gencode_flags_for_srcs(SRCS "${_OCS_SRCS}" CUDA_ARCHS "${_OCS_ARCHS}")
|
||||
if(OCS_OUT_SRCS_VAR)
|
||||
list(APPEND ${OCS_OUT_SRCS_VAR} ${_OCS_SRCS})
|
||||
else()
|
||||
list(APPEND VLLM_EXT_SRC ${_OCS_SRCS})
|
||||
endif()
|
||||
if(OCS_FLAGS)
|
||||
list(APPEND VLLM_GPU_FLAGS ${OCS_FLAGS})
|
||||
endif()
|
||||
message(STATUS "Building ${OCS_NAME} for archs: ${_OCS_ARCHS}")
|
||||
else()
|
||||
if(OCS_NO_ARCH_MSG)
|
||||
list(JOIN OCS_NO_ARCH_MSG "\n " _OCS_NO_ARCH_JOINED)
|
||||
message(STATUS "${_OCS_NO_ARCH_JOINED}")
|
||||
else()
|
||||
message(STATUS "Not building ${OCS_NAME}: no compatible architectures found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
if(OCS_VERSION_MSG)
|
||||
list(JOIN OCS_VERSION_MSG "\n " _OCS_VERSION_JOINED)
|
||||
message(STATUS "${_OCS_VERSION_JOINED}")
|
||||
else()
|
||||
message(STATUS "Not building ${OCS_NAME}: CUDA Compiler version is less than ${OCS_MIN_VERSION}")
|
||||
endif()
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Get additional GPU compiler flags from torch.
|
||||
#
|
||||
|
||||
@ -1,18 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import (AssistantMessage,
|
||||
ToolMessage,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import (Function,
|
||||
FunctionCall, Tool,
|
||||
ToolCall)
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
|
||||
from vllm.transformers_utils.tokenizers.mistral import (
|
||||
make_mistral_chat_completion_request)
|
||||
|
||||
|
||||
# yapf: enable
|
||||
@pytest.mark.parametrize(
|
||||
"openai_request,expected_mistral_request",
|
||||
[(
|
||||
@ -81,107 +78,6 @@ from vllm.transformers_utils.tokenizers.mistral import (
|
||||
)
|
||||
def test_make_mistral_chat_completion_request(openai_request,
|
||||
expected_mistral_request):
|
||||
actual_request = make_mistral_chat_completion_request(
|
||||
openai_request["messages"], openai_request["tools"])
|
||||
assert actual_request == expected_mistral_request
|
||||
|
||||
|
||||
# Tool use with list content and reasoning_content
|
||||
@pytest.mark.parametrize("openai_request,expected_mistral_request", [(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in Paris?",
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"assistant",
|
||||
"reasoning_content":
|
||||
None,
|
||||
"content":
|
||||
None,
|
||||
"tool_calls": [{
|
||||
"id": "call123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
}],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Rainy"
|
||||
}],
|
||||
"name": "get_weather",
|
||||
"tool_call_id": "call123",
|
||||
},
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Gets the current weather in a city.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}],
|
||||
},
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(content="What's the weather in Paris?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="call123",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Paris"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content="Rainy",
|
||||
tool_call_id="call123",
|
||||
name="get_weather",
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_weather",
|
||||
description="Gets the current weather in a city.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)])
|
||||
def test_make_mistral_chat_completion_request_list_content(
|
||||
openai_request, expected_mistral_request):
|
||||
actual_request = make_mistral_chat_completion_request(
|
||||
openai_request["messages"], openai_request["tools"])
|
||||
assert actual_request == expected_mistral_request
|
||||
assert (make_mistral_chat_completion_request(
|
||||
openai_request["messages"],
|
||||
openai_request["tools"]) == expected_mistral_request)
|
||||
|
||||
@ -210,7 +210,9 @@ class DeepseekVL2MultiModalProcessor(
|
||||
dict(prompt=prompt, **mm_data),
|
||||
mm_kwargs,
|
||||
)
|
||||
pixel_values = processed_outputs["pixel_values"]
|
||||
target_dtype = self.info.ctx.model_config.dtype
|
||||
pixel_values = processed_outputs.pop("pixel_values").to(
|
||||
target_dtype)
|
||||
# split pixel values into patches corresponding to each image
|
||||
images_spatial_crop = processed_outputs["images_spatial_crop"]
|
||||
patches_per_image = [
|
||||
|
||||
@ -263,6 +263,11 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
mm_data,
|
||||
mm_kwargs,
|
||||
)
|
||||
if "pixel_values" in processed_outputs:
|
||||
# Cast pixel values to model dtype already here,
|
||||
# so we need to transfer less data to the GPU
|
||||
processed_outputs["pixel_values"] = processed_outputs[
|
||||
"pixel_values"].to(self.info.ctx.model_config.dtype)
|
||||
|
||||
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
|
||||
if (images := mm_data.get("images")) is not None:
|
||||
|
||||
@ -746,17 +746,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
batched_inputs: BatchedTensorInputs,
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> BatchedTensorInputs:
|
||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||
|
||||
def maybe_cast_dtype(x: torch.Tensor):
|
||||
# This mimics the behavior of transformers.BatchFeature
|
||||
return x.to(dtype=dtype) if x.is_floating_point() else x
|
||||
|
||||
json_mapped = json_map_leaves(
|
||||
# NOTE: Cast the dtype before sending it to device
|
||||
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
|
||||
lambda x: x.to(device, non_blocking=True),
|
||||
json_inputs,
|
||||
)
|
||||
|
||||
|
||||
@ -294,11 +294,8 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
inputs_embeds=None,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
dtype=self.model_runner.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**model_execute_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -156,11 +156,7 @@ def make_mistral_chat_completion_request(
|
||||
#
|
||||
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
||||
for message in messages:
|
||||
# Remove reasoning_content as unsupported by Mistral
|
||||
_ = message.pop("reasoning_content", None) # type: ignore
|
||||
|
||||
# Convert list text content to string
|
||||
if message.get("role") in ("assistant", "tool"):
|
||||
if message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
content = "\n".join(chunk.get("text") for chunk in content)
|
||||
|
||||
@ -929,11 +929,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_outputs = []
|
||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
||||
device=self.device)
|
||||
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
@ -1877,10 +1874,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
[dummy_mm_kwargs] * max_num_mm_items)
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
batched_dummy_mm_inputs, device=self.device)
|
||||
|
||||
# Run multimodal encoder.
|
||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||
|
||||
@ -652,11 +652,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_outputs = []
|
||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
||||
device=self.device)
|
||||
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
@ -1438,11 +1435,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
||||
batch_size)
|
||||
return MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
|
||||
device=self.device)
|
||||
|
||||
|
||||
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
|
||||
|
||||
@ -297,11 +297,8 @@ class CPUEncoderDecoderModelRunner(
|
||||
model_input.encoder_input_tokens,
|
||||
"encoder_positions":
|
||||
model_input.encoder_input_positions,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
"intermediate_tensors":
|
||||
intermediate_tensors,
|
||||
}
|
||||
|
||||
@ -628,10 +628,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
multimodal_kwargs = {}
|
||||
if model_input.multi_modal_kwargs is not None:
|
||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
model_input.multi_modal_kwargs, device=self.device)
|
||||
execute_model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
execute_model_kwargs.update(
|
||||
|
||||
@ -50,11 +50,8 @@ class CPUPoolingModelRunner(
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
**cross_enc_kwargs,
|
||||
"intermediate_tensors":
|
||||
intermediate_tensors,
|
||||
|
||||
@ -202,13 +202,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
encoder_input_ids=model_input.encoder_input_tokens,
|
||||
encoder_positions=model_input.encoder_input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**seqlen_agnostic_kwargs,
|
||||
)
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
@ -1845,11 +1845,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
inputs_embeds=model_input.inputs_embeds,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -70,11 +70,8 @@ class MultiStepNeuronModelRunner(NeuronModelRunner):
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
)
|
||||
|
||||
output = self.model.sample(
|
||||
|
||||
@ -49,11 +49,8 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
sampling_params=sampling_params,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
)
|
||||
|
||||
output = self.model.sample(
|
||||
|
||||
@ -378,11 +378,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
sampling_params=sampling_params,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
|
||||
or {},
|
||||
device=self.device),
|
||||
)
|
||||
elif current_platform.use_transformers_neuronx():
|
||||
# [TODO] validate on-device sampling
|
||||
@ -391,11 +389,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
|
||||
or {},
|
||||
device=self.device),
|
||||
)
|
||||
|
||||
# Compute the logits only if the on-device sampling is turned off as
|
||||
|
||||
@ -119,14 +119,10 @@ class PoolingModelRunner(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**cross_enc_kwargs,
|
||||
**seqlen_agnostic_kwargs,
|
||||
)
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
|
||||
@ -562,12 +562,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
|
||||
or {},
|
||||
device=self.device))
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
Reference in New Issue
Block a user