Compare commits

...

6 Commits

Author SHA1 Message Date
9e011d3954 Update mistaken usage of GREATER to GREATER_EQUAL
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-07-09 13:41:55 -04:00
b24f0531e3 Fix flags
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-07-09 13:33:36 -04:00
f1fd89a9bf Remove from dockerfile
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-07-09 13:02:10 -04:00
721dcb2ebc Change cmakelists
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-07-09 12:55:41 -04:00
0204263598 Try nvcc compress-mode to reduce binary size
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-07-09 12:28:06 -04:00
4ac9c33f78 [Bugfix] Fix handling of Tensorizer arguments for LoadConfig (#20643)
Signed-off-by: Sanger Steel <sangersteel@gmail.com>
2025-07-09 15:36:37 +00:00
5 changed files with 40 additions and 64 deletions

View File

@ -171,6 +171,13 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()
#
# Set nvcc fatbin compression.
#
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "-Xfatbin" "-compress-all" "-compress-mode=size")
endif()
#
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
@ -393,7 +400,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
@ -409,7 +416,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
@ -424,7 +431,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
@ -438,7 +445,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
@ -453,7 +460,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
@ -468,7 +475,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
@ -511,7 +518,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper).
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
@ -520,7 +527,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
"if you intend on running FP8 sparse quantized models on Hopper.")
@ -532,7 +539,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# FP4 Archs and flags
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
@ -553,7 +560,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUTLASS MLA Archs and flags
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
set(SRCS
"csrc/attention/mla/cutlass_mla_kernels.cu")
set_gencode_flags_for_srcs(
@ -642,7 +649,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The machete kernels only work on hopper and require CUDA 12.0 or later.
# Only build Machete kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS)
#
# For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
@ -694,7 +701,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND MACHETE_ARCHS)
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "

View File

@ -103,25 +103,6 @@ def write_keyfile(keyfile_path: str):
f.write(encryption_params.key)
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1,
s3_endpoint="object.ord1.coreweave.com",
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(
prompts, sampling_params)
# noqa: E501
assert deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
model_ref, vllm_runner, tmp_path, model_path):

View File

@ -1003,41 +1003,27 @@ class EngineArgs:
override_attention_dtype=self.override_attention_dtype,
)
def valid_tensorizer_config_provided(self) -> bool:
"""
Checks if a parseable TensorizerConfig was passed to
self.model_loader_extra_config. It first checks if the config passed
is a dict or a TensorizerConfig object directly, and if the latter is
true (by checking that the object has TensorizerConfig's
.to_serializable() method), converts it in to a serializable dict
format
"""
if self.model_loader_extra_config:
if hasattr(self.model_loader_extra_config, "to_serializable"):
self.model_loader_extra_config = (
self.model_loader_extra_config.to_serializable())
for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]:
try:
self.model_loader_extra_config[allowed_to_pass]
return False
except KeyError:
pass
return True
def validate_tensorizer_args(self):
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig)
for key in self.model_loader_extra_config:
if key in TensorizerConfig._fields:
self.model_loader_extra_config["tensorizer_config"][
key] = self.model_loader_extra_config[key]
def create_load_config(self) -> LoadConfig:
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
if (self.load_format == "tensorizer"
and self.valid_tensorizer_config_provided()):
logger.info("Inferring Tensorizer args from %s", self.model)
self.model_loader_extra_config = {"tensorizer_dir": self.model}
else:
logger.info(
"Using Tensorizer args from --model-loader-extra-config. "
"Note that you can now simply pass the S3 directory in the "
"model tag instead of providing the JSON string.")
if self.load_format == "tensorizer":
if hasattr(self.model_loader_extra_config, "to_serializable"):
self.model_loader_extra_config = (
self.model_loader_extra_config.to_serializable())
self.model_loader_extra_config["tensorizer_config"] = {}
self.model_loader_extra_config["tensorizer_config"][
"tensorizer_dir"] = self.model
self.validate_tensorizer_args()
return LoadConfig(
load_format=self.load_format,

View File

@ -223,9 +223,11 @@ class TensorizerConfig(MutableMapping):
and re.search(r'%0\dd', self.tensorizer_uri) is not None
if self.tensorizer_dir and self.tensorizer_uri:
raise ValueError(
"Either tensorizer_dir or tensorizer_uri must be provided, "
"not both.")
logger.warning_once(
"Provided both tensorizer_dir and tensorizer_uri. "
"Inferring tensorizer_dir from tensorizer_uri as the "
"latter takes precedence.")
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "

View File

@ -43,7 +43,7 @@ class TensorizerLoader(BaseModelLoader):
else:
validate_config(load_config.model_loader_extra_config)
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config)
**load_config.model_loader_extra_config["tensorizer_config"])
def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):