Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -1,46 +0,0 @@
|
||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
||||
# following differences:
|
||||
# - ruff line length is overridden to 88
|
||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
# pycodestyle
|
||||
"E",
|
||||
# Pyflakes
|
||||
"F",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# flake8-logging-format
|
||||
"G",
|
||||
]
|
||||
ignore = [
|
||||
# star imports
|
||||
"F405", "F403",
|
||||
# lambda expression assignment
|
||||
"E731",
|
||||
# Loop control variable not used within loop body
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Can remove once 3.10+ is the minimum Python version
|
||||
"UP007",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
@ -6,28 +6,16 @@ default_stages:
|
||||
- manual # Run in CI
|
||||
exclude: 'vllm/third_party/.*'
|
||||
repos:
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
# Keep the same list from yapfignore here to avoid yapf failing without any inputs
|
||||
exclude: '(.buildkite|benchmarks|build|examples)/.*'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github, --fix]
|
||||
- id: ruff-format
|
||||
files: ^(.buildkite|benchmarks|examples)/.*
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.35.5
|
||||
hooks:
|
||||
- id: typos
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v20.1.3
|
||||
hooks:
|
||||
|
||||
@ -2,9 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
@ -5,9 +5,9 @@ import time
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
from benchmark_utils import TimeCollector
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
|
||||
@ -37,14 +37,13 @@ from typing import Optional
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from backend_request_func import (
|
||||
ASYNC_REQUEST_FUNCS,
|
||||
RequestFuncInput,
|
||||
RequestFuncOutput,
|
||||
)
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
||||
# following differences:
|
||||
# - ruff line length is overridden to 88
|
||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
# pycodestyle
|
||||
"E",
|
||||
# Pyflakes
|
||||
"F",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# flake8-logging-format
|
||||
"G",
|
||||
]
|
||||
ignore = [
|
||||
# star imports
|
||||
"F405", "F403",
|
||||
# lambda expression assignment
|
||||
"E731",
|
||||
# Loop control variable not used within loop body
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Can remove once 3.10+ is the minimum Python version
|
||||
"UP007",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["vllm"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
@ -16,7 +16,7 @@ import shutil
|
||||
|
||||
from torch.utils.hipify.hipify_python import hipify
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Project directory where all the source + include files live.
|
||||
@ -34,15 +34,14 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
# Source files to convert.
|
||||
parser.add_argument("sources",
|
||||
help="Source files to hipify.",
|
||||
nargs="*",
|
||||
default=[])
|
||||
parser.add_argument(
|
||||
"sources", help="Source files to hipify.", nargs="*", default=[]
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Limit include scope to project_dir only
|
||||
includes = [os.path.join(args.project_dir, '*')]
|
||||
includes = [os.path.join(args.project_dir, "*")]
|
||||
|
||||
# Get absolute path for all source files.
|
||||
extra_files = [os.path.abspath(s) for s in args.sources]
|
||||
@ -51,25 +50,31 @@ if __name__ == '__main__':
|
||||
# The directory might already exist to hold object files so we ignore that.
|
||||
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
||||
|
||||
hipify_result = hipify(project_directory=args.project_dir,
|
||||
output_directory=args.output_dir,
|
||||
header_include_dirs=[],
|
||||
includes=includes,
|
||||
extra_files=extra_files,
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,
|
||||
hipify_extra_files_only=True)
|
||||
hipify_result = hipify(
|
||||
project_directory=args.project_dir,
|
||||
output_directory=args.output_dir,
|
||||
header_include_dirs=[],
|
||||
includes=includes,
|
||||
extra_files=extra_files,
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,
|
||||
hipify_extra_files_only=True,
|
||||
)
|
||||
|
||||
hipified_sources = []
|
||||
for source in args.sources:
|
||||
s_abs = os.path.abspath(source)
|
||||
hipified_s_abs = (hipify_result[s_abs].hipified_path if
|
||||
(s_abs in hipify_result
|
||||
and hipify_result[s_abs].hipified_path is not None)
|
||||
else s_abs)
|
||||
hipified_s_abs = (
|
||||
hipify_result[s_abs].hipified_path
|
||||
if (
|
||||
s_abs in hipify_result
|
||||
and hipify_result[s_abs].hipified_path is not None
|
||||
)
|
||||
else s_abs
|
||||
)
|
||||
hipified_sources.append(hipified_s_abs)
|
||||
|
||||
assert (len(hipified_sources) == len(args.sources))
|
||||
assert len(hipified_sources) == len(args.sources)
|
||||
|
||||
# Print hipified source files.
|
||||
print("\n".join(hipified_sources))
|
||||
|
||||
@ -27,7 +27,7 @@ VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
|
||||
**{
|
||||
VLLMDataType.u4b8: "u4b8",
|
||||
VLLMDataType.u8b128: "u8b128",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||
@ -35,7 +35,7 @@ VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||
**{
|
||||
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
||||
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
||||
@ -43,7 +43,7 @@ VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
||||
**{
|
||||
VLLMDataType.u4b8: 4,
|
||||
VLLMDataType.u8b128: 8,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||
@ -67,15 +67,13 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||
DataType.f32: "at::ScalarType::Float",
|
||||
}
|
||||
|
||||
VLLMKernelScheduleTag: dict[Union[
|
||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
MixedInputKernelScheduleType.TmaWarpSpecialized:
|
||||
"cutlass::gemm::KernelTmaWarpSpecialized",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||
}
|
||||
}
|
||||
VLLMKernelScheduleTag: dict[
|
||||
Union[MixedInputKernelScheduleType, KernelScheduleType], str
|
||||
] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||
},
|
||||
}
|
||||
|
||||
@ -17,25 +17,30 @@ FILE_HEAD = """
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );")
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f"
|
||||
"vllm::kU4",
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
"vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f",
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
@ -58,11 +63,12 @@ def generate_new_kernels():
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||
):
|
||||
# act order case only support gptq-int4 and gptq-int8
|
||||
if group_blocks == 0 and scalar_type not in [
|
||||
"vllm::kU4B8", "vllm::kU8B128"
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
]:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
|
||||
@ -17,28 +17,32 @@ FILE_HEAD = """
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );")
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f"
|
||||
"vllm::kU4",
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
"vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f",
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
||||
(128, 64, 128)]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
@ -59,11 +63,12 @@ def generate_new_kernels():
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||
):
|
||||
# act order case only support gptq-int4 and gptq-int8
|
||||
if group_blocks == 0 and scalar_type not in [
|
||||
"vllm::kU4B8", "vllm::kU8B128"
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
]:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
@ -93,8 +98,7 @@ def generate_new_kernels():
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
is_zp_float_list = [False]
|
||||
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
|
||||
group_blocks == 4:
|
||||
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
|
||||
# HQQ (is_zp_float = true) only supports
|
||||
# 4bit quantization and fp16
|
||||
is_zp_float_list.append(True)
|
||||
|
||||
@ -12,18 +12,24 @@ from functools import reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import jinja2
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
MixedInputKernelScheduleType,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType, VLLMDataType,
|
||||
VLLMDataTypeNames,
|
||||
VLLMDataTypeSize, VLLMDataTypeTag,
|
||||
VLLMDataTypeTorchDataTypeTag,
|
||||
VLLMDataTypeVLLMScalarTypeTag,
|
||||
VLLMKernelScheduleTag)
|
||||
from vllm_cutlass_library_extension import (
|
||||
DataType,
|
||||
EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
MixedInputKernelScheduleType,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType,
|
||||
VLLMDataType,
|
||||
VLLMDataTypeNames,
|
||||
VLLMDataTypeSize,
|
||||
VLLMDataTypeTag,
|
||||
VLLMDataTypeTorchDataTypeTag,
|
||||
VLLMDataTypeVLLMScalarTypeTag,
|
||||
VLLMKernelScheduleTag,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -286,18 +292,23 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
tile_shape = (
|
||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||
)
|
||||
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
|
||||
f"x{schedule_config.cluster_shape_mnk[1]}" +
|
||||
f"x{schedule_config.cluster_shape_mnk[2]}")
|
||||
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
|
||||
.split("::")[-1]
|
||||
epilogue_schedule = EpilogueScheduleTag[
|
||||
schedule_config.epilogue_schedule].split("::")[-1]
|
||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
|
||||
.split("::")[-1]
|
||||
cluster_shape = (
|
||||
f"{schedule_config.cluster_shape_mnk[0]}"
|
||||
+ f"x{schedule_config.cluster_shape_mnk[1]}"
|
||||
+ f"x{schedule_config.cluster_shape_mnk[2]}"
|
||||
)
|
||||
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
|
||||
"::"
|
||||
)[-1]
|
||||
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
|
||||
"::"
|
||||
)[-1]
|
||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
|
||||
|
||||
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
|
||||
f"_{epilogue_schedule}_{tile_scheduler}")
|
||||
return (
|
||||
f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
|
||||
+ f"_{epilogue_schedule}_{tile_scheduler}"
|
||||
)
|
||||
|
||||
|
||||
# mostly unique shorter sch_sig
|
||||
@ -316,18 +327,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
|
||||
# unique type_name
|
||||
def generate_type_signature(kernel_types: TypeConfig):
|
||||
return str("".join([
|
||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]))
|
||||
return str(
|
||||
"".join(
|
||||
[
|
||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def generate_type_option_name(kernel_types: TypeConfig):
|
||||
return ", ".join([
|
||||
f"{field.name.replace('b_', 'with_')+'_type'}=" +
|
||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
])
|
||||
return ", ".join(
|
||||
[
|
||||
f"{field.name.replace('b_', 'with_') + '_type'}="
|
||||
+ VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
@ -335,7 +352,6 @@ def is_power_of_two(n):
|
||||
|
||||
|
||||
def to_cute_constant(value: list[int]):
|
||||
|
||||
def _to_cute_constant(value: int):
|
||||
if is_power_of_two(value):
|
||||
return f"_{value}"
|
||||
@ -350,11 +366,11 @@ def to_cute_constant(value: list[int]):
|
||||
|
||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||
# Use dict over set for deterministic ordering
|
||||
return list({
|
||||
sch: None
|
||||
for impl_config in impl_configs
|
||||
for sch in impl_config.schedules
|
||||
}.keys())
|
||||
return list(
|
||||
{
|
||||
sch: None for impl_config in impl_configs for sch in impl_config.schedules
|
||||
}.keys()
|
||||
)
|
||||
|
||||
|
||||
def unsigned_type_with_bitwidth(num_bits):
|
||||
@ -380,7 +396,7 @@ template_globals = {
|
||||
"gen_type_sig": generate_type_signature,
|
||||
"unique_schedules": unique_schedules,
|
||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||
"gen_type_option_name": generate_type_option_name
|
||||
"gen_type_option_name": generate_type_option_name,
|
||||
}
|
||||
|
||||
|
||||
@ -398,23 +414,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||
sources = []
|
||||
|
||||
sources.append((
|
||||
"machete_mm_dispatch",
|
||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||
))
|
||||
sources.append(
|
||||
(
|
||||
"machete_mm_dispatch",
|
||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||
)
|
||||
)
|
||||
|
||||
prepack_types = []
|
||||
for impl_config in impl_configs:
|
||||
convert_type = impl_config.types.a \
|
||||
if impl_config.types.b_group_scale == DataType.void \
|
||||
else impl_config.types.b_group_scale
|
||||
convert_type = (
|
||||
impl_config.types.a
|
||||
if impl_config.types.b_group_scale == DataType.void
|
||||
else impl_config.types.b_group_scale
|
||||
)
|
||||
prepack_types.append(
|
||||
PrepackTypeConfig(
|
||||
a=impl_config.types.a,
|
||||
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
||||
convert=convert_type,
|
||||
accumulator=impl_config.types.accumulator,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||
# For now, we can just use the first accumulator type seen since
|
||||
@ -430,10 +451,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||
unique_prepack_types.append(prepack_type)
|
||||
prepack_types_seen.add(key)
|
||||
|
||||
sources.append((
|
||||
"machete_prepack",
|
||||
prepack_dispatch_template.render(types=unique_prepack_types, ),
|
||||
))
|
||||
sources.append(
|
||||
(
|
||||
"machete_prepack",
|
||||
prepack_dispatch_template.render(
|
||||
types=unique_prepack_types,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Split up impls across files
|
||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||
@ -466,10 +491,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||
|
||||
for part, file_impls in enumerate(files_impls):
|
||||
sources.append((
|
||||
f"machete_mm_impl_part{part+1}",
|
||||
mm_impl_template.render(impl_configs=file_impls),
|
||||
))
|
||||
sources.append(
|
||||
(
|
||||
f"machete_mm_impl_part{part + 1}",
|
||||
mm_impl_template.render(impl_configs=file_impls),
|
||||
)
|
||||
)
|
||||
|
||||
return sources
|
||||
|
||||
@ -514,8 +541,7 @@ def generate():
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config,
|
||||
**sch_common_params)) # type: ignore
|
||||
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
|
||||
for cond, tile_config in default_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
@ -541,14 +567,18 @@ def generate():
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||
for a in (DataType.f16, DataType.bf16))
|
||||
)
|
||||
for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||
for a in (DataType.f16, DataType.bf16)
|
||||
)
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(GPTQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic))
|
||||
for x in zip(
|
||||
GPTQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic),
|
||||
)
|
||||
]
|
||||
|
||||
AWQ_kernel_type_configs = list(
|
||||
@ -561,14 +591,18 @@ def generate():
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
) for b in (DataType.u4, DataType.u8)
|
||||
for a in (DataType.f16, DataType.bf16))
|
||||
)
|
||||
for b in (DataType.u4, DataType.u8)
|
||||
for a in (DataType.f16, DataType.bf16)
|
||||
)
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(AWQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic))
|
||||
for x in zip(
|
||||
AWQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic),
|
||||
)
|
||||
]
|
||||
|
||||
# TODO: Support W4A8 when ready
|
||||
|
||||
@ -33,8 +33,11 @@ def auto_mock(module, attr, max_mocks=50):
|
||||
try:
|
||||
# First treat attr as an attr, then as a submodule
|
||||
with patch("importlib.metadata.version", return_value="0.0.0"):
|
||||
return getattr(importlib.import_module(module), attr,
|
||||
importlib.import_module(f"{module}.{attr}"))
|
||||
return getattr(
|
||||
importlib.import_module(module),
|
||||
attr,
|
||||
importlib.import_module(f"{module}.{attr}"),
|
||||
)
|
||||
except importlib.metadata.PackageNotFoundError as e:
|
||||
raise e
|
||||
except ModuleNotFoundError as e:
|
||||
@ -42,7 +45,8 @@ def auto_mock(module, attr, max_mocks=50):
|
||||
sys.modules[e.name] = PydanticMagicMock()
|
||||
|
||||
raise ImportError(
|
||||
f"Failed to import {module}.{attr} after mocking {max_mocks} imports")
|
||||
f"Failed to import {module}.{attr} after mocking {max_mocks} imports"
|
||||
)
|
||||
|
||||
|
||||
latency = auto_mock("vllm.benchmarks", "latency")
|
||||
@ -61,9 +65,7 @@ class MarkdownFormatter(HelpFormatter):
|
||||
"""Custom formatter that generates markdown for argument groups."""
|
||||
|
||||
def __init__(self, prog, starting_heading_level=3):
|
||||
super().__init__(prog,
|
||||
max_help_position=float('inf'),
|
||||
width=float('inf'))
|
||||
super().__init__(prog, max_help_position=float("inf"), width=float("inf"))
|
||||
self._section_heading_prefix = "#" * starting_heading_level
|
||||
self._argument_heading_prefix = "#" * (starting_heading_level + 1)
|
||||
self._markdown_output = []
|
||||
@ -85,23 +87,19 @@ class MarkdownFormatter(HelpFormatter):
|
||||
|
||||
def add_arguments(self, actions):
|
||||
for action in actions:
|
||||
if (len(action.option_strings) == 0
|
||||
or "--help" in action.option_strings):
|
||||
if len(action.option_strings) == 0 or "--help" in action.option_strings:
|
||||
continue
|
||||
|
||||
option_strings = f'`{"`, `".join(action.option_strings)}`'
|
||||
option_strings = f"`{'`, `'.join(action.option_strings)}`"
|
||||
heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n"
|
||||
self._markdown_output.append(heading_md)
|
||||
|
||||
if choices := action.choices:
|
||||
choices = f'`{"`, `".join(str(c) for c in choices)}`'
|
||||
self._markdown_output.append(
|
||||
f"Possible choices: {choices}\n\n")
|
||||
elif ((metavar := action.metavar)
|
||||
and isinstance(metavar, (list, tuple))):
|
||||
metavar = f'`{"`, `".join(str(m) for m in metavar)}`'
|
||||
self._markdown_output.append(
|
||||
f"Possible choices: {metavar}\n\n")
|
||||
choices = f"`{'`, `'.join(str(c) for c in choices)}`"
|
||||
self._markdown_output.append(f"Possible choices: {choices}\n\n")
|
||||
elif (metavar := action.metavar) and isinstance(metavar, (list, tuple)):
|
||||
metavar = f"`{'`, `'.join(str(m) for m in metavar)}`"
|
||||
self._markdown_output.append(f"Possible choices: {metavar}\n\n")
|
||||
|
||||
if action.help:
|
||||
self._markdown_output.append(f"{action.help}\n\n")
|
||||
@ -143,24 +141,17 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
|
||||
# Create parsers to document
|
||||
parsers = {
|
||||
"engine_args":
|
||||
create_parser(EngineArgs.add_cli_args),
|
||||
"async_engine_args":
|
||||
create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True),
|
||||
"serve":
|
||||
create_parser(cli_args.make_arg_parser),
|
||||
"chat":
|
||||
create_parser(ChatCommand.add_cli_args),
|
||||
"complete":
|
||||
create_parser(CompleteCommand.add_cli_args),
|
||||
"bench_latency":
|
||||
create_parser(latency.add_cli_args),
|
||||
"bench_throughput":
|
||||
create_parser(throughput.add_cli_args),
|
||||
"bench_serve":
|
||||
create_parser(serve.add_cli_args),
|
||||
"run-batch":
|
||||
create_parser(run_batch.make_arg_parser),
|
||||
"engine_args": create_parser(EngineArgs.add_cli_args),
|
||||
"async_engine_args": create_parser(
|
||||
AsyncEngineArgs.add_cli_args, async_args_only=True
|
||||
),
|
||||
"serve": create_parser(cli_args.make_arg_parser),
|
||||
"chat": create_parser(ChatCommand.add_cli_args),
|
||||
"complete": create_parser(CompleteCommand.add_cli_args),
|
||||
"bench_latency": create_parser(latency.add_cli_args),
|
||||
"bench_throughput": create_parser(throughput.add_cli_args),
|
||||
"bench_serve": create_parser(serve.add_cli_args),
|
||||
"run-batch": create_parser(run_batch.make_arg_parser),
|
||||
}
|
||||
|
||||
# Generate documentation for each parser
|
||||
|
||||
@ -11,7 +11,7 @@ import regex as re
|
||||
logger = logging.getLogger("mkdocs")
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
||||
ROOT_DIR_RELATIVE = '../../../../..'
|
||||
ROOT_DIR_RELATIVE = "../../../../.."
|
||||
EXAMPLE_DIR = ROOT_DIR / "examples"
|
||||
EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples"
|
||||
|
||||
@ -36,7 +36,7 @@ def fix_case(text: str) -> str:
|
||||
r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16
|
||||
}
|
||||
for pattern, repl in subs.items():
|
||||
text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE)
|
||||
text = re.sub(rf"\b{pattern}\b", repl, text, flags=re.IGNORECASE)
|
||||
return text
|
||||
|
||||
|
||||
@ -58,7 +58,8 @@ class Example:
|
||||
determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file.
|
||||
determine_title() -> str: Determines the title of the document.
|
||||
generate() -> str: Generates the documentation content.
|
||||
""" # noqa: E501
|
||||
""" # noqa: E501
|
||||
|
||||
path: Path
|
||||
category: str = None
|
||||
main_file: Path = field(init=False)
|
||||
@ -84,9 +85,8 @@ class Example:
|
||||
Markdown file found in the directory.
|
||||
Raises:
|
||||
IndexError: If no Markdown files are found in the directory.
|
||||
""" # noqa: E501
|
||||
return self.path if self.path.is_file() else list(
|
||||
self.path.glob("*.md")).pop()
|
||||
""" # noqa: E501
|
||||
return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop()
|
||||
|
||||
def determine_other_files(self) -> list[Path]:
|
||||
"""
|
||||
@ -98,7 +98,7 @@ class Example:
|
||||
|
||||
Returns:
|
||||
list[Path]: A list of Path objects representing the other files in the directory.
|
||||
""" # noqa: E501
|
||||
""" # noqa: E501
|
||||
if self.path.is_file():
|
||||
return []
|
||||
is_other_file = lambda file: file.is_file() and file != self.main_file
|
||||
@ -109,9 +109,9 @@ class Example:
|
||||
# Specify encoding for building on Windows
|
||||
with open(self.main_file, encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
match = re.match(r'^#\s+(?P<title>.+)$', first_line)
|
||||
match = re.match(r"^#\s+(?P<title>.+)$", first_line)
|
||||
if match:
|
||||
return match.group('title')
|
||||
return match.group("title")
|
||||
return fix_case(self.path.stem.replace("_", " ").title())
|
||||
|
||||
def fix_relative_links(self, content: str) -> str:
|
||||
@ -127,7 +127,7 @@ class Example:
|
||||
"""
|
||||
# Regex to match markdown links [text](relative_path)
|
||||
# This matches links that don't start with http, https, ftp, or #
|
||||
link_pattern = r'\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)'
|
||||
link_pattern = r"\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)"
|
||||
|
||||
def replace_link(match):
|
||||
link_text = match.group(1)
|
||||
@ -137,7 +137,7 @@ class Example:
|
||||
gh_file = (self.main_file.parent / relative_path).resolve()
|
||||
gh_file = gh_file.relative_to(ROOT_DIR)
|
||||
|
||||
return f'[{link_text}](gh-file:{gh_file})'
|
||||
return f"[{link_text}](gh-file:{gh_file})"
|
||||
|
||||
return re.sub(link_pattern, replace_link, content)
|
||||
|
||||
@ -150,9 +150,11 @@ class Example:
|
||||
code_fence = "``````"
|
||||
|
||||
if self.is_code:
|
||||
content += (f"{code_fence}{self.main_file.suffix[1:]}\n"
|
||||
f'--8<-- "{self.main_file}"\n'
|
||||
f"{code_fence}\n")
|
||||
content += (
|
||||
f"{code_fence}{self.main_file.suffix[1:]}\n"
|
||||
f'--8<-- "{self.main_file}"\n'
|
||||
f"{code_fence}\n"
|
||||
)
|
||||
else:
|
||||
with open(self.main_file) as f:
|
||||
# Skip the title from md snippets as it's been included above
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Literal
|
||||
|
||||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa
|
||||
if os.getenv('READTHEDOCS_VERSION_TYPE') == "tag":
|
||||
if os.getenv("READTHEDOCS_VERSION_TYPE") == "tag":
|
||||
# remove the warning banner if the version is a tagged release
|
||||
mkdocs_dir = Path(__file__).parent.parent
|
||||
announcement_path = mkdocs_dir / "overrides/main.html"
|
||||
|
||||
@ -25,8 +25,9 @@ from mkdocs.structure.files import Files
|
||||
from mkdocs.structure.pages import Page
|
||||
|
||||
|
||||
def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig,
|
||||
files: Files) -> str:
|
||||
def on_page_markdown(
|
||||
markdown: str, *, page: Page, config: MkDocsConfig, files: Files
|
||||
) -> str:
|
||||
"""
|
||||
Custom MkDocs plugin hook to rewrite special GitHub reference links
|
||||
in Markdown.
|
||||
@ -92,11 +93,11 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig,
|
||||
Example:
|
||||
[My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123)
|
||||
"""
|
||||
url = f'{urls[match.group("type")]}/{match.group("path")}'
|
||||
url = f"{urls[match.group('type')]}/{match.group('path')}"
|
||||
if fragment := match.group("fragment"):
|
||||
url += f"#{fragment}"
|
||||
|
||||
return f'[{gh_icon} {match.group("title")}]({url})'
|
||||
return f"[{gh_icon} {match.group('title')}]({url})"
|
||||
|
||||
def replace_auto_link(match: re.Match) -> str:
|
||||
"""
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
||||
# following differences:
|
||||
# - ruff line length is overridden to 88
|
||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
exclude = [
|
||||
# External file, leaving license intact
|
||||
"examples/other/fp8/quantizer/quantize.py",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
# pycodestyle
|
||||
"E",
|
||||
# Pyflakes
|
||||
"F",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# flake8-logging-format
|
||||
"G",
|
||||
]
|
||||
ignore = [
|
||||
# star imports
|
||||
"F405", "F403",
|
||||
# lambda expression assignment
|
||||
"E731",
|
||||
# Loop control variable not used within loop body
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Can remove once 3.10+ is the minimum Python version
|
||||
"UP007",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["vllm"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
127
pyproject.toml
127
pyproject.toml
@ -52,27 +52,106 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi
|
||||
where = ["."]
|
||||
include = ["vllm*"]
|
||||
|
||||
[tool.yapfignore]
|
||||
ignore_patterns = [
|
||||
".buildkite/**",
|
||||
"benchmarks/**",
|
||||
"build/**",
|
||||
"examples/**",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
# Allow lines to be as long as 80.
|
||||
line-length = 80
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
# Python 3.8 typing - skip V0 code
|
||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
||||
# TEMPORARY! These ignores will be fixed forward
|
||||
## Line length violations
|
||||
"csrc/cutlass_extensions/vllm_cutlass_library_extension.py" = ["E501"]
|
||||
"tests/compile/piecewise/test_simple.py" = ["E501"]
|
||||
"tests/compile/piecewise/test_toy_llama.py" = ["E501", "B023"]
|
||||
"tests/entrypoints/conftest.py" = ["E501"]
|
||||
"tests/entrypoints/openai/test_audio.py" = ["E501"]
|
||||
"tests/entrypoints/openai/test_chat.py" = ["E501"]
|
||||
"tests/entrypoints/openai/test_chat_template.py" = ["E501"]
|
||||
"tests/entrypoints/openai/test_chat_with_tool_reasoning.py" = ["E501"]
|
||||
"tests/entrypoints/openai/test_completion_with_function_calling.py" = ["E501"]
|
||||
"tests/entrypoints/openai/test_video.py" = ["E501"]
|
||||
"tests/entrypoints/openai/test_vision.py" = ["E501"]
|
||||
"tests/entrypoints/test_chat_utils.py" = ["E501"]
|
||||
"tests/kernels/moe/modular_kernel_tools/common.py" = ["E501"]
|
||||
"tests/models/language/generation/test_gemma.py" = ["E501"]
|
||||
"tests/models/language/generation/test_mistral.py" = ["E501"]
|
||||
"tests/models/multimodal/generation/test_ultravox.py" = ["E501"]
|
||||
"tests/models/multimodal/generation/test_voxtral.py" = ["E501"]
|
||||
"tests/models/multimodal/generation/vlm_utils/custom_inputs.py" = ["E501"]
|
||||
"tests/tool_use/test_tool_choice_required.py" = ["E501"]
|
||||
"tests/v1/attention/utils.py" = ["E501"]
|
||||
"tests/v1/entrypoints/openai/responses/test_image.py" = ["E501"]
|
||||
"tests/v1/kv_connector/nixl_integration/test_accuracy.py" = ["E501"]
|
||||
"tests/v1/kv_connector/unit/test_offloading_connector.py" = ["E501"]
|
||||
"tests/v1/logits_processors/test_custom_offline.py" = ["E501"]
|
||||
"vllm/attention/ops/pallas_kv_cache_update.py" = ["E501"]
|
||||
"vllm/compilation/collective_fusion.py" = ["E501"]
|
||||
"vllm/compilation/wrapper.py" = ["E501"]
|
||||
"vllm/config/vllm.py" = ["E501"]
|
||||
"vllm/distributed/device_communicators/all2all.py" = ["E501"]
|
||||
"vllm/entrypoints/openai/protocol.py" = ["E501"]
|
||||
"vllm/lora/layers/vocal_parallel_embedding.py" = ["E501"]
|
||||
"vllm/model_executor/model_loader/bitsandbytes_loader.py" = ["E501"]
|
||||
"vllm/model_executor/models/bailing_moe.py" = ["E501"]
|
||||
"vllm/model_executor/models/hyperclovax_vision.py" = ["E501"]
|
||||
"vllm/model_executor/models/llama4_eagle.py" = ["E501"]
|
||||
"vllm/model_executor/models/longcat_flash_mtp.py" = ["E501"]
|
||||
"vllm/model_executor/models/phi4mm.py" = ["E501"]
|
||||
"vllm/model_executor/models/qwen3_next.py" = ["E501"]
|
||||
"vllm/model_executor/layers/quantization/ptpc_fp8.py" = ["E501"]
|
||||
"vllm/v1/attention/backends/mla/common.py" = ["E501"]
|
||||
"vllm/v1/engine/utils.py" = ["E501"]
|
||||
"vllm/v1/utils.py" = ["E501"]
|
||||
"vllm/v1/worker/gpu_model_runner.py" = ["E501"]
|
||||
## Simplification rules
|
||||
"tests/distributed/test_expert_placement.py" = ["SIM108"]
|
||||
"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"]
|
||||
"tests/kernels/attention/test_flashmla.py" = ["SIM108"]
|
||||
"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"]
|
||||
"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"]
|
||||
"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"]
|
||||
"tests/kernels/test_onednn.py" = ["SIM108"]
|
||||
"tests/kernels/utils.py" = ["SIM108"]
|
||||
"tests/multimodal/test_processing.py" = ["SIM108"]
|
||||
"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"]
|
||||
"vllm/distributed/parallel_state.py" = ["SIM108"]
|
||||
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
|
||||
"vllm/entrypoints/llm.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/layernorm.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"]
|
||||
"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"]
|
||||
"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"]
|
||||
"vllm/utils/__init__.py" = ["SIM108"]
|
||||
"vllm/v1/sample/ops/bad_words.py" = ["SIM108"]
|
||||
"vllm/v1/sample/rejection_sampler.py" = ["SIM108"]
|
||||
"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"]
|
||||
"vllm/_custom_ops.py" = ["SIM108"]
|
||||
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
|
||||
## Loop variable binding issues
|
||||
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
|
||||
## Type annotation modernization and other rules
|
||||
"vllm/attention/backends/abstract.py" = ["UP035", "UP006"]
|
||||
"vllm/attention/layer.py" = ["UP035", "UP006"]
|
||||
"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"]
|
||||
"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"]
|
||||
"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"]
|
||||
"vllm/engine/arg_utils.py" = ["UP035", "UP006"]
|
||||
"vllm/engine/metrics.py" = ["UP035", "UP006"]
|
||||
"vllm/engine/metrics_types.py" = ["UP035", "UP006"]
|
||||
"vllm/executor/executor_base.py" = ["UP035", "UP006"]
|
||||
"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"]
|
||||
"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"]
|
||||
"vllm/executor/ray_utils.py" = ["UP035", "UP006"]
|
||||
"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"]
|
||||
"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"]
|
||||
## Type comparison issues
|
||||
"vllm/multimodal/inputs.py" = ["E721"]
|
||||
# End of temporary ignores
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
@ -87,7 +166,7 @@ select = [
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
# "I",
|
||||
"I",
|
||||
# flake8-logging-format
|
||||
"G",
|
||||
]
|
||||
@ -104,21 +183,15 @@ ignore = [
|
||||
"UP007",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
follow_imports = "silent"
|
||||
|
||||
[tool.isort]
|
||||
skip_glob = [
|
||||
".buildkite/*",
|
||||
"benchmarks/*",
|
||||
"examples/*",
|
||||
]
|
||||
use_parentheses = true
|
||||
skip_gitignore = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"slow_test",
|
||||
|
||||
255
setup.py
255
setup.py
@ -34,32 +34,36 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# cannot import envs directly because it depends on vllm,
|
||||
# which is not installed yet
|
||||
envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py'))
|
||||
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm", "envs.py"))
|
||||
|
||||
VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
|
||||
|
||||
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
|
||||
logger.warning(
|
||||
"VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
|
||||
logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
|
||||
VLLM_TARGET_DEVICE = "cpu"
|
||||
elif not (sys.platform.startswith("linux")
|
||||
or sys.platform.startswith("darwin")):
|
||||
elif not (sys.platform.startswith("linux") or sys.platform.startswith("darwin")):
|
||||
logger.warning(
|
||||
"vLLM only supports Linux platform (including WSL) and MacOS."
|
||||
"Building on %s, "
|
||||
"so vLLM may not be able to run correctly", sys.platform)
|
||||
"so vLLM may not be able to run correctly",
|
||||
sys.platform,
|
||||
)
|
||||
VLLM_TARGET_DEVICE = "empty"
|
||||
elif (sys.platform.startswith("linux") and torch.version.cuda is None
|
||||
and os.getenv("VLLM_TARGET_DEVICE") is None
|
||||
and torch.version.hip is None):
|
||||
elif (
|
||||
sys.platform.startswith("linux")
|
||||
and torch.version.cuda is None
|
||||
and os.getenv("VLLM_TARGET_DEVICE") is None
|
||||
and torch.version.hip is None
|
||||
):
|
||||
# if cuda or hip is not available and VLLM_TARGET_DEVICE is not set,
|
||||
# fallback to cpu
|
||||
VLLM_TARGET_DEVICE = "cpu"
|
||||
|
||||
|
||||
def is_sccache_available() -> bool:
|
||||
return which("sccache") is not None and \
|
||||
not bool(int(os.getenv("VLLM_DISABLE_SCCACHE", "0")))
|
||||
return which("sccache") is not None and not bool(
|
||||
int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))
|
||||
)
|
||||
|
||||
|
||||
def is_ccache_available() -> bool:
|
||||
@ -83,8 +87,7 @@ def is_url_available(url: str) -> bool:
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
|
||||
def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
|
||||
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
|
||||
super().__init__(name, sources=[], py_limited_api=True, **kwa)
|
||||
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
|
||||
|
||||
@ -121,8 +124,8 @@ class cmake_build_ext(build_ext):
|
||||
if nvcc_threads is not None:
|
||||
nvcc_threads = int(nvcc_threads)
|
||||
logger.info(
|
||||
"Using NVCC_THREADS=%d as the number of nvcc threads.",
|
||||
nvcc_threads)
|
||||
"Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads
|
||||
)
|
||||
else:
|
||||
nvcc_threads = 1
|
||||
num_jobs = max(1, num_jobs // nvcc_threads)
|
||||
@ -146,36 +149,36 @@ class cmake_build_ext(build_ext):
|
||||
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
|
||||
|
||||
cmake_args = [
|
||||
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
|
||||
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
|
||||
"-DCMAKE_BUILD_TYPE={}".format(cfg),
|
||||
"-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE),
|
||||
]
|
||||
|
||||
verbose = envs.VERBOSE
|
||||
if verbose:
|
||||
cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON']
|
||||
cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
|
||||
|
||||
if is_sccache_available():
|
||||
cmake_args += [
|
||||
'-DCMAKE_C_COMPILER_LAUNCHER=sccache',
|
||||
'-DCMAKE_CXX_COMPILER_LAUNCHER=sccache',
|
||||
'-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache',
|
||||
'-DCMAKE_HIP_COMPILER_LAUNCHER=sccache',
|
||||
"-DCMAKE_C_COMPILER_LAUNCHER=sccache",
|
||||
"-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
|
||||
"-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache",
|
||||
"-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
|
||||
]
|
||||
elif is_ccache_available():
|
||||
cmake_args += [
|
||||
'-DCMAKE_C_COMPILER_LAUNCHER=ccache',
|
||||
'-DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
|
||||
'-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache',
|
||||
'-DCMAKE_HIP_COMPILER_LAUNCHER=ccache',
|
||||
"-DCMAKE_C_COMPILER_LAUNCHER=ccache",
|
||||
"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
|
||||
"-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache",
|
||||
"-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
|
||||
]
|
||||
|
||||
# Pass the python executable to cmake so it can find an exact
|
||||
# match.
|
||||
cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]
|
||||
cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)]
|
||||
|
||||
# Pass the python path to cmake so it can reuse the build dependencies
|
||||
# on subsequent calls to python.
|
||||
cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))]
|
||||
cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))]
|
||||
|
||||
# Override the base directory for FetchContent downloads to $ROOT/.deps
|
||||
# This allows sharing dependencies between profiles,
|
||||
@ -183,7 +186,7 @@ class cmake_build_ext(build_ext):
|
||||
# To override this, set the FETCHCONTENT_BASE_DIR environment variable.
|
||||
fc_base_dir = os.path.join(ROOT_DIR, ".deps")
|
||||
fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
|
||||
cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)]
|
||||
cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)]
|
||||
|
||||
#
|
||||
# Setup parallelism and build tool
|
||||
@ -191,35 +194,36 @@ class cmake_build_ext(build_ext):
|
||||
num_jobs, nvcc_threads = self.compute_num_jobs()
|
||||
|
||||
if nvcc_threads:
|
||||
cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)]
|
||||
cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)]
|
||||
|
||||
if is_ninja_available():
|
||||
build_tool = ['-G', 'Ninja']
|
||||
build_tool = ["-G", "Ninja"]
|
||||
cmake_args += [
|
||||
'-DCMAKE_JOB_POOL_COMPILE:STRING=compile',
|
||||
'-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs),
|
||||
"-DCMAKE_JOB_POOL_COMPILE:STRING=compile",
|
||||
"-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs),
|
||||
]
|
||||
else:
|
||||
# Default build tool to whatever cmake picks.
|
||||
build_tool = []
|
||||
# Make sure we use the nvcc from CUDA_HOME
|
||||
if _is_cuda():
|
||||
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
|
||||
cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"]
|
||||
|
||||
other_cmake_args = os.environ.get("CMAKE_ARGS")
|
||||
if other_cmake_args:
|
||||
cmake_args += other_cmake_args.split()
|
||||
|
||||
subprocess.check_call(
|
||||
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
|
||||
cwd=self.build_temp)
|
||||
["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args],
|
||||
cwd=self.build_temp,
|
||||
)
|
||||
|
||||
def build_extensions(self) -> None:
|
||||
# Ensure that CMake is present and working
|
||||
try:
|
||||
subprocess.check_output(['cmake', '--version'])
|
||||
subprocess.check_output(["cmake", "--version"])
|
||||
except OSError as e:
|
||||
raise RuntimeError('Cannot find CMake executable') from e
|
||||
raise RuntimeError("Cannot find CMake executable") from e
|
||||
|
||||
# Create build directory if it does not exist.
|
||||
if not os.path.exists(self.build_temp):
|
||||
@ -258,13 +262,18 @@ class cmake_build_ext(build_ext):
|
||||
# CMake appends the extension prefix to the install path,
|
||||
# and outdir already contains that prefix, so we need to remove it.
|
||||
prefix = outdir
|
||||
for _ in range(ext.name.count('.')):
|
||||
for _ in range(ext.name.count(".")):
|
||||
prefix = prefix.parent
|
||||
|
||||
# prefix here should actually be the same for all components
|
||||
install_args = [
|
||||
"cmake", "--install", ".", "--prefix", prefix, "--component",
|
||||
target_name(ext.name)
|
||||
"cmake",
|
||||
"--install",
|
||||
".",
|
||||
"--prefix",
|
||||
prefix,
|
||||
"--component",
|
||||
target_name(ext.name),
|
||||
]
|
||||
subprocess.check_call(install_args, cwd=self.build_temp)
|
||||
|
||||
@ -275,12 +284,15 @@ class cmake_build_ext(build_ext):
|
||||
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
|
||||
# directory so that they can be included in the editable build
|
||||
import glob
|
||||
files = glob.glob(os.path.join(self.build_lib, "vllm",
|
||||
"vllm_flash_attn", "**", "*.py"),
|
||||
recursive=True)
|
||||
|
||||
files = glob.glob(
|
||||
os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "**", "*.py"),
|
||||
recursive=True,
|
||||
)
|
||||
for file in files:
|
||||
dst_file = os.path.join("vllm/vllm_flash_attn",
|
||||
file.split("vllm/vllm_flash_attn/")[-1])
|
||||
dst_file = os.path.join(
|
||||
"vllm/vllm_flash_attn", file.split("vllm/vllm_flash_attn/")[-1]
|
||||
)
|
||||
print(f"Copying {file} to {dst_file}")
|
||||
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
||||
self.copy_file(file, dst_file)
|
||||
@ -290,8 +302,7 @@ class precompiled_build_ext(build_ext):
|
||||
"""Disables extension building when using precompiled binaries."""
|
||||
|
||||
def run(self) -> None:
|
||||
assert _is_cuda(
|
||||
), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
||||
assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
||||
|
||||
def build_extensions(self) -> None:
|
||||
print("Skipping build_ext: using precompiled extensions.")
|
||||
@ -312,9 +323,9 @@ class precompiled_wheel_utils:
|
||||
wheel_filename = wheel_url_or_path.split("/")[-1]
|
||||
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
||||
wheel_path = os.path.join(temp_dir, wheel_filename)
|
||||
print(f"Downloading wheel from {wheel_url_or_path} "
|
||||
f"to {wheel_path}")
|
||||
print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}")
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
urlretrieve(wheel_url_or_path, filename=wheel_path)
|
||||
else:
|
||||
wheel_path = wheel_url_or_path
|
||||
@ -335,25 +346,29 @@ class precompiled_wheel_utils:
|
||||
]
|
||||
|
||||
compiled_regex = re.compile(
|
||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||
)
|
||||
file_members = list(
|
||||
filter(lambda x: x.filename in files_to_copy,
|
||||
wheel.filelist))
|
||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
||||
)
|
||||
file_members += list(
|
||||
filter(lambda x: compiled_regex.match(x.filename),
|
||||
wheel.filelist))
|
||||
filter(lambda x: compiled_regex.match(x.filename), wheel.filelist)
|
||||
)
|
||||
|
||||
for file in file_members:
|
||||
print(f"[extract] {file.filename}")
|
||||
target_path = os.path.join(".", file.filename)
|
||||
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
||||
with wheel.open(file.filename) as src, open(
|
||||
target_path, "wb") as dst:
|
||||
with (
|
||||
wheel.open(file.filename) as src,
|
||||
open(target_path, "wb") as dst,
|
||||
):
|
||||
shutil.copyfileobj(src, dst)
|
||||
|
||||
pkg = os.path.dirname(file.filename).replace("/", ".")
|
||||
package_data_patch.setdefault(pkg, []).append(
|
||||
os.path.basename(file.filename))
|
||||
os.path.basename(file.filename)
|
||||
)
|
||||
|
||||
return package_data_patch
|
||||
finally:
|
||||
@ -369,10 +384,13 @@ class precompiled_wheel_utils:
|
||||
|
||||
try:
|
||||
# Get the latest commit hash of the upstream main branch.
|
||||
resp_json = subprocess.check_output([
|
||||
"curl", "-s",
|
||||
"https://api.github.com/repos/vllm-project/vllm/commits/main"
|
||||
]).decode("utf-8")
|
||||
resp_json = subprocess.check_output(
|
||||
[
|
||||
"curl",
|
||||
"-s",
|
||||
"https://api.github.com/repos/vllm-project/vllm/commits/main",
|
||||
]
|
||||
).decode("utf-8")
|
||||
upstream_main_commit = json.loads(resp_json)["sha"]
|
||||
|
||||
# In Docker build context, .git may be immutable or missing.
|
||||
@ -382,25 +400,32 @@ class precompiled_wheel_utils:
|
||||
# Check if the upstream_main_commit exists in the local repo
|
||||
try:
|
||||
subprocess.check_output(
|
||||
["git", "cat-file", "-e", f"{upstream_main_commit}"])
|
||||
["git", "cat-file", "-e", f"{upstream_main_commit}"]
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
# If not present, fetch it from the remote repository.
|
||||
# Note that this does not update any local branches,
|
||||
# but ensures that this commit ref and its history are
|
||||
# available in our local repo.
|
||||
subprocess.check_call([
|
||||
"git", "fetch", "https://github.com/vllm-project/vllm",
|
||||
"main"
|
||||
])
|
||||
subprocess.check_call(
|
||||
["git", "fetch", "https://github.com/vllm-project/vllm", "main"]
|
||||
)
|
||||
|
||||
# Then get the commit hash of the current branch that is the same as
|
||||
# the upstream main commit.
|
||||
current_branch = subprocess.check_output(
|
||||
["git", "branch", "--show-current"]).decode("utf-8").strip()
|
||||
current_branch = (
|
||||
subprocess.check_output(["git", "branch", "--show-current"])
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
|
||||
base_commit = subprocess.check_output([
|
||||
"git", "merge-base", f"{upstream_main_commit}", current_branch
|
||||
]).decode("utf-8").strip()
|
||||
base_commit = (
|
||||
subprocess.check_output(
|
||||
["git", "merge-base", f"{upstream_main_commit}", current_branch]
|
||||
)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
return base_commit
|
||||
except ValueError as err:
|
||||
raise ValueError(err) from None
|
||||
@ -408,7 +433,9 @@ class precompiled_wheel_utils:
|
||||
logger.warning(
|
||||
"Failed to get the base commit in the main branch. "
|
||||
"Using the nightly wheel. The libraries in this "
|
||||
"wheel may not be compatible with your dev branch: %s", err)
|
||||
"wheel may not be compatible with your dev branch: %s",
|
||||
err,
|
||||
)
|
||||
return "nightly"
|
||||
|
||||
|
||||
@ -418,12 +445,13 @@ def _no_device() -> bool:
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
has_cuda = torch.version.cuda is not None
|
||||
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu())
|
||||
return VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
return (VLLM_TARGET_DEVICE == "cuda"
|
||||
or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
|
||||
return (
|
||||
VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm"
|
||||
) and torch.version.hip is not None
|
||||
|
||||
|
||||
def _is_tpu() -> bool:
|
||||
@ -462,8 +490,12 @@ def get_rocm_version():
|
||||
minor = ctypes.c_uint32()
|
||||
patch = ctypes.c_uint32()
|
||||
|
||||
if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor),
|
||||
ctypes.byref(patch)) == 0):
|
||||
if (
|
||||
get_rocm_core_version(
|
||||
ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
return f"{major.value}.{minor.value}.{patch.value}"
|
||||
return None
|
||||
except Exception:
|
||||
@ -476,8 +508,9 @@ def get_nvcc_cuda_version() -> Version:
|
||||
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
||||
"""
|
||||
assert CUDA_HOME is not None, "CUDA_HOME is not set"
|
||||
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
|
||||
universal_newlines=True)
|
||||
nvcc_output = subprocess.check_output(
|
||||
[CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True
|
||||
)
|
||||
output = nvcc_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
||||
@ -489,14 +522,20 @@ def get_gaudi_sw_version():
|
||||
Returns the driver version.
|
||||
"""
|
||||
# Enable console printing for `hl-smi` check
|
||||
output = subprocess.run("hl-smi",
|
||||
shell=True,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
env={"ENABLE_CONSOLE": "true"})
|
||||
output = subprocess.run(
|
||||
"hl-smi",
|
||||
shell=True,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
env={"ENABLE_CONSOLE": "true"},
|
||||
)
|
||||
if output.returncode == 0 and output.stdout:
|
||||
return output.stdout.split("\n")[2].replace(
|
||||
" ", "").split(":")[1][:-1].split("-")[0]
|
||||
return (
|
||||
output.stdout.split("\n")[2]
|
||||
.replace(" ", "")
|
||||
.split(":")[1][:-1]
|
||||
.split("-")[0]
|
||||
)
|
||||
return "0.0.0" # when hl-smi is not available
|
||||
|
||||
|
||||
@ -546,8 +585,11 @@ def get_requirements() -> list[str]:
|
||||
for line in requirements:
|
||||
if line.startswith("-r "):
|
||||
resolved_requirements += _read_requirements(line.split()[1])
|
||||
elif not line.startswith("--") and not line.startswith(
|
||||
"#") and line.strip() != "":
|
||||
elif (
|
||||
not line.startswith("--")
|
||||
and not line.startswith("#")
|
||||
and line.strip() != ""
|
||||
):
|
||||
resolved_requirements.append(line)
|
||||
return resolved_requirements
|
||||
|
||||
@ -558,7 +600,7 @@ def get_requirements() -> list[str]:
|
||||
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||
modified_requirements = []
|
||||
for req in requirements:
|
||||
if ("vllm-flash-attn" in req and cuda_major != "12"):
|
||||
if "vllm-flash-attn" in req and cuda_major != "12":
|
||||
# vllm-flash-attn is built only for CUDA 12.x.
|
||||
# Skip for other versions.
|
||||
continue
|
||||
@ -573,8 +615,7 @@ def get_requirements() -> list[str]:
|
||||
elif _is_xpu():
|
||||
requirements = _read_requirements("xpu.txt")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported platform, please use CUDA, ROCm, or CPU.")
|
||||
raise ValueError("Unsupported platform, please use CUDA, ROCm, or CPU.")
|
||||
return requirements
|
||||
|
||||
|
||||
@ -590,14 +631,13 @@ if _is_cuda():
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
||||
# FA3 requires CUDA 12.3 or later
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
# Optional since this doesn't get built (produce an .so file) when
|
||||
# not targeting a hopper system
|
||||
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
|
||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
|
||||
)
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
@ -619,6 +659,7 @@ if envs.VLLM_USE_PRECOMPILED:
|
||||
wheel_url = wheel_location
|
||||
else:
|
||||
import platform
|
||||
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
wheel_tag = "manylinux1_x86_64"
|
||||
@ -628,8 +669,11 @@ if envs.VLLM_USE_PRECOMPILED:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
||||
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||
nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||
nightly_wheel_url = (
|
||||
f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||
)
|
||||
from urllib.request import urlopen
|
||||
|
||||
try:
|
||||
with urlopen(wheel_url) as resp:
|
||||
if resp.status != 200:
|
||||
@ -638,8 +682,7 @@ if envs.VLLM_USE_PRECOMPILED:
|
||||
print(f"[warn] Falling back to nightly wheel: {e}")
|
||||
wheel_url = nightly_wheel_url
|
||||
|
||||
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
|
||||
wheel_url)
|
||||
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(wheel_url)
|
||||
for pkg, files in patch.items():
|
||||
package_data.setdefault(pkg, []).extend(files)
|
||||
|
||||
@ -650,8 +693,9 @@ if not ext_modules:
|
||||
cmdclass = {}
|
||||
else:
|
||||
cmdclass = {
|
||||
"build_ext":
|
||||
precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
|
||||
"build_ext": precompiled_build_ext
|
||||
if envs.VLLM_USE_PRECOMPILED
|
||||
else cmake_build_ext
|
||||
}
|
||||
|
||||
setup(
|
||||
@ -664,8 +708,11 @@ setup(
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
||||
"audio": ["librosa", "soundfile",
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"audio": [
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"mistral_common[audio]",
|
||||
], # Required for audio processing
|
||||
"video": [], # Kept for backwards compatibility
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
"flashinfer": ["flashinfer-python==0.3.1"],
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
import weakref
|
||||
from unittest.mock import Mock
|
||||
@ -37,16 +38,21 @@ def test_vllm_gc_ed():
|
||||
|
||||
|
||||
def _fix_prompt_embed_outputs(
|
||||
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
|
||||
example_prompts: list[str]) -> list[tuple[list[int], str]]:
|
||||
vllm_outputs: list[tuple[list[int], str]],
|
||||
hf_model: HfRunner,
|
||||
example_prompts: list[str],
|
||||
) -> list[tuple[list[int], str]]:
|
||||
fixed_vllm_outputs = []
|
||||
for vllm_output, hf_input, prompt in zip(
|
||||
vllm_outputs, hf_model.get_inputs(example_prompts),
|
||||
example_prompts):
|
||||
vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts
|
||||
):
|
||||
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
||||
fixed_vllm_outputs.append(
|
||||
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
|
||||
prompt + vllm_output[1]))
|
||||
(
|
||||
hf_input_ids + vllm_output[0][len(hf_input_ids) :],
|
||||
prompt + vllm_output[1],
|
||||
)
|
||||
)
|
||||
return fixed_vllm_outputs
|
||||
|
||||
|
||||
@ -69,8 +75,7 @@ def test_models(
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||
pytest.skip(
|
||||
f"{backend} does not support gemma2 with full context length.")
|
||||
pytest.skip(f"{backend} does not support gemma2 with full context length.")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
@ -78,34 +83,35 @@ def test_models(
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = "The following numbers of the sequence " + ", ".join(
|
||||
str(i) for i in range(1024)) + " are:"
|
||||
prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||
example_prompts)
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
prompt_embeds, max_tokens)
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts)
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@ -117,21 +123,18 @@ def test_models(
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model, distributed_executor_backend, attention_backend, "
|
||||
"test_suite, extra_env", [
|
||||
"model, distributed_executor_backend, attention_backend, test_suite, extra_env",
|
||||
[
|
||||
("distilbert/distilgpt2", "ray", "", "L4", {}),
|
||||
("distilbert/distilgpt2", "mp", "", "L4", {}),
|
||||
("distilbert/distilgpt2", "ray", "", "L4", {
|
||||
"VLLM_SLEEP_WHEN_IDLE": "1"
|
||||
}),
|
||||
("distilbert/distilgpt2", "mp", "", "L4", {
|
||||
"VLLM_SLEEP_WHEN_IDLE": "1"
|
||||
}),
|
||||
("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||
("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
||||
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
||||
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
||||
])
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models_distributed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@ -149,11 +152,14 @@ def test_models_distributed(
|
||||
pytest.skip(f"Skip test for {test_suite}")
|
||||
|
||||
with monkeypatch.context() as monkeypatch_context:
|
||||
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
||||
if (
|
||||
model == "meta-llama/Llama-3.2-1B-Instruct"
|
||||
and distributed_executor_backend == "ray"
|
||||
and attention_backend == ""
|
||||
and test_suite == "L4"
|
||||
): # noqa
|
||||
if enable_prompt_embeds:
|
||||
pytest.skip(
|
||||
"enable_prompt_embeds does not work with ray compiled dag."
|
||||
)
|
||||
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
||||
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
||||
|
||||
@ -175,30 +181,26 @@ def test_models_distributed(
|
||||
# will hurt multiprocessing backend with fork method
|
||||
# (the default method).
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||
example_prompts)
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
prompt_embeds, max_tokens)
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts)
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@ -209,27 +211,23 @@ def test_models_distributed(
|
||||
|
||||
|
||||
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
|
||||
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
|
||||
if not VLLM_USE_V1:
|
||||
pytest.skip("Skipping V0 test, dump input not supported")
|
||||
|
||||
# Needed to mock an error in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
|
||||
with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model:
|
||||
if isinstance(vllm_model.llm.llm_engine, LLMEngineV1):
|
||||
v1_test_failed_model_execution(vllm_model)
|
||||
|
||||
|
||||
def v1_test_failed_model_execution(vllm_model):
|
||||
|
||||
engine = vllm_model.llm.llm_engine
|
||||
mocked_execute_model = Mock(
|
||||
side_effect=RuntimeError("Mocked Critical Error"))
|
||||
engine.engine_core.engine_core.model_executor.execute_model =\
|
||||
mocked_execute_model
|
||||
mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error"))
|
||||
engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
prompts = [
|
||||
|
||||
@ -5,5 +5,6 @@ from ..utils import compare_two_settings
|
||||
|
||||
|
||||
def test_cpu_offload():
|
||||
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [],
|
||||
["--cpu-offload-gb", "1"])
|
||||
compare_two_settings(
|
||||
"meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"]
|
||||
)
|
||||
|
||||
@ -23,13 +23,13 @@ def test_python_error():
|
||||
tensors = []
|
||||
with allocator.use_memory_pool():
|
||||
# allocate 70% of the total memory
|
||||
x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
|
||||
x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
|
||||
tensors.append(x)
|
||||
# release the memory
|
||||
allocator.sleep()
|
||||
|
||||
# allocate more memory than the total memory
|
||||
y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
|
||||
y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
|
||||
tensors.append(y)
|
||||
with pytest.raises(RuntimeError):
|
||||
# when the allocator is woken up, it should raise an error
|
||||
@ -41,17 +41,17 @@ def test_python_error():
|
||||
def test_basic_cumem():
|
||||
# some tensors from default memory pool
|
||||
shape = (1024, 1024)
|
||||
x = torch.empty(shape, device='cuda')
|
||||
x = torch.empty(shape, device="cuda")
|
||||
x.zero_()
|
||||
|
||||
# some tensors from custom memory pool
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
# custom memory pool
|
||||
y = torch.empty(shape, device='cuda')
|
||||
y = torch.empty(shape, device="cuda")
|
||||
y.zero_()
|
||||
y += 1
|
||||
z = torch.empty(shape, device='cuda')
|
||||
z = torch.empty(shape, device="cuda")
|
||||
z.zero_()
|
||||
z += 2
|
||||
|
||||
@ -74,16 +74,16 @@ def test_basic_cumem():
|
||||
def test_cumem_with_cudagraph():
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
weight = torch.eye(1024, device='cuda')
|
||||
weight = torch.eye(1024, device="cuda")
|
||||
with allocator.use_memory_pool(tag="discard"):
|
||||
cache = torch.empty(1024, 1024, device='cuda')
|
||||
cache = torch.empty(1024, 1024, device="cuda")
|
||||
|
||||
def model(x):
|
||||
out = x @ weight
|
||||
cache[:out.size(0)].copy_(out)
|
||||
cache[: out.size(0)].copy_(out)
|
||||
return out + 1
|
||||
|
||||
x = torch.empty(128, 1024, device='cuda')
|
||||
x = torch.empty(128, 1024, device="cuda")
|
||||
|
||||
# warmup
|
||||
model(x)
|
||||
@ -109,7 +109,7 @@ def test_cumem_with_cudagraph():
|
||||
model_graph.replay()
|
||||
|
||||
# cache content is as expected
|
||||
assert torch.allclose(x, cache[:x.size(0)])
|
||||
assert torch.allclose(x, cache[: x.size(0)])
|
||||
|
||||
# output content is as expected
|
||||
assert torch.allclose(y, x + 1)
|
||||
@ -123,7 +123,8 @@ def test_cumem_with_cudagraph():
|
||||
("meta-llama/Llama-3.2-1B", True),
|
||||
# sleep mode with pytorch checkpoint
|
||||
("facebook/opt-125m", True),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
|
||||
with monkeypatch.context() as m:
|
||||
assert use_v1
|
||||
|
||||
@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_latency():
|
||||
command = [
|
||||
"vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32",
|
||||
"--output-len", "1", "--enforce-eager", "--load-format", "dummy"
|
||||
"vllm",
|
||||
"bench",
|
||||
"latency",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--input-len",
|
||||
"32",
|
||||
"--output-len",
|
||||
"1",
|
||||
"--enforce-eager",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
|
||||
@ -7,8 +7,11 @@ import numpy as np
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
|
||||
SampleRequest)
|
||||
from vllm.benchmarks.datasets import (
|
||||
RandomDataset,
|
||||
RandomMultiModalDataset,
|
||||
SampleRequest,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -27,11 +30,9 @@ class Params(NamedTuple):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def random_dataset_params() -> Params:
|
||||
return Params(num_requests=16,
|
||||
prefix_len=7,
|
||||
range_ratio=0.3,
|
||||
input_len=50,
|
||||
output_len=20)
|
||||
return Params(
|
||||
num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20
|
||||
)
|
||||
|
||||
|
||||
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||
@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len)
|
||||
|
||||
|
||||
def _collect_samples(dataset: RandomDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int = 16,
|
||||
prefix_len: int = 7,
|
||||
range_ratio: float = 0.3,
|
||||
input_len: int = 50,
|
||||
output_len: int = 20) -> list[tuple[str, int, int]]:
|
||||
def _collect_samples(
|
||||
dataset: RandomDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int = 16,
|
||||
prefix_len: int = 7,
|
||||
range_ratio: float = 0.3,
|
||||
input_len: int = 50,
|
||||
output_len: int = 20,
|
||||
) -> list[tuple[str, int, int]]:
|
||||
samples = dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset,
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_same_seed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
random_dataset_params: Params) -> None:
|
||||
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
|
||||
) -> None:
|
||||
"""Same seed should yield identical outputs, even if global RNGs change.
|
||||
|
||||
This guards against accidental reliance on Python's random or np.random
|
||||
@ -70,13 +73,15 @@ def test_random_dataset_same_seed(
|
||||
common_seed = 123
|
||||
dataset_a = RandomDataset(random_seed=common_seed)
|
||||
dataset_b = RandomDataset(random_seed=common_seed)
|
||||
a = _collect_samples(dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
a = _collect_samples(
|
||||
dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len,
|
||||
)
|
||||
|
||||
# Perturb global RNG state to ensure isolation
|
||||
random.seed(999)
|
||||
@ -84,43 +89,50 @@ def test_random_dataset_same_seed(
|
||||
np.random.seed(888)
|
||||
_ = [np.random.random() for _ in range(100)]
|
||||
|
||||
b = _collect_samples(dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
b = _collect_samples(
|
||||
dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len,
|
||||
)
|
||||
assert a == b
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
random_dataset_params: Params) -> None:
|
||||
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
|
||||
) -> None:
|
||||
"""Different seeds should change outputs with overwhelming likelihood."""
|
||||
p = random_dataset_params
|
||||
seed_a = 0
|
||||
dataset_a = RandomDataset(random_seed=seed_a)
|
||||
a = _collect_samples(dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
a = _collect_samples(
|
||||
dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len,
|
||||
)
|
||||
|
||||
seed_b = 999
|
||||
dataset_b = RandomDataset(random_seed=seed_b)
|
||||
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
||||
random.seed(seed_a)
|
||||
np.random.seed(seed_a)
|
||||
b = _collect_samples(dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
b = _collect_samples(
|
||||
dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len,
|
||||
)
|
||||
assert a != b
|
||||
|
||||
|
||||
@ -128,6 +140,7 @@ def test_random_dataset_different_seeds(
|
||||
# RandomMultiModalDataset tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _mm_fingerprint_sample(
|
||||
req: SampleRequest,
|
||||
) -> tuple[str, int, int, int, list[str]]:
|
||||
@ -152,8 +165,13 @@ def _mm_fingerprint_sample(
|
||||
item_prefixes.append(f"video:{url[:22]}")
|
||||
else:
|
||||
item_prefixes.append("unknown:")
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
|
||||
item_prefixes)
|
||||
return (
|
||||
req.prompt,
|
||||
req.prompt_len,
|
||||
req.expected_output_len,
|
||||
len(items),
|
||||
item_prefixes,
|
||||
)
|
||||
|
||||
|
||||
def _collect_mm_samples(
|
||||
@ -214,6 +232,7 @@ def test_random_mm_different_seeds(
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa != fb
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_respects_limits(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
for s in samples:
|
||||
assert s.multi_modal_data == []
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_num_items_per_prompt(
|
||||
hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Fixed number of images per prompt
|
||||
# set num_mm_items_range_ratio to 0.0
|
||||
@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt(
|
||||
def test_random_mm_bucket_config_not_mutated(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# This bucket config is not normalized to sum to 1
|
||||
# and has more buckets than requested images
|
||||
@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated(
|
||||
# Ensure the original dict content is unchanged
|
||||
assert original == snapshot
|
||||
|
||||
|
||||
# Vary number of mm items per prompt
|
||||
# set num_mm_items_range_ratio to 0.5
|
||||
samples_varying_items = _collect_mm_samples(
|
||||
|
||||
@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"
|
||||
]
|
||||
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
@ -46,6 +44,7 @@ def test_bench_serve(server):
|
||||
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve_chat(server):
|
||||
command = [
|
||||
|
||||
@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_throughput():
|
||||
command = [
|
||||
"vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len",
|
||||
"32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy"
|
||||
"vllm",
|
||||
"bench",
|
||||
"throughput",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--input-len",
|
||||
"32",
|
||||
"--output-len",
|
||||
"1",
|
||||
"--enforce-eager",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
|
||||
@ -23,8 +23,7 @@ class LazyInitPass(InductorPass):
|
||||
and then immediately invoke it.
|
||||
"""
|
||||
|
||||
def __init__(self, pass_cls: type[VllmInductorPass],
|
||||
vllm_config: VllmConfig):
|
||||
def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
|
||||
self.pass_cls = pass_cls
|
||||
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
||||
|
||||
@ -45,20 +44,18 @@ class TestBackend:
|
||||
Inductor config is default-initialized from VllmConfig.CompilationConfig.
|
||||
"""
|
||||
|
||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
|
||||
None]]):
|
||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
|
||||
self.custom_passes = list(passes)
|
||||
compile_config = get_current_vllm_config().compilation_config
|
||||
self.inductor_config = compile_config.inductor_compile_config
|
||||
self.inductor_config['force_disable_caches'] = True
|
||||
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
|
||||
self.inductor_config["force_disable_caches"] = True
|
||||
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||
self.graph_pre_compile = deepcopy(graph)
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
return compile_fx(graph,
|
||||
example_inputs,
|
||||
config_patches=self.inductor_config)
|
||||
|
||||
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
|
||||
|
||||
@with_pattern_match_debug
|
||||
def post_pass(self, graph: fx.Graph):
|
||||
@ -82,8 +79,7 @@ class TestBackend:
|
||||
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
||||
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
||||
if fully_replaced:
|
||||
assert num_post == 0, \
|
||||
f"Unexpected op {op.name()} in post-pass graph"
|
||||
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
|
||||
|
||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||
for op in ops:
|
||||
|
||||
@ -38,8 +38,8 @@ test_params_full_cudagraph = []
|
||||
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
||||
for mla_backend in MLA_backends:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(
|
||||
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
|
||||
pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
|
||||
)
|
||||
|
||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||
other_backend_configs = [
|
||||
@ -47,7 +47,8 @@ other_backend_configs = [
|
||||
]
|
||||
for backend_config in other_backend_configs:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
|
||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
@ -55,8 +56,10 @@ def llm_pair(request):
|
||||
model, backend_config = request.param
|
||||
|
||||
# Dynamically skip test if GPU capability is not met
|
||||
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
||||
!= current_platform.get_device_capability():
|
||||
if (
|
||||
backend_config.specific_gpu_arch
|
||||
and backend_config.specific_gpu_arch != current_platform.get_device_capability()
|
||||
):
|
||||
if backend_config.specific_gpu_arch == (9, 0):
|
||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||
elif backend_config.specific_gpu_arch == (10, 0):
|
||||
@ -76,8 +79,7 @@ def llm_pair(request):
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=128,
|
||||
compilation_config=\
|
||||
CompilationConfig(**backend_config.comp_config),
|
||||
compilation_config=CompilationConfig(**backend_config.comp_config),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
@ -113,20 +115,22 @@ class TestFullCUDAGraph:
|
||||
meaning there would be multiple LLM instances hogging memory simultaneously.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
||||
(1, 10),
|
||||
(7, 10),
|
||||
(16, 10),
|
||||
(25, 10),
|
||||
(32, 10),
|
||||
(45, 10),
|
||||
(64, 10),
|
||||
(123, 10),
|
||||
(8, 5),
|
||||
(8, 30),
|
||||
])
|
||||
def test_full_cudagraph(self, batch_size, max_tokens,
|
||||
llm_pair: tuple[LLM, LLM]):
|
||||
@pytest.mark.parametrize(
|
||||
("batch_size", "max_tokens"),
|
||||
[
|
||||
(1, 10),
|
||||
(7, 10),
|
||||
(16, 10),
|
||||
(25, 10),
|
||||
(32, 10),
|
||||
(45, 10),
|
||||
(64, 10),
|
||||
(123, 10),
|
||||
(8, 5),
|
||||
(8, 30),
|
||||
],
|
||||
)
|
||||
def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
|
||||
"""
|
||||
Test various batch sizes and max_tokens to ensure that the
|
||||
full cudagraph compilation works for padded cases too.
|
||||
@ -137,26 +141,34 @@ class TestFullCUDAGraph:
|
||||
prompts = ["the quick brown fox"] * batch_size
|
||||
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
||||
# that can amplify tiny numeric differences across runtimes.
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=1.0)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, max_tokens=max_tokens, top_p=1.0
|
||||
)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
|
||||
# Check that all responses are the same
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text.lower() == \
|
||||
full_res.outputs[0].text.lower()
|
||||
for piecewise_res, full_res in zip(piecewise_responses, full_responses):
|
||||
assert (
|
||||
piecewise_res.outputs[0].text.lower()
|
||||
== full_res.outputs[0].text.lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
}), pytest.raises(RuntimeError):
|
||||
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
|
||||
with (
|
||||
temporary_environ(
|
||||
{
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
}
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
||||
)
|
||||
|
||||
@ -10,10 +10,14 @@ from torch import nn
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
@ -27,12 +31,7 @@ RANDOM_SEED = 0
|
||||
|
||||
@support_torch_compile
|
||||
class ParentModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -40,7 +39,6 @@ class ParentModel(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
|
||||
@ -51,17 +49,21 @@ class Attention(nn.Module):
|
||||
nn.init.xavier_normal_(
|
||||
self.pre_attn.weight.data,
|
||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||
gain=0.001)
|
||||
gain=0.001,
|
||||
)
|
||||
nn.init.xavier_normal_(
|
||||
self.post_attn.weight.data,
|
||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||
gain=0.001)
|
||||
gain=0.001,
|
||||
)
|
||||
|
||||
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_f32 = x.float()
|
||||
return (x_f32 * torch.rsqrt(
|
||||
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
|
||||
self.rms_norm_weight).to(x.dtype)
|
||||
return (
|
||||
x_f32
|
||||
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
|
||||
* self.rms_norm_weight
|
||||
).to(x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.pre_attn(x)
|
||||
@ -76,14 +78,15 @@ class Attention(nn.Module):
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.attn = Attention(mlp_size, hidden_size)
|
||||
|
||||
@ -93,21 +96,21 @@ class CompiledAttention(nn.Module):
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledAttentionTwo(CompiledAttention):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.attn(x) + x
|
||||
|
||||
|
||||
@ignore_torch_compile
|
||||
class SimpleModelWithTwoGraphs(ParentModel):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
# Test will fail without set_model_tag here with error:
|
||||
# "ValueError: too many values to unpack (expected 3)"
|
||||
@ -142,32 +145,45 @@ class SimpleModelWithTwoGraphs(ParentModel):
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
def run_model(
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
inputs: torch.Tensor,
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(inputs)
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(inputs[:2])
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(inputs[:1])
|
||||
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(inputs[:2])
|
||||
|
||||
output = output.cpu()
|
||||
@ -178,82 +194,104 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
outputs = []
|
||||
|
||||
# piecewise compile
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
# Pre-allocate memory for CUDAGraph which expects
|
||||
# static tensor addresses
|
||||
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2, # two graphs for the model
|
||||
num_piecewise_graphs_seen=6,
|
||||
# attn_one, attn_two each has 3 piecewise graphs
|
||||
# (pre attn, post attn, silly_attention) each
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=2, # two graphs for the model
|
||||
num_piecewise_graphs_seen=6,
|
||||
# attn_one, attn_two each has 3 piecewise graphs
|
||||
# (pre attn, post attn, silly_attention) each
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# no compile or cudagraph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, ))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# piecewise compile without CUDA graph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly.attention"],
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly.attention"],
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
):
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# Generally don't expect outputs with and without inductor
|
||||
# to be bitwise equivalent
|
||||
|
||||
@ -11,8 +11,13 @@ from torch import nn
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
@ -23,12 +28,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
|
||||
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -60,53 +60,65 @@ def _run_simple_model(
|
||||
expected_num_backend_compilations,
|
||||
expected_num_cudagraph_captured,
|
||||
):
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix='')
|
||||
model = SillyModel(vllm_config=vllm_config, prefix="")
|
||||
|
||||
inputs = torch.randn(100).cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
num_piecewise_capturable_graphs_seen=
|
||||
expected_num_piecewise_capturable_graphs_seen,
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||
), set_forward_context(None,
|
||||
vllm_config=vllm_config): # background context
|
||||
),
|
||||
set_forward_context(None, vllm_config=vllm_config),
|
||||
): # background context
|
||||
# warm up with background context
|
||||
model(inputs)
|
||||
|
||||
# capturing/replaying should under context of cudagraph dispatching
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(torch.randn(2).cuda())
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=1, )):
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input = torch.zeros(2).cuda()
|
||||
reset_global_counter()
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(input)
|
||||
assert get_global_counter() == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||
@ -122,10 +134,8 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
use_inductor=use_inductor,
|
||||
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||
expected_num_backend_compilations=
|
||||
3, # num_piecewise_capturable_graphs_seen
|
||||
expected_num_cudagraph_captured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
expected_num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
expected_num_cudagraph_captured=6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
)
|
||||
|
||||
|
||||
@ -134,8 +144,7 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
def test_simple_inductor_graph_partition(splitting_ops):
|
||||
assert VLLM_USE_V1
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
_run_simple_model(
|
||||
# inductor graph partition automatically resets splitting_ops
|
||||
@ -143,13 +152,9 @@ def test_simple_inductor_graph_partition(splitting_ops):
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=True,
|
||||
use_inductor=True,
|
||||
expected_num_piecewise_graphs_seen=
|
||||
1, # since not splitting at fx graph level
|
||||
expected_num_piecewise_capturable_graphs_seen=
|
||||
1, # since not splitting at fx graph level
|
||||
expected_num_backend_compilations=
|
||||
1, # since not splitting at fx graph level
|
||||
expected_num_cudagraph_captured=
|
||||
6, # inductor graph partition still captures 6
|
||||
expected_num_piecewise_graphs_seen=1, # since not splitting at fx graph level
|
||||
expected_num_piecewise_capturable_graphs_seen=1, # since not splitting at fx graph level
|
||||
expected_num_backend_compilations=1, # since not splitting at fx graph level
|
||||
expected_num_cudagraph_captured=6, # inductor graph partition still captures 6
|
||||
# graph, same as fx graph partition.
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@ This is a tractable model, the weights and computation are specially designed
|
||||
if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||
initialized randomly with a fixed seed.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -17,8 +18,13 @@ from torch import nn
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
@ -43,15 +49,14 @@ class LlamaConfig:
|
||||
factors.append((k, v))
|
||||
factors.sort()
|
||||
import hashlib
|
||||
return hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
|
||||
return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.mlp_size >= self.hidden_size
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_projection = nn.Linear(
|
||||
@ -66,31 +71,31 @@ class LlamaMLP(nn.Module):
|
||||
)
|
||||
|
||||
if config.tractable_init:
|
||||
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
|
||||
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
|
||||
nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
|
||||
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
|
||||
nn.init.eye_(self.down_projection.weight.data)
|
||||
else:
|
||||
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(
|
||||
config.random_seed),
|
||||
gain=0.001)
|
||||
nn.init.xavier_normal_(self.down_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(
|
||||
config.random_seed),
|
||||
gain=0.001)
|
||||
nn.init.xavier_normal_(
|
||||
self.gate_up_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(config.random_seed),
|
||||
gain=0.001,
|
||||
)
|
||||
nn.init.xavier_normal_(
|
||||
self.down_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(config.random_seed),
|
||||
gain=0.001,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# for tractable_init and positive input, this is
|
||||
# essentially an elementwise-square
|
||||
x = self.gate_up_projection(x)
|
||||
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
|
||||
x[:, x.size(1) // 2:])
|
||||
x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
|
||||
x = self.down_projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig) -> None:
|
||||
super().__init__()
|
||||
self.qkv_projection = nn.Linear(
|
||||
@ -106,21 +111,25 @@ class LlamaAttention(nn.Module):
|
||||
)
|
||||
|
||||
if config.tractable_init:
|
||||
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
|
||||
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
|
||||
config.hidden_size])
|
||||
nn.init.eye_(self.qkv_projection.weight.data[2 *
|
||||
config.hidden_size:])
|
||||
nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
|
||||
nn.init.eye_(
|
||||
self.qkv_projection.weight.data[
|
||||
config.hidden_size : 2 * config.hidden_size
|
||||
]
|
||||
)
|
||||
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
|
||||
nn.init.eye_(self.output_projection.weight.data)
|
||||
else:
|
||||
nn.init.xavier_normal_(self.qkv_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(
|
||||
config.random_seed),
|
||||
gain=0.001)
|
||||
nn.init.xavier_normal_(self.output_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(
|
||||
config.random_seed),
|
||||
gain=0.001)
|
||||
nn.init.xavier_normal_(
|
||||
self.qkv_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(config.random_seed),
|
||||
gain=0.001,
|
||||
)
|
||||
nn.init.xavier_normal_(
|
||||
self.output_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(config.random_seed),
|
||||
gain=0.001,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -144,7 +153,6 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig) -> None:
|
||||
super().__init__()
|
||||
self.self_attention = LlamaAttention(config)
|
||||
@ -164,7 +172,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
- if residual is not None, the outputs are:
|
||||
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
|
||||
- hidden_states = (residual + 1) ** 2
|
||||
""" # noqa
|
||||
""" # noqa
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states + 1
|
||||
@ -173,8 +181,9 @@ class LlamaDecoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states + 1
|
||||
|
||||
hidden_states = self.self_attention(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
hidden_states = self.self_attention(
|
||||
positions=positions, hidden_states=hidden_states
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
@ -186,20 +195,22 @@ class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
config: LlamaConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
config: LlamaConfig,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_tokens = nn.Embedding(
|
||||
num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
|
||||
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]
|
||||
)
|
||||
|
||||
# this is the initial value of the hidden states
|
||||
self.embedding_tokens.weight.data.fill_(config.init_value)
|
||||
@ -216,34 +227,39 @@ class LlamaModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def tractable_computation(input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
config: LlamaConfig,
|
||||
init_value: float = 1.0) -> torch.Tensor:
|
||||
hidden_states = torch.ones(input_ids.size(0),
|
||||
config.hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=input_ids.dtype) * init_value
|
||||
def tractable_computation(
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
config: LlamaConfig,
|
||||
init_value: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = (
|
||||
torch.ones(
|
||||
input_ids.size(0),
|
||||
config.hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=input_ids.dtype,
|
||||
)
|
||||
* init_value
|
||||
)
|
||||
|
||||
# first layer
|
||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||
hidden_states = (residual + 1)**2
|
||||
hidden_states = (residual + 1) ** 2
|
||||
|
||||
# following layers
|
||||
for _ in range(config.num_layers - 1):
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||
hidden_states = (residual + 1)**2
|
||||
hidden_states = (residual + 1) ** 2
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(llama_config,
|
||||
use_compile: bool,
|
||||
use_inductor: bool,
|
||||
split_attn: bool = False) -> torch.Tensor:
|
||||
|
||||
def run_model(
|
||||
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
|
||||
) -> torch.Tensor:
|
||||
if use_compile:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
@ -256,54 +272,66 @@ def run_model(llama_config,
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, )
|
||||
level=CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config,
|
||||
additional_config=llama_config)
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=compilation_config, additional_config=llama_config
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = LlamaModel(config=llama_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="").eval().cuda()
|
||||
model = (
|
||||
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config): # background context
|
||||
with set_forward_context({}, vllm_config=vllm_config): # background context
|
||||
B = 16 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
|
||||
positions = torch.arange(B).cuda()
|
||||
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(input_ids, positions)
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(input_ids[:2], positions[:2])
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(input_ids[:1], positions[:1])
|
||||
|
||||
input_ids[:2].zero_()
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
|
||||
output = output.cpu()
|
||||
|
||||
if llama_config.tractable_init:
|
||||
expected_output = tractable_computation(input_ids[:2],
|
||||
positions[:2],
|
||||
llama_config).cpu()
|
||||
expected_output = tractable_computation(
|
||||
input_ids[:2], positions[:2], llama_config
|
||||
).cpu()
|
||||
|
||||
assert torch.allclose(output, expected_output)
|
||||
else:
|
||||
@ -314,27 +342,23 @@ def run_model(llama_config,
|
||||
def test_toy_llama(use_inductor: bool):
|
||||
# compare output with and without piecewise compilation
|
||||
|
||||
llama_config = LlamaConfig(hidden_size=128,
|
||||
mlp_size=256,
|
||||
vocab_size=128,
|
||||
num_layers=12)
|
||||
llama_config = LlamaConfig(
|
||||
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
|
||||
)
|
||||
|
||||
tractable_config = LlamaConfig(hidden_size=128,
|
||||
mlp_size=256,
|
||||
vocab_size=128,
|
||||
num_layers=2,
|
||||
tractable_init=True)
|
||||
tractable_config = LlamaConfig(
|
||||
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
|
||||
)
|
||||
|
||||
outputs = []
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config, use_inductor=False, use_compile=False))
|
||||
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
|
||||
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||
|
||||
if use_inductor:
|
||||
@ -343,41 +367,41 @@ def test_toy_llama(use_inductor: bool):
|
||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_piecewise_capturable_graphs_seen=1,
|
||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=
|
||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
**kwargs,
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_piecewise_capturable_graphs_seen=1,
|
||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True))
|
||||
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
|
||||
)
|
||||
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=2 * llama_config.num_layers +
|
||||
1, # 2 * num_layers + 1
|
||||
num_piecewise_capturable_graphs_seen=1 +
|
||||
llama_config.num_layers, # 1 + num_layers
|
||||
num_backend_compilations=1 +
|
||||
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=2 *
|
||||
(1 + llama_config.num_layers
|
||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
|
||||
num_piecewise_capturable_graphs_seen=1
|
||||
+ llama_config.num_layers, # 1 + num_layers
|
||||
num_backend_compilations=1
|
||||
+ llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=2
|
||||
* (
|
||||
1 + llama_config.num_layers
|
||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True))
|
||||
run_model(tractable_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True)
|
||||
run_model(
|
||||
llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True,
|
||||
)
|
||||
)
|
||||
run_model(
|
||||
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
|
||||
)
|
||||
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
@ -388,17 +412,15 @@ def benchmark():
|
||||
from triton.testing import do_bench
|
||||
|
||||
# similar to llama 3.1-8B
|
||||
llama_config = LlamaConfig(hidden_size=4096,
|
||||
mlp_size=14336,
|
||||
vocab_size=128 * 1024,
|
||||
num_layers=32)
|
||||
llama_config = LlamaConfig(
|
||||
hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
|
||||
)
|
||||
|
||||
# a tiny model to measure the overhead
|
||||
# of piecewise cudagraph
|
||||
llama_config = LlamaConfig(hidden_size=40,
|
||||
mlp_size=80,
|
||||
vocab_size=128,
|
||||
num_layers=2)
|
||||
llama_config = LlamaConfig(
|
||||
hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
|
||||
)
|
||||
|
||||
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
|
||||
|
||||
@ -424,12 +446,15 @@ def benchmark():
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = LlamaModel(config=llama_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="").eval().cuda().to(torch.bfloat16)
|
||||
model = (
|
||||
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
|
||||
.eval()
|
||||
.cuda()
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
B = 256 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
|
||||
positions = torch.arange(B).cuda().to(torch.bfloat16)
|
||||
|
||||
graphs = {}
|
||||
@ -451,21 +476,25 @@ def benchmark():
|
||||
# and use it later, because it will look up the name `b` in the
|
||||
# enclosing scope, and the value of `b` will always be 256.
|
||||
# it is fine here, because we only use the lambda function once.
|
||||
runtime = do_bench(lambda: graphs[b][0] # noqa
|
||||
(input_ids[:b], positions[:b])) # noqa
|
||||
runtime = do_bench(
|
||||
lambda: graphs[b][0]( # noqa
|
||||
input_ids[:b], positions[:b]
|
||||
)
|
||||
) # noqa
|
||||
piecewise_cudagraph_time[b] = runtime
|
||||
else:
|
||||
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
|
||||
eager_runtime = do_bench(
|
||||
lambda: model(input_ids[:b], positions[:b])) # noqa
|
||||
eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa
|
||||
full_cudagraph_time[b] = runtime
|
||||
eager_time[b] = eager_runtime
|
||||
|
||||
# print in tabular format
|
||||
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
|
||||
for b in cudagraph_sizes:
|
||||
print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
||||
f"\t{piecewise_cudagraph_time[b]:.3f}")
|
||||
print(
|
||||
f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
||||
f"\t{piecewise_cudagraph_time[b]:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -31,8 +31,9 @@ def reset_global_counter():
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
def silly_attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
"""
|
||||
Unified attention implementation that depends on
|
||||
all inputs and affects the output.
|
||||
@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out.copy_(q + k + v)
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
def silly_attention_fake(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
"""Fake implementation for testing"""
|
||||
return
|
||||
|
||||
@ -60,5 +62,5 @@ direct_register_custom_op(
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
tags=(torch._C.Tag.cudagraph_unsafe, ),
|
||||
tags=(torch._C.Tag.cudagraph_unsafe,),
|
||||
)
|
||||
|
||||
@ -8,18 +8,30 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AsyncTPPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
PassConfig, VllmConfig)
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import (compare_two_settings, create_new_process_for_each_test,
|
||||
multi_gpu_test)
|
||||
from ..utils import (
|
||||
compare_two_settings,
|
||||
create_new_process_for_each_test,
|
||||
multi_gpu_test,
|
||||
)
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -33,14 +45,13 @@ prompts = [
|
||||
|
||||
|
||||
class TestMMRSModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
||||
(self.hidden_size * 2, hidden_size)),
|
||||
requires_grad=False)
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
|
||||
)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
@ -66,14 +77,13 @@ class TestMMRSModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAGMMModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.weight = torch.nn.Parameter(torch.empty(
|
||||
(hidden_size, hidden_size)),
|
||||
requires_grad=False)
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty((hidden_size, hidden_size)), requires_grad=False
|
||||
)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.weight, std=0.02)
|
||||
|
||||
@ -96,20 +106,21 @@ class TestAGMMModel(torch.nn.Module):
|
||||
|
||||
|
||||
class _BaseScaledMMModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
|
||||
.contiguous().transpose(0, 1)
|
||||
self.weight = (
|
||||
torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
# Initialize scale_b for _scaled_mm.
|
||||
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
||||
|
||||
|
||||
class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
||||
@ -117,11 +128,13 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||
"""
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||
scaled_mm = torch._scaled_mm(fp8_input,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype)
|
||||
scaled_mm = torch._scaled_mm(
|
||||
fp8_input,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
@ -133,7 +146,6 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||
|
||||
|
||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the all gather + scaled_mm in the FX graph
|
||||
@ -143,11 +155,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||
|
||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||
scaled_mm = torch._scaled_mm(all_gather,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype)
|
||||
scaled_mm = torch._scaled_mm(
|
||||
all_gather,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
return scaled_mm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@ -158,7 +172,6 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
|
||||
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
||||
@ -167,11 +180,14 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||
"""
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=input.device)
|
||||
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
|
||||
self.scale_b, None)
|
||||
mm_out = torch.empty(
|
||||
(fp8_input.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
torch.ops._C.cutlass_scaled_mm(
|
||||
mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
|
||||
)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
@ -183,7 +199,6 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||
|
||||
|
||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the all gather + cutlass_scaled_mm
|
||||
@ -195,11 +210,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||
|
||||
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=all_gather.device)
|
||||
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
|
||||
scale_a, self.scale_b, None)
|
||||
mm_out = torch.empty(
|
||||
(all_gather.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=all_gather.device,
|
||||
)
|
||||
torch.ops._C.cutlass_scaled_mm(
|
||||
mm_out, all_gather, self.weight, scale_a, self.scale_b, None
|
||||
)
|
||||
return mm_out
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@ -210,23 +228,37 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("test_model", [
|
||||
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"test_model",
|
||||
[
|
||||
TestMMRSModel,
|
||||
TestAGMMModel,
|
||||
TestScaledMMRSModel,
|
||||
TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel,
|
||||
TestAGCutlassScaledMMModel,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel,
|
||||
TestAGCutlassScaledMMModel) and dtype == torch.float16:
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_async_tp_pass_replace(
|
||||
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
|
||||
):
|
||||
if (
|
||||
test_model
|
||||
in (
|
||||
TestScaledMMRSModel,
|
||||
TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel,
|
||||
TestAGCutlassScaledMMModel,
|
||||
)
|
||||
and dtype == torch.float16
|
||||
):
|
||||
pytest.skip(
|
||||
"Only bf16 high precision output types are supported for " \
|
||||
"Only bf16 high precision output types are supported for "
|
||||
"per-token (row-wise) scaling"
|
||||
)
|
||||
|
||||
@ -235,19 +267,24 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(num_processes, test_model,
|
||||
batch_size, seq_len, hidden_size,
|
||||
dtype),
|
||||
nprocs=nprocs)
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
def async_tp_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
@ -255,13 +292,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
@ -269,27 +308,28 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
||||
enable_async_tp=True, ), )
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(
|
||||
enable_async_tp=True,
|
||||
),
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model_name,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
model = test_model_cls(hidden_size,
|
||||
dtype) # Pass dtype to model constructor
|
||||
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype,
|
||||
requires_grad=False)
|
||||
hidden_states = torch.randn(
|
||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||
)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
@ -306,10 +346,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||
@ -342,12 +382,10 @@ def test_async_tp_pass_correctness(
|
||||
common_args.append("--enforce-eager")
|
||||
|
||||
compilation_config = {
|
||||
'level': 3,
|
||||
'compile_sizes': [2, 4, 8],
|
||||
'splitting_ops': [],
|
||||
'pass_config': {
|
||||
'enable_async_tp': async_tp_enabled
|
||||
},
|
||||
"level": 3,
|
||||
"compile_sizes": [2, 4, 8],
|
||||
"splitting_ops": [],
|
||||
"pass_config": {"enable_async_tp": async_tp_enabled},
|
||||
}
|
||||
|
||||
async_tp_env = tp_env = {
|
||||
@ -372,9 +410,6 @@ def test_async_tp_pass_correctness(
|
||||
"mp",
|
||||
]
|
||||
|
||||
compare_two_settings(model_id,
|
||||
async_tp_args,
|
||||
tp_args,
|
||||
async_tp_env,
|
||||
tp_env,
|
||||
method="generate")
|
||||
compare_two_settings(
|
||||
model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate"
|
||||
)
|
||||
|
||||
@ -103,23 +103,28 @@ def test_compile_correctness(
|
||||
attn_backend = test_setting.attn_backend
|
||||
method = test_setting.method
|
||||
if cuda_device_count_stateless() < pp_size * tp_size:
|
||||
pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||
f"{cuda_device_count_stateless()}")
|
||||
pytest.skip(
|
||||
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||
f"{cuda_device_count_stateless()}"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
final_args = [
|
||||
"--enforce-eager", *model_args, "-pp",
|
||||
str(pp_size), "-tp",
|
||||
str(tp_size)
|
||||
"--enforce-eager",
|
||||
*model_args,
|
||||
"-pp",
|
||||
str(pp_size),
|
||||
"-tp",
|
||||
str(tp_size),
|
||||
]
|
||||
|
||||
all_args: list[list[str]] = []
|
||||
all_envs: list[dict[str, str] | None] = []
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.PIECEWISE,
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.PIECEWISE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-O{level}"])
|
||||
all_envs.append({})
|
||||
@ -130,14 +135,15 @@ def test_compile_correctness(
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close")
|
||||
method=method if method != "generate" else "generate_close",
|
||||
)
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-O{level}"])
|
||||
all_envs.append({})
|
||||
|
||||
@ -9,11 +9,11 @@ from vllm.utils import _is_torch_equal_or_newer
|
||||
|
||||
|
||||
def test_version():
|
||||
assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev')
|
||||
assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev')
|
||||
assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev')
|
||||
assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev')
|
||||
assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev')
|
||||
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
|
||||
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
|
||||
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
|
||||
assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
|
||||
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||
|
||||
|
||||
def test_use_cudagraphs_dynamic(monkeypatch):
|
||||
@ -21,7 +21,7 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
||||
vllm_config = VllmConfig()
|
||||
assert vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
vllm_config = VllmConfig()
|
||||
assert not vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
@ -44,19 +44,23 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
assert vllm.envs.VLLM_USE_V1
|
||||
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
|
||||
|
||||
compilation_config = {
|
||||
"use_cudagraph": False, # speed things up a bit
|
||||
}
|
||||
with (
|
||||
compilation_counter.expect(num_cache_entries_updated=0,
|
||||
num_compiled_artifacts_saved=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner('facebook/opt-125m',
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
compilation_counter.expect(
|
||||
num_cache_entries_updated=0, num_compiled_artifacts_saved=0
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@ -67,22 +71,25 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||
assert vllm.envs.VLLM_USE_V1
|
||||
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [100],
|
||||
"use_cudagraph": enabled,
|
||||
}
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_gpu_runner_capture_triggers=1 if enabled else 0,
|
||||
num_cudagraph_captured=13 if enabled else 0,
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner('facebook/opt-125m',
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_gpu_runner_capture_triggers=1 if enabled else 0,
|
||||
num_cudagraph_captured=13 if enabled else 0,
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@ -90,14 +97,17 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||
@pytest.mark.forked
|
||||
def test_dynamo_as_is(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
compilation_counter.expect(dynamo_as_is_count=1),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner('facebook/opt-125m',
|
||||
compilation_config={"level": 1},
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
compilation_counter.expect(dynamo_as_is_count=1),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config={"level": 1},
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@ -105,14 +115,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
|
||||
@pytest.mark.forked
|
||||
def test_no_compilation(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
with (
|
||||
compilation_counter.expect(num_graphs_seen=0,
|
||||
dynamo_as_is_count=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner('facebook/opt-125m',
|
||||
compilation_config={"level": 0},
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config={"level": 0},
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@ -120,77 +132,73 @@ def test_no_compilation(vllm_runner, monkeypatch):
|
||||
@pytest.mark.forked
|
||||
def test_enforce_eager(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
compilation_counter.expect(num_graphs_seen=0,
|
||||
dynamo_as_is_count=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner('facebook/opt-125m',
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def test_splitting_ops_dynamic():
|
||||
# Default config
|
||||
config = VllmConfig()
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
# When use_inductor_graph_partition=True
|
||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
||||
if _is_torch_equal_or_newer("2.9.0.dev"):
|
||||
# inductor graph partition is only available in PyTorch 2.9+.
|
||||
# this is a fast config check so we are not using pytest.skip.
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["silly_attention"]))
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
|
||||
)
|
||||
)
|
||||
# should ignore splitting_ops
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
|
||||
# When attn_fusion pass enabled.
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
))
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# cudagraph mode also fall back to FULL
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.FULL
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
# splitting_ops can not contain attention ops when attn_fusion
|
||||
# pass enabled.
|
||||
with pytest.raises(AssertionError):
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
))
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
)
|
||||
)
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
))
|
||||
if _is_torch_equal_or_newer("2.9.0.dev"):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# enable_attn_fusion is directly support under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
@ -4,10 +4,15 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
||||
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
@ -18,32 +23,42 @@ MLP_SIZE = 128
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
def run_model(
|
||||
vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode
|
||||
):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(torch.randn(2, MLP_SIZE).cuda())
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(torch.randn(2, MLP_SIZE).cuda())
|
||||
|
||||
output = output.cpu()
|
||||
@ -52,22 +67,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||
|
||||
def test_ignore_torch_compile_decorator():
|
||||
# piecewise
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
@support_torch_compile
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -79,66 +93,60 @@ def test_ignore_torch_compile_decorator():
|
||||
return x
|
||||
|
||||
@ignore_torch_compile
|
||||
class B(A):
|
||||
...
|
||||
class B(A): ...
|
||||
|
||||
@support_torch_compile
|
||||
class C(B):
|
||||
...
|
||||
class C(B): ...
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
# A has support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
# B's ignore_torch_compile should override A's support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
# C's support_torch_compile should override B's ignore_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
@support_torch_compile(
|
||||
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
)
|
||||
class B(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -152,15 +160,11 @@ class B(nn.Module):
|
||||
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||
cache_config.kv_sharing_fast_prefill)
|
||||
@support_torch_compile(
|
||||
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
)
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
@ -175,54 +179,60 @@ class A(nn.Module):
|
||||
|
||||
|
||||
def test_conditional_compile_enable_if():
|
||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=True, ),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=True,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
),
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
# A has support_torch_compile but enable_if fn returns False
|
||||
# enalbe_if will be True for B, so we expect mod1 and mod2
|
||||
# to be compiled
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
# 3 piecewise graphs per instance of B()
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
# 3 piecewise graphs per instance of B()
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
# Set kv_sharing_fast_prefill=False
|
||||
# which will cause A to be compiled and B to not be compiled
|
||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=False, ),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=False,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=7,
|
||||
# 3 attn ops and 4 non-attn ops
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=7,
|
||||
# 3 attn ops and 4 non-attn ops
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
@ -14,8 +14,7 @@ from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
PassConfig)
|
||||
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
@ -25,43 +24,54 @@ from ..utils import create_new_process_for_each_test
|
||||
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||
("facebook/opt-125m", {}),
|
||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||
"dtype": torch.float16,
|
||||
}),
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
|
||||
"dtype": torch.float16,
|
||||
}),
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||
{
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
),
|
||||
(
|
||||
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
|
||||
{
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
),
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||
]
|
||||
|
||||
if all:
|
||||
|
||||
# TODO: figure out why this fails.
|
||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
|
||||
"quantization": "gguf"
|
||||
}))
|
||||
TEST_MODELS.append(
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
|
||||
)
|
||||
|
||||
if is_quant_method_supported("gptq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
|
||||
"quantization": "gptq"
|
||||
}))
|
||||
TEST_MODELS.append(
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
|
||||
)
|
||||
|
||||
if is_quant_method_supported("gptq_marlin"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
|
||||
"quantization": "gptq_marlin"
|
||||
}))
|
||||
TEST_MODELS.append(
|
||||
(
|
||||
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
|
||||
{"quantization": "gptq_marlin"},
|
||||
)
|
||||
)
|
||||
|
||||
if is_quant_method_supported("gptq_marlin_24"):
|
||||
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
|
||||
"quantization": "gptq_marlin_24"
|
||||
}))
|
||||
TEST_MODELS.append(
|
||||
(
|
||||
"alexm-nm/tinyllama-24-marlin24-4bit-g128",
|
||||
{"quantization": "gptq_marlin_24"},
|
||||
)
|
||||
)
|
||||
|
||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
||||
"quantization": "AWQ"
|
||||
}))
|
||||
TEST_MODELS.append(
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
|
||||
)
|
||||
|
||||
if keywords is None:
|
||||
return TEST_MODELS
|
||||
@ -95,22 +105,34 @@ def test_full_graph(
|
||||
"compilation_config, model_info",
|
||||
[
|
||||
# additional compile sizes, only some of the models
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
compile_sizes=[1, 2]), model)
|
||||
(
|
||||
CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]),
|
||||
model,
|
||||
)
|
||||
for model in models_list(all=False)
|
||||
] + [
|
||||
]
|
||||
+ [
|
||||
# RMSNorm + quant fusion, only 8-bit quant models
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(enable_fusion=True,
|
||||
enable_noop=True)), model)
|
||||
(
|
||||
CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
),
|
||||
model,
|
||||
)
|
||||
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
] + [
|
||||
]
|
||||
+ [
|
||||
# Test depyf integration works
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
debug_dump_path=tempfile.gettempdir()),
|
||||
("facebook/opt-125m", {})),
|
||||
] + [
|
||||
(
|
||||
CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()
|
||||
),
|
||||
("facebook/opt-125m", {}),
|
||||
),
|
||||
]
|
||||
+ [
|
||||
# graph inductor partition
|
||||
(
|
||||
CompilationConfig(
|
||||
@ -119,20 +141,24 @@ def test_full_graph(
|
||||
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
compile_sizes=[1, 2]),
|
||||
model) for model in models_list(all=False)
|
||||
compile_sizes=[1, 2],
|
||||
),
|
||||
model,
|
||||
)
|
||||
for model in models_list(all=False)
|
||||
if is_torch_equal_or_newer("2.9.0.dev")
|
||||
])
|
||||
],
|
||||
)
|
||||
# only test some of the models
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_compile_config(
|
||||
compilation_config: CompilationConfig,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
):
|
||||
if (compilation_config.use_inductor_graph_partition
|
||||
and not is_torch_equal_or_newer("2.9.0.dev")):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"
|
||||
):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
model, model_kwargs = model_info
|
||||
print(f"MODEL={model}")
|
||||
@ -156,8 +182,7 @@ def test_fp8_kv_scale_compile(optimization_level: int):
|
||||
|
||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
compilation_config = CompilationConfig(
|
||||
@ -171,14 +196,16 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
"kv_cache_dtype": "fp8",
|
||||
"max_model_len": 1024,
|
||||
}
|
||||
with caplog_vllm.at_level(
|
||||
logging.DEBUG), global_force_attn_backend_context_manager(
|
||||
_Backend.FLASHINFER):
|
||||
with (
|
||||
caplog_vllm.at_level(logging.DEBUG),
|
||||
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
|
||||
):
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
try:
|
||||
assert ("Fused quantization onto 48 attention nodes"
|
||||
in caplog_vllm.text), caplog_vllm.text
|
||||
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
|
||||
caplog_vllm.text
|
||||
)
|
||||
except AssertionError:
|
||||
# Note: this message is only triggered when the compilation goes
|
||||
# through the custom pass. Due to multiple layers of cache on
|
||||
@ -189,8 +216,11 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
assert "Fused quantization" not in caplog_vllm.text
|
||||
|
||||
|
||||
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
||||
model_kwargs: dict[str, Any]):
|
||||
def run_model(
|
||||
compile_config: Union[int, CompilationConfig],
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
):
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
|
||||
@ -14,10 +14,8 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -28,7 +26,6 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class TestSiluMul(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int = 128):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@ -36,8 +33,7 @@ class TestSiluMul(torch.nn.Module):
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
if TEST_FP8:
|
||||
self.w = torch.rand(hidden_size,
|
||||
hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
@ -46,17 +42,14 @@ class TestSiluMul(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
if TEST_FP8:
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.wscale)
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
else:
|
||||
return y
|
||||
|
||||
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )
|
||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
if TEST_FP8 and do_fusion:
|
||||
@ -69,7 +62,6 @@ class TestSiluMul(torch.nn.Module):
|
||||
|
||||
|
||||
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -78,10 +70,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size), dtype=dtype))
|
||||
torch.empty((intermediate_size, hidden_size), dtype=dtype)
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
self.norm.weight = torch.nn.Parameter(
|
||||
torch.ones(intermediate_size, dtype=dtype))
|
||||
torch.ones(intermediate_size, dtype=dtype)
|
||||
)
|
||||
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
@ -89,8 +83,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.w = torch.rand(hidden_size,
|
||||
intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
@ -120,10 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
return (hidden_states, residual)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@ -137,12 +128,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
head_dim=64,
|
||||
rotary_dim=None,
|
||||
max_position=2048,
|
||||
base=10000):
|
||||
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.rotary_dim = rotary_dim or head_dim
|
||||
@ -173,21 +159,15 @@ class TestRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
head_dim=64,
|
||||
num_heads=4,
|
||||
max_position=2048,
|
||||
base=10000):
|
||||
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = head_dim * num_heads
|
||||
|
||||
self.qkv_proj = torch.nn.Linear(self.hidden_size,
|
||||
self.hidden_size * 3,
|
||||
bias=False,
|
||||
dtype=torch.float16)
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -233,21 +213,24 @@ MODELS = [
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODELS)
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
|
||||
)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion else [noop_pass, cleanup_pass])
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
@ -260,8 +243,7 @@ def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
|
||||
is None) # noqa: E501
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
|
||||
@ -5,17 +5,26 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
RMSNormQuantFusionPass)
|
||||
from vllm.compilation.fusion import (
|
||||
FUSED_OPS,
|
||||
QUANT_OPS,
|
||||
FusedRMSQuantKey,
|
||||
RMSNormQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, QuantKey, ScaleDesc)
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
|
||||
Fp8LinearOp,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
@ -25,9 +34,15 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float, static: bool,
|
||||
cuda_force_torch: bool, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
static: bool,
|
||||
cuda_force_torch: bool,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cuda_force_torch = cuda_force_torch
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
@ -54,17 +69,15 @@ class TestModel(torch.nn.Module):
|
||||
resid = torch.sqrt(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w[0],
|
||||
self.wscale[0],
|
||||
input_scale=self.scale[0])
|
||||
x2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
x3 = self.fp8_linear.apply(y2,
|
||||
self.w[1],
|
||||
self.wscale[1],
|
||||
input_scale=self.scale[1])
|
||||
x3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
return y3
|
||||
|
||||
@ -74,7 +87,7 @@ class TestModel(torch.nn.Module):
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
|
||||
]
|
||||
|
||||
|
||||
@ -85,22 +98,27 @@ class TestModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("static", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize("cuda_force_torch",
|
||||
[True, False] if cutlass_fp8_supported() else [True])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
cuda_force_torch):
|
||||
@pytest.mark.parametrize(
|
||||
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
def test_fusion_rmsnorm_quant(
|
||||
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
||||
@ -10,14 +10,24 @@ from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
||||
ModelConfig, PassConfig, VllmConfig)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
GroupShape, QuantFP8)
|
||||
GroupShape,
|
||||
QuantFP8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
@ -26,7 +36,6 @@ from .backend import TestBackend
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -47,7 +56,6 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -68,25 +76,22 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.quant_fp8 = QuantFP8(static=True,
|
||||
group_shape=GroupShape.PER_TENSOR)
|
||||
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size),
|
||||
dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
torch.ops._C.static_scaled_fp8_quant(self.output,
|
||||
norm_output.contiguous(),
|
||||
self.scale)
|
||||
torch.ops._C.static_scaled_fp8_quant(
|
||||
self.output, norm_output.contiguous(), self.scale
|
||||
)
|
||||
return self.output, residual_output
|
||||
|
||||
def ops_in_model_after(self):
|
||||
@ -95,35 +100,33 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size),
|
||||
dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(token_num, 128)
|
||||
scale_n = hidden_size // 16
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||
dtype=torch.int32)
|
||||
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
||||
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
|
||||
self.output_scale, self.scale)
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
self.output, norm_output, self.output_scale, self.scale
|
||||
)
|
||||
return self.output, residual_output, self.output_scale
|
||||
|
||||
def ops_in_model_after(self):
|
||||
@ -132,7 +135,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.scaled_fp4_quant.default
|
||||
torch.ops._C.scaled_fp4_quant.default,
|
||||
]
|
||||
|
||||
|
||||
@ -145,41 +148,55 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||
# TODO: Enable with torch==2.8.0
|
||||
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||
])
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||
reason="flashinfer is not found or flashinfer "
|
||||
"is not compiled with trtllm_allreduce_fusion")
|
||||
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
"is not compiled with trtllm_allreduce_fusion",
|
||||
)
|
||||
def test_all_reduce_fusion_pass_replace(
|
||||
test_model: torch.nn.Module,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
num_processes = 2
|
||||
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
||||
and not current_platform.has_device_capability(100)):
|
||||
pytest.skip("Skip as nvfp4 is only supported on "
|
||||
"devices with compute capability 10.0 (Blackwell)")
|
||||
if (
|
||||
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
||||
and not current_platform.has_device_capability(100)
|
||||
):
|
||||
pytest.skip(
|
||||
"Skip as nvfp4 is only supported on "
|
||||
"devices with compute capability 10.0 (Blackwell)"
|
||||
)
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(num_processes, test_model,
|
||||
batch_size, seq_len, hidden_size,
|
||||
dtype),
|
||||
nprocs=nprocs)
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
def all_reduce_fusion_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
@ -187,39 +204,42 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"]))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
|
||||
)
|
||||
)
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_fi_allreduce_fusion=True, enable_noop=True)
|
||||
enable_fi_allreduce_fusion=True, enable_noop=True
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model_name,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
|
||||
cleanup_pass)
|
||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
|
||||
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
|
||||
@ -19,14 +19,23 @@ from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@ -40,14 +49,16 @@ backend_unfused: Optional[TestBackend] = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, quant_key",
|
||||
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
|
||||
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
|
||||
)
|
||||
@pytest.mark.parametrize("use_triton_fa", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="V0 attn quant fusion only on ROCm")
|
||||
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
||||
quant_key: QuantKey, use_triton_fa: bool):
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
|
||||
)
|
||||
def test_attention_fusion_v0(
|
||||
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
|
||||
):
|
||||
# Clean Dynamo cache to avoid reusing other test cases
|
||||
# (for some reason the reset at the end is not enough)
|
||||
torch._dynamo.reset()
|
||||
@ -69,22 +80,24 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config,
|
||||
model_config=ModelConfig(
|
||||
model=model,
|
||||
dtype=torch.bfloat16,
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=compile_config,
|
||||
model_config=ModelConfig(
|
||||
model=model,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
)
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
|
||||
llm = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048)
|
||||
llm = LLM(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=10,
|
||||
top_p=0.95)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
|
||||
|
||||
unfused_output = llm.generate(prompts, sampling_params)
|
||||
backend_unfused = None # Reset backend to make sure llm gets released
|
||||
@ -97,21 +110,25 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config,
|
||||
model_config=ModelConfig(
|
||||
model=model,
|
||||
dtype=torch.bfloat16,
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=compile_config,
|
||||
model_config=ModelConfig(
|
||||
model=model,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
)
|
||||
|
||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||
# so we initialize it during compilation.
|
||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||
llm2 = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048)
|
||||
llm2 = LLM(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
# check support
|
||||
attn_fusion_supported = [
|
||||
@ -132,9 +149,9 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
||||
for i in range(len(attn_nodes_pre)):
|
||||
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
||||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
||||
assert fused == attn_fusion_supported[i], \
|
||||
f"Node {i} {'' if fused else 'not '} expected " \
|
||||
f"to have fused output quant"
|
||||
assert fused == attn_fusion_supported[i], (
|
||||
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
|
||||
)
|
||||
|
||||
# check outputs
|
||||
fused_output = llm2.generate(prompts, sampling_params)
|
||||
@ -160,9 +177,16 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
||||
class AttentionQuantPatternModel(torch.nn.Module):
|
||||
"""Base model for AttentionQuantPattern fusion."""
|
||||
|
||||
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
|
||||
kv_cache_dtype: torch.dtype, device: torch.device,
|
||||
vllm_config: VllmConfig, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
vllm_config: VllmConfig,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_qo_heads = num_qo_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -197,33 +221,30 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
|
||||
-> AttentionMetadata:
|
||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
|
||||
"""Initialize attention metadata."""
|
||||
|
||||
# Create common attn metadata
|
||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size,
|
||||
query_lens=[1] * batch_size)
|
||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
self.block_size,
|
||||
self.device,
|
||||
arange_block_indices=True)
|
||||
batch_spec, self.block_size, self.device, arange_block_indices=True
|
||||
)
|
||||
|
||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size -
|
||||
1) // self.block_size
|
||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
|
||||
# Create dummy KV cache for FlashInfer TRTLLM
|
||||
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
kv_cache = torch.zeros(num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device)
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
# k/v as 1st dimention
|
||||
if use_hnd:
|
||||
@ -239,7 +260,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
|
||||
# Build attn metadata
|
||||
self.attn_metadata = self.builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata)
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
@ -254,27 +276,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.quant_key.scale.static,
|
||||
act_quant_group_shape=self.quant_key.scale.group_shape)
|
||||
act_quant_group_shape=self.quant_key.scale.group_shape,
|
||||
)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w", {
|
||||
"weight":
|
||||
torch.randn(hidden_size, hidden_size).to(
|
||||
dtype=FP8_DTYPE, device=self.device).t(),
|
||||
"wscale":
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
"scale":
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
})
|
||||
"w",
|
||||
{
|
||||
"weight": torch.randn(hidden_size, hidden_size)
|
||||
.to(dtype=FP8_DTYPE, device=self.device)
|
||||
.t(),
|
||||
"wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
return self.fp8_linear.apply(input=attn_output,
|
||||
weight=self.w["weight"],
|
||||
weight_scale=self.w["wscale"],
|
||||
input_scale=self.w["scale"])
|
||||
return self.fp8_linear.apply(
|
||||
input=attn_output,
|
||||
weight=self.w["weight"],
|
||||
weight_scale=self.w["wscale"],
|
||||
input_scale=self.w["scale"],
|
||||
)
|
||||
|
||||
|
||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
@ -287,42 +312,54 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w", {
|
||||
"weight":
|
||||
torch.randint(256, (hidden_size, hidden_size // 2),
|
||||
dtype=FP4_DTYPE,
|
||||
device=self.device),
|
||||
"wscale_swizzled":
|
||||
torch.randn(hidden_size, hidden_size // 16).to(
|
||||
dtype=FP8_DTYPE, device=self.device),
|
||||
"wscale":
|
||||
torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||
"scale":
|
||||
torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||
})
|
||||
"w",
|
||||
{
|
||||
"weight": torch.randint(
|
||||
256,
|
||||
(hidden_size, hidden_size // 2),
|
||||
dtype=FP4_DTYPE,
|
||||
device=self.device,
|
||||
),
|
||||
"wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
|
||||
dtype=FP8_DTYPE, device=self.device
|
||||
),
|
||||
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
quant_output, output_block_scale = scaled_fp4_quant(
|
||||
attn_output, 1 / self.w["scale"])
|
||||
return cutlass_scaled_fp4_mm(a=quant_output,
|
||||
b=self.w["weight"],
|
||||
block_scale_a=output_block_scale,
|
||||
block_scale_b=self.w["wscale_swizzled"],
|
||||
alpha=self.w["scale"] * self.w["wscale"],
|
||||
out_dtype=attn_output.dtype)
|
||||
attn_output, 1 / self.w["scale"]
|
||||
)
|
||||
return cutlass_scaled_fp4_mm(
|
||||
a=quant_output,
|
||||
b=self.w["weight"],
|
||||
block_scale_a=output_block_scale,
|
||||
block_scale_b=self.w["wscale_swizzled"],
|
||||
alpha=self.w["scale"] * self.w["wscale"],
|
||||
out_dtype=attn_output.dtype,
|
||||
)
|
||||
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel),
|
||||
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel)]
|
||||
MODELS = [
|
||||
(
|
||||
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel,
|
||||
),
|
||||
(
|
||||
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel,
|
||||
),
|
||||
]
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
elif current_platform.is_rocm():
|
||||
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
TestAttentionFp8StaticQuantPatternModel)]
|
||||
MODELS = [
|
||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||
]
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
else:
|
||||
MODELS = []
|
||||
@ -331,41 +368,53 @@ else:
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("batch_size",
|
||||
[7, 256, 533] if current_platform.is_cuda() else [8])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||
@pytest.mark.parametrize("backend",
|
||||
[_Backend.FLASHINFER] if current_platform.is_cuda()
|
||||
else [_Backend.TRITON_ATTN])
|
||||
@pytest.mark.parametrize(
|
||||
"split_attention",
|
||||
[False, True] if current_platform.is_rocm() else [False])
|
||||
"backend",
|
||||
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"split_attention", [False, True] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
# TODO(boyuan): test inductor graph partition on rocm
|
||||
@pytest.mark.parametrize(
|
||||
"use_inductor_graph_partition",
|
||||
[False] if current_platform.is_rocm() else [False, True])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test ROCm or CUDA")
|
||||
[False] if current_platform.is_rocm() else [False, True],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||
)
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(current_platform.is_cuda()
|
||||
and not current_platform.is_device_capability((10, 0)),
|
||||
reason="On CUDA only test on SM100(Blackwell)")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test ROCm or CUDA")
|
||||
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
head_size: int, batch_size: int,
|
||||
dtype: torch.dtype, model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend, split_attention: bool,
|
||||
use_inductor_graph_partition: bool,
|
||||
monkeypatch, dist_init, caplog_vllm):
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
|
||||
reason="On CUDA only test on SM100(Blackwell)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||
)
|
||||
def test_attention_quant_pattern(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend,
|
||||
split_attention: bool,
|
||||
use_inductor_graph_partition: bool,
|
||||
monkeypatch,
|
||||
dist_init,
|
||||
caplog_vllm,
|
||||
):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
if split_attention:
|
||||
@ -386,21 +435,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
custom_ops=["+quant_fp8"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
),
|
||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
||||
cache_config=CacheConfig(cache_dtype="fp8"),
|
||||
)
|
||||
|
||||
# Create test inputs
|
||||
q = torch.randn(batch_size,
|
||||
num_qo_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k = torch.randn(batch_size,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
v = torch.randn(batch_size,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
|
||||
k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
|
||||
# Mark first dimension as dynamic for realistic testing
|
||||
torch._dynamo.mark_dynamic(q, 0)
|
||||
@ -409,42 +450,53 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
vllm_config_unfused = copy.deepcopy(vllm_config)
|
||||
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
|
||||
attn_metadata=None, vllm_config=vllm_config_unfused
|
||||
), global_force_attn_backend_context_manager(backend):
|
||||
model_unfused = model_class(num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config_unfused)
|
||||
with (
|
||||
set_current_vllm_config(vllm_config_unfused),
|
||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
|
||||
global_force_attn_backend_context_manager(backend),
|
||||
):
|
||||
model_unfused = model_class(
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config_unfused,
|
||||
)
|
||||
model_unfused = model_unfused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
||||
batch_size, use_hnd=split_attention)
|
||||
batch_size, use_hnd=split_attention
|
||||
)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
result_unfused = model_unfused(q, k, v)
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_attn_fusion=True, enable_noop=True)
|
||||
with set_current_vllm_config(vllm_config), set_forward_context(
|
||||
attn_metadata=None, vllm_config=vllm_config
|
||||
), global_force_attn_backend_context_manager(backend):
|
||||
model_fused = model_class(num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config,
|
||||
w=model_unfused.w)
|
||||
enable_attn_fusion=True, enable_noop=True
|
||||
)
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
|
||||
global_force_attn_backend_context_manager(backend),
|
||||
):
|
||||
model_fused = model_class(
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config,
|
||||
w=model_unfused.w,
|
||||
)
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
|
||||
batch_size, use_hnd=split_attention)
|
||||
batch_size, use_hnd=split_attention
|
||||
)
|
||||
|
||||
# Create test backend with fusion passes enabled
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
@ -454,9 +506,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
||||
|
||||
# Compile model with fusion enabled
|
||||
model_compiled = torch.compile(model_fused,
|
||||
backend=test_backend,
|
||||
fullgraph=True)
|
||||
model_compiled = torch.compile(
|
||||
model_fused, backend=test_backend, fullgraph=True
|
||||
)
|
||||
assert model_compiled.attn._o_scale_float is None
|
||||
|
||||
result_fused_1 = model_compiled(q, k, v)
|
||||
@ -471,49 +523,49 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_2,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
# Check attn fusion support
|
||||
quant_key = model_class.quant_key
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
|
||||
vllm_config.compilation_config.static_forward_context.items()
|
||||
layer.impl.fused_output_quant_supported(quant_key)
|
||||
for key, layer in vllm_config.compilation_config.static_forward_context.items()
|
||||
]
|
||||
if any(attn_fusion_supported):
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
test_backend.check_before_ops([QUANT_OPS[quant_key]],
|
||||
fully_replaced=True)
|
||||
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
|
||||
|
||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||
|
||||
# Check attention ops in the graph before and after fusion
|
||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP,
|
||||
test_backend.graph_post_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
|
||||
|
||||
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post), \
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post), (
|
||||
"Should have same number of attention nodes before and after fusion"
|
||||
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
|
||||
)
|
||||
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
|
||||
"Attention should not have output_scale before fusion"
|
||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
|
||||
)
|
||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
|
||||
"Attention should have output_scale after fusion"
|
||||
)
|
||||
|
||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
|
||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||
"Attention should not have output_block_scale before fusion"
|
||||
)
|
||||
if quant_key.dtype == FP8_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||
"Attention should not have output_block_scale after FP8 fusion"
|
||||
)
|
||||
elif quant_key.dtype == FP4_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
|
||||
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
|
||||
"Attention should have output_block_scale after FP4 fusion"
|
||||
) # noqa: E501
|
||||
|
||||
# Check that results are close
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_1,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@ -6,14 +6,12 @@ import torch
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("num_tokens", [256, 1024])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
||||
def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
@ -22,7 +20,6 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
torch.manual_seed(1)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
# Chain of reshapes
|
||||
y = x.reshape(-1, 128, 32)
|
||||
@ -32,7 +29,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
# Final reshape that should remain
|
||||
b = a.reshape(-1, 128, 32)
|
||||
# No-op slice
|
||||
c = b[0:b.shape[0]]
|
||||
c = b[0 : b.shape[0]]
|
||||
# The pass should replace the result of this op with `c`.
|
||||
d = torch.slice_scatter(
|
||||
torch.ones_like(c), # Dummy tensor to be scattered into
|
||||
@ -43,10 +40,12 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
)
|
||||
return d
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
||||
@ -82,17 +81,18 @@ def test_non_noop_slice_preserved():
|
||||
x = torch.randn(16, 16)
|
||||
|
||||
class SliceModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
base = x.clone()
|
||||
src = torch.ones(15, 16)
|
||||
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
|
||||
return x[0:-1, :], y
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
backend = TestBackend(noop_pass)
|
||||
|
||||
@ -28,7 +28,6 @@ def test_bad_callable():
|
||||
|
||||
# Pass that inherits from InductorPass
|
||||
class ProperPass(InductorPass):
|
||||
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
pass
|
||||
|
||||
@ -39,8 +38,7 @@ class ProperPass(InductorPass):
|
||||
ProperPass(),
|
||||
# Can also wrap callables in CallableInductorPass for compliance
|
||||
CallableInductorPass(simple_callable),
|
||||
CallableInductorPass(simple_callable,
|
||||
InductorPass.hash_source(__file__))
|
||||
CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)),
|
||||
],
|
||||
)
|
||||
def test_pass_manager_uuid(callable):
|
||||
@ -65,8 +63,9 @@ def test_pass_manager_uuid(callable):
|
||||
|
||||
# UUID should be different due to config change
|
||||
config2 = copy.deepcopy(config)
|
||||
config2.compilation_config.pass_config.enable_fusion = not \
|
||||
config2.compilation_config.pass_config.enable_fusion
|
||||
config2.compilation_config.pass_config.enable_fusion = (
|
||||
not config2.compilation_config.pass_config.enable_fusion
|
||||
)
|
||||
pass_manager3 = PostGradPassManager()
|
||||
pass_manager3.configure(config2)
|
||||
pass_manager3.add(callable)
|
||||
|
||||
@ -12,14 +12,20 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
PassConfig, VllmConfig)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
@ -36,16 +42,15 @@ prompts = [
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
vllm_config: VllmConfig = None):
|
||||
def __init__(
|
||||
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)))
|
||||
torch.empty((intermediate_size, hidden_size))
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
@ -64,7 +69,7 @@ class TestModel(torch.nn.Module):
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
#matrix multiplication
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
|
||||
@ -82,7 +87,7 @@ class TestModel(torch.nn.Module):
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default
|
||||
torch.ops.vllm.all_gather.default,
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
@ -90,18 +95,16 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestQuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
vllm_config: VllmConfig = None):
|
||||
def __init__(
|
||||
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.vllm_config = vllm_config
|
||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
||||
(intermediate_size, hidden_size)),
|
||||
requires_grad=False)
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
@ -111,8 +114,7 @@ class TestQuantModel(torch.nn.Module):
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
# Create a weight that is compatible with torch._scaled_mm,
|
||||
# which expects a column-major layout.
|
||||
self.w = torch.rand(hidden_size,
|
||||
intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
@ -129,7 +131,7 @@ class TestQuantModel(torch.nn.Module):
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
#matrix multiplication
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
|
||||
@ -140,45 +142,51 @@ class TestQuantModel(torch.nn.Module):
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(
|
||||
norm_output.device))
|
||||
fp8_linear_result = self.fp8_linear.apply(
|
||||
norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(norm_output.device),
|
||||
)
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
|
||||
def ops_in_model_before(self):
|
||||
ops_to_remove = [torch.ops.vllm.all_reduce.default
|
||||
] # Always removed by SP
|
||||
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
|
||||
# The following are only removed if fusion happens
|
||||
if self.vllm_config and self.vllm_config.compilation_config \
|
||||
.pass_config.enable_fusion:
|
||||
ops_to_remove.extend([
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
])
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
ops_to_remove.extend(
|
||||
[
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
)
|
||||
return ops_to_remove
|
||||
|
||||
def ops_in_model_after(self):
|
||||
ops_to_add = [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default
|
||||
torch.ops.vllm.all_gather.default,
|
||||
]
|
||||
# The following is only added if fusion happens
|
||||
if self.vllm_config and self.vllm_config.compilation_config \
|
||||
.pass_config.enable_fusion:
|
||||
ops_to_add.append(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
||||
return ops_to_add
|
||||
|
||||
def ops_in_model(self):
|
||||
if self.vllm_config and self.vllm_config.compilation_config \
|
||||
.pass_config.enable_fusion:
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
# If fusion happens, the fused op is the one
|
||||
# we check for (de)functionalization
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||
] # noqa: E501
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] # noqa: E501
|
||||
else:
|
||||
# If no fusion, the original ops are checked
|
||||
return [
|
||||
@ -195,30 +203,47 @@ class TestQuantModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module],
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype,
|
||||
enable_fusion: bool):
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(
|
||||
test_model_cls: type[torch.nn.Module],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
):
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(num_processes, test_model_cls,
|
||||
batch_size, seq_len, hidden_size,
|
||||
dtype, enable_fusion),
|
||||
nprocs=nprocs)
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(
|
||||
num_processes,
|
||||
test_model_cls,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_fusion,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def sequence_parallelism_pass_on_test_model(
|
||||
local_rank: int, world_size: int,
|
||||
test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype, enable_fusion: bool):
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
test_model_cls: type[torch.nn.Module],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
@ -226,13 +251,15 @@ def sequence_parallelism_pass_on_test_model(
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
@ -240,27 +267,28 @@ def sequence_parallelism_pass_on_test_model(
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True)) # NoOp needed for fusion
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
)
|
||||
) # NoOp needed for fusion
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model_name,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
passes_for_backend: list[VllmInductorPass] = \
|
||||
[noop_pass, sequence_parallelism_pass]
|
||||
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
|
||||
|
||||
if enable_fusion:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
@ -271,12 +299,9 @@ def sequence_parallelism_pass_on_test_model(
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
|
||||
model = test_model_cls(hidden_size,
|
||||
hidden_size * 2,
|
||||
vllm_config=vllm_config)
|
||||
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
@ -297,8 +322,7 @@ def sequence_parallelism_pass_on_test_model(
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
||||
op) is None # noqa: E501
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
|
||||
@ -8,10 +8,15 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
|
||||
FUSED_OPS,
|
||||
SILU_MUL_OP,
|
||||
ActivationQuantFusionPass,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
@ -19,9 +24,14 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
GroupShape,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, cutlass_fp8_supported)
|
||||
Fp8LinearOp,
|
||||
cutlass_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
@ -36,7 +46,6 @@ def is_nvfp4_supported():
|
||||
|
||||
|
||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@ -53,10 +62,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.wscale)
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@ -67,11 +73,12 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||
super().__init__()
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
silu_and_mul_nvfp4_quant_supported)
|
||||
silu_and_mul_nvfp4_quant_supported,
|
||||
)
|
||||
|
||||
assert silu_and_mul_nvfp4_quant_supported
|
||||
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@ -88,12 +95,14 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
|
||||
out = cutlass_scaled_fp4_mm(a=y_quant,
|
||||
b=self.w,
|
||||
block_scale_a=y_block_scale,
|
||||
block_scale_b=self.w_block_scale,
|
||||
alpha=self.alpha,
|
||||
out_dtype=y.dtype)
|
||||
out = cutlass_scaled_fp4_mm(
|
||||
a=y_quant,
|
||||
b=self.w,
|
||||
block_scale_a=y_block_scale,
|
||||
block_scale_b=self.w_block_scale,
|
||||
alpha=self.alpha,
|
||||
out_dtype=y.dtype,
|
||||
)
|
||||
return out
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@ -108,16 +117,24 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"model_class",
|
||||
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]))
|
||||
cast(
|
||||
list[type],
|
||||
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||
if is_nvfp4_supported()
|
||||
else [TestSiluMulFp8QuantModel],
|
||||
),
|
||||
)
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize("cuda_force_torch",
|
||||
[True, False] if cutlass_fp8_supported() else [True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
cuda_force_torch):
|
||||
@pytest.mark.parametrize(
|
||||
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
def test_fusion_silu_and_mul_quant(
|
||||
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
|
||||
):
|
||||
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
|
||||
pytest.skip("Duplicate tests for NVFP4")
|
||||
|
||||
@ -129,17 +146,13 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = VllmConfig()
|
||||
config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
|
||||
)
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
passes = [
|
||||
NoOpEliminationPass(config), fusion_pass,
|
||||
PostCleanupPass(config)
|
||||
]
|
||||
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
|
||||
backend = TestBackend(*passes)
|
||||
model = model_class(hidden_size=hidden_size,
|
||||
cuda_force_torch=cuda_force_torch,
|
||||
x=x)
|
||||
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
|
||||
|
||||
# First dimension dynamic
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
@ -155,10 +168,9 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
|
||||
torch.testing.assert_close(result[0].to(dtype=dtype),
|
||||
result2[0].to(dtype=dtype),
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ from vllm.config import CompilationLevel
|
||||
|
||||
|
||||
class MyMod(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
if cache is not None:
|
||||
return x + cache
|
||||
@ -18,12 +17,12 @@ class MyMod(torch.nn.Module):
|
||||
|
||||
|
||||
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||
super().__init__(compiled_callable,
|
||||
compilation_level=CompilationLevel.DYNAMO_ONCE)
|
||||
super().__init__(
|
||||
compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
# this is the function to be compiled
|
||||
@ -54,10 +53,8 @@ def test_torch_compile_wrapper():
|
||||
|
||||
# for new input, dispatch to the compiled code directly
|
||||
new_x = torch.tensor([3])
|
||||
assert wrapper(new_x,
|
||||
None).item() == 6 # dispatch to the first compiled code
|
||||
assert wrapper(
|
||||
new_x, cache).item() == 5 # dispatch to the second compiled code
|
||||
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
|
||||
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
|
||||
|
||||
for wrapper in wrappers:
|
||||
# make sure they have independent compiled codes
|
||||
|
||||
@ -14,8 +14,9 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
|
||||
def create_config():
|
||||
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
|
||||
trust_remote_code=True)
|
||||
engine_args = EngineArgs(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||
)
|
||||
return engine_args.create_engine_config()
|
||||
|
||||
# Create config with CUDA_VISIBLE_DEVICES set normally
|
||||
@ -34,16 +35,18 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
empty_config_dict.pop("instance_id", None)
|
||||
|
||||
assert deep_compare(normal_config_dict, empty_config_dict), (
|
||||
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
|
||||
" should be equivalent")
|
||||
'Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=""'
|
||||
" should be equivalent"
|
||||
)
|
||||
|
||||
|
||||
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
||||
# In testing, this method needs to be nested inside as ray does not
|
||||
# see the test module.
|
||||
def create_config():
|
||||
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
|
||||
trust_remote_code=True)
|
||||
engine_args = EngineArgs(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||
)
|
||||
return engine_args.create_engine_config()
|
||||
|
||||
config = create_config()
|
||||
@ -51,6 +54,7 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
||||
assert parallel_config.ray_runtime_env is None
|
||||
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
runtime_env = {
|
||||
@ -59,13 +63,13 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
||||
},
|
||||
}
|
||||
|
||||
config_ref = ray.remote(create_config).options(
|
||||
runtime_env=runtime_env).remote()
|
||||
config_ref = ray.remote(create_config).options(runtime_env=runtime_env).remote()
|
||||
|
||||
config = ray.get(config_ref)
|
||||
parallel_config = config.parallel_config
|
||||
assert parallel_config.ray_runtime_env is not None
|
||||
assert parallel_config.ray_runtime_env.env_vars().get(
|
||||
"TEST_ENV_VAR") == "test_value"
|
||||
assert (
|
||||
parallel_config.ray_runtime_env.env_vars().get("TEST_ENV_VAR") == "test_value"
|
||||
)
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
@ -16,13 +16,13 @@ def test_mp_reducer(monkeypatch):
|
||||
"""
|
||||
|
||||
# Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
# Ensure transformers_modules is not in sys.modules
|
||||
if 'transformers_modules' in sys.modules:
|
||||
del sys.modules['transformers_modules']
|
||||
if "transformers_modules" in sys.modules:
|
||||
del sys.modules["transformers_modules"]
|
||||
|
||||
with patch('multiprocessing.reducer.register') as mock_register:
|
||||
with patch("multiprocessing.reducer.register") as mock_register:
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
max_model_len=32,
|
||||
@ -36,7 +36,8 @@ def test_mp_reducer(monkeypatch):
|
||||
)
|
||||
|
||||
assert mock_register.called, (
|
||||
"multiprocessing.reducer.register should have been called")
|
||||
"multiprocessing.reducer.register should have been called"
|
||||
)
|
||||
|
||||
vllm_config_registered = False
|
||||
for call_args in mock_register.call_args_list:
|
||||
@ -45,8 +46,7 @@ def test_mp_reducer(monkeypatch):
|
||||
vllm_config_registered = True
|
||||
|
||||
reducer_func = call_args[0][1]
|
||||
assert callable(
|
||||
reducer_func), "Reducer function should be callable"
|
||||
assert callable(reducer_func), "Reducer function should be callable"
|
||||
break
|
||||
|
||||
assert vllm_config_registered, (
|
||||
|
||||
@ -30,22 +30,27 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
BatchEncoding, BatchFeature)
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
BatchFeature,
|
||||
)
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from tests.models.utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs)
|
||||
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config.model import (ConvertOption, RunnerOption,
|
||||
_get_and_verify_dtype)
|
||||
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.distributed import (
|
||||
cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
@ -82,12 +87,13 @@ class ImageAssetPrompts(TypedDict):
|
||||
|
||||
|
||||
class ImageTestAssets(list[ImageAsset]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([
|
||||
ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom"),
|
||||
])
|
||||
super().__init__(
|
||||
[
|
||||
ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom"),
|
||||
]
|
||||
)
|
||||
|
||||
def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
|
||||
"""
|
||||
@ -104,11 +110,12 @@ class VideoAssetPrompts(TypedDict):
|
||||
|
||||
|
||||
class VideoTestAssets(list[VideoAsset]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([
|
||||
VideoAsset("baby_reading"),
|
||||
])
|
||||
super().__init__(
|
||||
[
|
||||
VideoAsset("baby_reading"),
|
||||
]
|
||||
)
|
||||
|
||||
def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
|
||||
return [prompts["baby_reading"]]
|
||||
@ -120,12 +127,13 @@ class AudioAssetPrompts(TypedDict):
|
||||
|
||||
|
||||
class AudioTestAssets(list[AudioAsset]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([
|
||||
AudioAsset("mary_had_lamb"),
|
||||
AudioAsset("winning_call"),
|
||||
])
|
||||
super().__init__(
|
||||
[
|
||||
AudioAsset("mary_had_lamb"),
|
||||
AudioAsset("winning_call"),
|
||||
]
|
||||
)
|
||||
|
||||
def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
|
||||
return [prompts["mary_had_lamb"], prompts["winning_call"]]
|
||||
@ -220,6 +228,7 @@ def example_system_message() -> str:
|
||||
|
||||
class DecoderPromptType(Enum):
|
||||
"""For encoder/decoder models only."""
|
||||
|
||||
CUSTOM = 1
|
||||
NONE = 2
|
||||
EMPTY_STR = 3
|
||||
@ -253,15 +262,13 @@ _R = TypeVar("_R")
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def get_default_device(self):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return ("cpu"
|
||||
if current_platform.is_cpu() else current_platform.device_type)
|
||||
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
||||
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
if x is None or isinstance(x, (bool, )):
|
||||
if x is None or isinstance(x, (bool,)):
|
||||
return x
|
||||
|
||||
if device is None:
|
||||
@ -289,8 +296,11 @@ class HfRunner:
|
||||
# Set this to avoid hanging issue
|
||||
default_torch_num_threads: Optional[int] = None,
|
||||
) -> None:
|
||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
||||
set_default_torch_num_threads(default_torch_num_threads))
|
||||
init_ctx = (
|
||||
nullcontext()
|
||||
if default_torch_num_threads is None
|
||||
else set_default_torch_num_threads(default_torch_num_threads)
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
self._init(
|
||||
@ -362,14 +372,15 @@ class HfRunner:
|
||||
)
|
||||
|
||||
# in case some unquantized custom models are not in same dtype
|
||||
if (getattr(model, "quantization_method", None) is None
|
||||
and any(p.dtype != self.dtype
|
||||
for p in model.parameters())):
|
||||
if getattr(model, "quantization_method", None) is None and any(
|
||||
p.dtype != self.dtype for p in model.parameters()
|
||||
):
|
||||
model = model.to(dtype=self.dtype)
|
||||
|
||||
if (getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||
and len({p.device
|
||||
for p in model.parameters()}) < 2):
|
||||
if (
|
||||
getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||
and len({p.device for p in model.parameters()}) < 2
|
||||
):
|
||||
model = model.to(device=self.device)
|
||||
|
||||
self.model = model
|
||||
@ -384,6 +395,7 @@ class HfRunner:
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor # noqa: F401
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
@ -471,10 +483,9 @@ class HfRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
all_inputs = self.get_inputs(
|
||||
prompts, images=images, videos=videos, audios=audios
|
||||
)
|
||||
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
for inputs in all_inputs:
|
||||
@ -501,16 +512,17 @@ class HfRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs)
|
||||
outputs = self.generate(
|
||||
prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
@ -521,21 +533,22 @@ class HfRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=beam_width,
|
||||
num_return_sequences=beam_width,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
outputs = self.generate(
|
||||
prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=beam_width,
|
||||
num_return_sequences=beam_width,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
for i in range(len(outputs)):
|
||||
output_ids, output_str = outputs[i]
|
||||
for j in range(len(output_ids)):
|
||||
output_ids[j] = [
|
||||
x for x in output_ids[j]
|
||||
if x != self.tokenizer.pad_token_id
|
||||
x for x in output_ids[j] if x != self.tokenizer.pad_token_id
|
||||
]
|
||||
outputs[i] = (output_ids, output_str)
|
||||
return outputs
|
||||
@ -549,10 +562,9 @@ class HfRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[list[torch.Tensor]]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
all_inputs = self.get_inputs(
|
||||
prompts, images=images, videos=videos, audios=audios
|
||||
)
|
||||
|
||||
all_logprobs: list[list[torch.Tensor]] = []
|
||||
for inputs in all_inputs:
|
||||
@ -565,8 +577,7 @@ class HfRunner:
|
||||
return_dict_in_generate=True,
|
||||
**kwargs,
|
||||
)
|
||||
seq_logprobs = self._hidden_states_to_seq_logprobs(
|
||||
output.hidden_states)
|
||||
seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states)
|
||||
all_logprobs.append(seq_logprobs)
|
||||
return all_logprobs
|
||||
|
||||
@ -630,10 +641,9 @@ class HfRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[TokensTextLogprobs]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
all_inputs = self.get_inputs(
|
||||
prompts, images=images, videos=videos, audios=audios
|
||||
)
|
||||
|
||||
all_logprobs: list[list[dict[int, float]]] = []
|
||||
all_output_ids: list[list[int]] = []
|
||||
@ -653,8 +663,7 @@ class HfRunner:
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(output.hidden_states,
|
||||
num_logprobs)
|
||||
) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
|
||||
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = output.sequences[0]
|
||||
@ -664,19 +673,16 @@ class HfRunner:
|
||||
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||
|
||||
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
return [
|
||||
(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs
|
||||
]
|
||||
|
||||
def encode(self, prompts: list[str], *args,
|
||||
**kwargs) -> list[list[torch.Tensor]]:
|
||||
def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
|
||||
return self.model.encode(prompts, *args, **kwargs)
|
||||
|
||||
def predict(self, prompts: list[list[str]], *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
return self.model.predict(prompts,
|
||||
*args,
|
||||
convert_to_tensor=True,
|
||||
**kwargs)
|
||||
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||||
return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@ -727,8 +733,11 @@ class VllmRunner:
|
||||
default_torch_num_threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
||||
set_default_torch_num_threads(default_torch_num_threads))
|
||||
init_ctx = (
|
||||
nullcontext()
|
||||
if default_torch_num_threads is None
|
||||
else set_default_torch_num_threads(default_torch_num_threads)
|
||||
)
|
||||
|
||||
if not kwargs.get("compilation_config", None):
|
||||
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
|
||||
@ -760,11 +769,12 @@ class VllmRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if any(x is not None and len(x) != len(prompts)
|
||||
for x in [images, videos, audios]):
|
||||
if any(
|
||||
x is not None and len(x) != len(prompts) for x in [images, videos, audios]
|
||||
):
|
||||
raise ValueError(
|
||||
"All non-None multimodal inputs must have the same length as "
|
||||
"prompts")
|
||||
"All non-None multimodal inputs must have the same length as prompts"
|
||||
)
|
||||
|
||||
inputs = list[dict[str, Any]]()
|
||||
for i, prompt in enumerate(prompts):
|
||||
@ -800,14 +810,11 @@ class VllmRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
req_outputs = self.llm.generate(inputs,
|
||||
sampling_params=sampling_params,
|
||||
**kwargs)
|
||||
req_outputs = self.llm.generate(
|
||||
inputs, sampling_params=sampling_params, **kwargs
|
||||
)
|
||||
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
for req_output in req_outputs:
|
||||
@ -834,8 +841,9 @@ class VllmRunner:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
output_logprobs = sample.logprobs
|
||||
outputs.append((output_ids, output_str, output_logprobs,
|
||||
req_output.prompt_logprobs))
|
||||
outputs.append(
|
||||
(output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def generate_w_logprobs(
|
||||
@ -846,23 +854,22 @@ class VllmRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[list[TokensTextLogprobs],
|
||||
list[TokensTextLogprobsPromptLogprobs]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
req_outputs = self.llm.generate(inputs,
|
||||
sampling_params=sampling_params,
|
||||
**kwargs)
|
||||
req_outputs = self.llm.generate(
|
||||
inputs, sampling_params=sampling_params, **kwargs
|
||||
)
|
||||
|
||||
toks_str_logsprobs_prompt_logprobs = (
|
||||
self._final_steps_generate_w_logprobs(req_outputs))
|
||||
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
|
||||
req_outputs
|
||||
)
|
||||
# Omit prompt logprobs if not required by sampling params
|
||||
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||
if sampling_params.prompt_logprobs is None else
|
||||
toks_str_logsprobs_prompt_logprobs)
|
||||
return (
|
||||
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||
if sampling_params.prompt_logprobs is None
|
||||
else toks_str_logsprobs_prompt_logprobs
|
||||
)
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
@ -874,14 +881,15 @@ class VllmRunner:
|
||||
**kwargs: Any,
|
||||
) -> list[tuple[list[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts,
|
||||
greedy_params,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs)
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
outputs = self.generate(
|
||||
prompts,
|
||||
greedy_params,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios,
|
||||
**kwargs,
|
||||
)
|
||||
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||||
|
||||
def generate_greedy_logprobs(
|
||||
self,
|
||||
@ -895,22 +903,24 @@ class VllmRunner:
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[list[TokensTextLogprobs],
|
||||
list[TokensTextLogprobsPromptLogprobs]]:
|
||||
) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]:
|
||||
greedy_logprobs_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop=stop)
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
return self.generate_w_logprobs(prompts,
|
||||
greedy_logprobs_params,
|
||||
images=images,
|
||||
audios=audios,
|
||||
videos=videos,
|
||||
**kwargs)
|
||||
return self.generate_w_logprobs(
|
||||
prompts,
|
||||
greedy_logprobs_params,
|
||||
images=images,
|
||||
audios=audios,
|
||||
videos=videos,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
|
||||
"""
|
||||
@ -919,10 +929,9 @@ class VllmRunner:
|
||||
:param prompts: list of prompts to score
|
||||
:return: perplexity score of each prompt
|
||||
"""
|
||||
outputs = self.generate_greedy_logprobs(prompts,
|
||||
max_tokens=1,
|
||||
num_logprobs=None,
|
||||
num_prompt_logprobs=0)
|
||||
outputs = self.generate_greedy_logprobs(
|
||||
prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0
|
||||
)
|
||||
|
||||
perplexities = []
|
||||
for output in outputs:
|
||||
@ -951,15 +960,13 @@ class VllmRunner:
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
concurrency_limit: Optional[int] = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
outputs = self.llm.beam_search(inputs,
|
||||
BeamSearchParams(beam_width=beam_width,
|
||||
max_tokens=max_tokens),
|
||||
concurrency_limit=concurrency_limit)
|
||||
outputs = self.llm.beam_search(
|
||||
inputs,
|
||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens),
|
||||
concurrency_limit=concurrency_limit,
|
||||
)
|
||||
returned_outputs = []
|
||||
for output in outputs:
|
||||
token_ids = [x.tokens for x in output.sequences]
|
||||
@ -971,17 +978,16 @@ class VllmRunner:
|
||||
req_outputs = self.llm.classify(prompts)
|
||||
return [req_output.outputs.probs for req_output in req_outputs]
|
||||
|
||||
def embed(self,
|
||||
prompts: list[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
*args,
|
||||
**kwargs) -> list[list[float]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
def embed(
|
||||
self,
|
||||
prompts: list[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[list[float]]:
|
||||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||||
|
||||
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
@ -1026,6 +1032,7 @@ def vllm_runner():
|
||||
@pytest.fixture()
|
||||
def temporary_enable_log_propagate():
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("vllm")
|
||||
logger.propagate = True
|
||||
yield
|
||||
@ -1045,6 +1052,7 @@ def num_gpus_available():
|
||||
in current process."""
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.device_count()
|
||||
|
||||
|
||||
@ -1058,12 +1066,11 @@ _dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
|
||||
def dummy_opt_path():
|
||||
json_path = os.path.join(_dummy_opt_path, "config.json")
|
||||
if not os.path.exists(_dummy_opt_path):
|
||||
snapshot_download(repo_id="facebook/opt-125m",
|
||||
local_dir=_dummy_opt_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack"
|
||||
])
|
||||
snapshot_download(
|
||||
repo_id="facebook/opt-125m",
|
||||
local_dir=_dummy_opt_path,
|
||||
ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"],
|
||||
)
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
@ -1077,12 +1084,18 @@ def dummy_opt_path():
|
||||
def dummy_llava_path():
|
||||
json_path = os.path.join(_dummy_llava_path, "config.json")
|
||||
if not os.path.exists(_dummy_llava_path):
|
||||
snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
|
||||
local_dir=_dummy_llava_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack", "*.safetensors"
|
||||
])
|
||||
snapshot_download(
|
||||
repo_id="llava-hf/llava-1.5-7b-hf",
|
||||
local_dir=_dummy_llava_path,
|
||||
ignore_patterns=[
|
||||
"*.bin",
|
||||
"*.bin.index.json",
|
||||
"*.pt",
|
||||
"*.h5",
|
||||
"*.msgpack",
|
||||
"*.safetensors",
|
||||
],
|
||||
)
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
@ -1096,12 +1109,18 @@ def dummy_llava_path():
|
||||
def dummy_gemma2_embedding_path():
|
||||
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
|
||||
if not os.path.exists(_dummy_gemma2_embedding_path):
|
||||
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
|
||||
local_dir=_dummy_gemma2_embedding_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack", "*.safetensors"
|
||||
])
|
||||
snapshot_download(
|
||||
repo_id="BAAI/bge-multilingual-gemma2",
|
||||
local_dir=_dummy_gemma2_embedding_path,
|
||||
ignore_patterns=[
|
||||
"*.bin",
|
||||
"*.bin.index.json",
|
||||
"*.pt",
|
||||
"*.h5",
|
||||
"*.msgpack",
|
||||
"*.safetensors",
|
||||
],
|
||||
)
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path) as f:
|
||||
config = json.load(f)
|
||||
@ -1114,10 +1133,9 @@ def dummy_gemma2_embedding_path():
|
||||
# Add the flag `--optional` to allow run tests
|
||||
# that are marked with @pytest.mark.optional
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--optional",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run optional test")
|
||||
parser.addoption(
|
||||
"--optional", action="store_true", default=False, help="run optional test"
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
@ -1185,7 +1203,6 @@ def _find_free_port() -> int:
|
||||
|
||||
|
||||
class LocalAssetServer:
|
||||
|
||||
address: str
|
||||
port: int
|
||||
server: Optional[http.server.ThreadingHTTPServer]
|
||||
@ -1200,9 +1217,9 @@ class LocalAssetServer:
|
||||
def __enter__(self):
|
||||
self.port = _find_free_port()
|
||||
self.server = http.server.ThreadingHTTPServer(
|
||||
(self.address, self.port), AssetHandler)
|
||||
self.thread = threading.Thread(target=self.server.serve_forever,
|
||||
daemon=True)
|
||||
(self.address, self.port), AssetHandler
|
||||
)
|
||||
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||
self.thread.start()
|
||||
return self
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.platforms import current_platform
|
||||
def check_cuda_context():
|
||||
"""Check CUDA driver context status"""
|
||||
try:
|
||||
cuda = ctypes.CDLL('libcuda.so')
|
||||
cuda = ctypes.CDLL("libcuda.so")
|
||||
device = ctypes.c_int()
|
||||
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
||||
return (True, device.value) if result == 0 else (False, None)
|
||||
@ -27,9 +27,11 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||
# New thread should have no CUDA context initially
|
||||
valid_before, device_before = check_cuda_context()
|
||||
if valid_before:
|
||||
return False, \
|
||||
"CUDA context should not exist in new thread, " \
|
||||
f"got device {device_before}"
|
||||
return (
|
||||
False,
|
||||
"CUDA context should not exist in new thread, "
|
||||
f"got device {device_before}",
|
||||
)
|
||||
|
||||
# Test setting CUDA context
|
||||
current_platform.set_device(device_input)
|
||||
@ -39,8 +41,7 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||
if not valid_after:
|
||||
return False, "CUDA context should be valid after set_cuda_context"
|
||||
if device_id != expected_device_id:
|
||||
return False, \
|
||||
f"Expected device {expected_device_id}, got {device_id}"
|
||||
return False, f"Expected device {expected_device_id}, got {device_id}"
|
||||
|
||||
return True, "Success"
|
||||
except Exception as e:
|
||||
@ -50,30 +51,30 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||
class TestSetCudaContext:
|
||||
"""Test suite for the set_cuda_context function."""
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="CUDA not available")
|
||||
@pytest.mark.parametrize(argnames="device_input,expected_device_id",
|
||||
argvalues=[
|
||||
(0, 0),
|
||||
(torch.device('cuda:0'), 0),
|
||||
('cuda:0', 0),
|
||||
],
|
||||
ids=["int", "torch_device", "string"])
|
||||
def test_set_cuda_context_parametrized(self, device_input,
|
||||
expected_device_id):
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize(
|
||||
argnames="device_input,expected_device_id",
|
||||
argvalues=[
|
||||
(0, 0),
|
||||
(torch.device("cuda:0"), 0),
|
||||
("cuda:0", 0),
|
||||
],
|
||||
ids=["int", "torch_device", "string"],
|
||||
)
|
||||
def test_set_cuda_context_parametrized(self, device_input, expected_device_id):
|
||||
"""Test setting CUDA context in isolated threads."""
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_cuda_test_in_thread, device_input,
|
||||
expected_device_id)
|
||||
future = executor.submit(
|
||||
run_cuda_test_in_thread, device_input, expected_device_id
|
||||
)
|
||||
success, message = future.result(timeout=30)
|
||||
assert success, message
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="CUDA not available")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
def test_set_cuda_context_invalid_device_type(self):
|
||||
"""Test error handling for invalid device type."""
|
||||
with pytest.raises(ValueError, match="Expected a cuda device"):
|
||||
current_platform.set_device(torch.device('cpu'))
|
||||
current_platform.set_device(torch.device("cpu"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -17,20 +17,16 @@ def test_computed_prefix_blocks(model: str):
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?")
|
||||
"online for free?"
|
||||
)
|
||||
|
||||
llm = LLM(model=model)
|
||||
sampling_params = SamplingParams(max_tokens=10,
|
||||
temperature=0.0,
|
||||
detokenize=False)
|
||||
sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False)
|
||||
|
||||
outputs_no_detokenization = llm.generate(prompt,
|
||||
sampling_params)[0].outputs[0]
|
||||
outputs_no_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
|
||||
sampling_params.detokenize = True
|
||||
outputs_with_detokenization = llm.generate(prompt,
|
||||
sampling_params)[0].outputs[0]
|
||||
outputs_with_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
|
||||
|
||||
assert outputs_no_detokenization.text == ''
|
||||
assert outputs_with_detokenization.text != ''
|
||||
assert outputs_no_detokenization.token_ids == \
|
||||
outputs_with_detokenization.token_ids
|
||||
assert outputs_no_detokenization.text == ""
|
||||
assert outputs_with_detokenization.text != ""
|
||||
assert outputs_no_detokenization.token_ids == outputs_with_detokenization.token_ids
|
||||
|
||||
@ -8,15 +8,17 @@ from vllm import SamplingParams
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
|
||||
|
||||
PROMPT = "Hello, my name is Lee, and I'm a student in the " + \
|
||||
"college of engineering"
|
||||
PROMPT = "Hello, my name is Lee, and I'm a student in the " + "college of engineering"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_tokens,stop,truth", [
|
||||
(0, None, " is Lee, and I'm a student in the college of engineering"),
|
||||
(0, "e", " is L"),
|
||||
(5, "e", " is Lee, and I'm a stud"),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"min_tokens,stop,truth",
|
||||
[
|
||||
(0, None, " is Lee, and I'm a student in the college of engineering"),
|
||||
(0, "e", " is L"),
|
||||
(5, "e", " is Lee, and I'm a stud"),
|
||||
],
|
||||
)
|
||||
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
||||
"""Test for a specific min_tokens and stop.
|
||||
|
||||
@ -31,16 +33,18 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
||||
stop=stop,
|
||||
min_tokens=min_tokens,
|
||||
)
|
||||
request = EngineCoreRequest(request_id="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None)
|
||||
request = EngineCoreRequest(
|
||||
request_id="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
|
||||
@ -31,34 +31,39 @@ def test_stop_reason(vllm_model, example_prompts):
|
||||
llm = vllm_model.llm
|
||||
|
||||
# test stop token
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
ignore_eos=True,
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop_token_ids=[stop_token_id]))
|
||||
outputs = llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
ignore_eos=True,
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop_token_ids=[stop_token_id],
|
||||
),
|
||||
)
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "stop"
|
||||
assert output.stop_reason == stop_token_id
|
||||
|
||||
# test stop string
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
ignore_eos=True,
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop="."))
|
||||
outputs = llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="."
|
||||
),
|
||||
)
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "stop"
|
||||
assert output.stop_reason == STOP_STR
|
||||
|
||||
# test EOS token
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
seed=SEED, max_tokens=MAX_TOKENS))
|
||||
outputs = llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS),
|
||||
)
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "length" or (
|
||||
output.finish_reason == "stop" and output.stop_reason is None)
|
||||
output.finish_reason == "stop" and output.stop_reason is None
|
||||
)
|
||||
|
||||
@ -14,7 +14,6 @@ def include_stop_str_in_output(request):
|
||||
|
||||
|
||||
class _DummyDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
def __init__(self, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
@ -27,7 +26,8 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
|
||||
params = SamplingParams(
|
||||
stop=stop,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
min_tokens=min_tokens)
|
||||
min_tokens=min_tokens,
|
||||
)
|
||||
# Keep other fields minimal for unit test purposes.
|
||||
req = EngineCoreRequest(
|
||||
request_id="test",
|
||||
@ -44,8 +44,7 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
|
||||
return req
|
||||
|
||||
|
||||
def test_stop_string_while_stop_token_terminates(
|
||||
include_stop_str_in_output: bool):
|
||||
def test_stop_string_while_stop_token_terminates(include_stop_str_in_output: bool):
|
||||
"""
|
||||
This test verifies that the detokenizer correctly handles the case where
|
||||
the generated token sequence contains both:
|
||||
@ -78,8 +77,9 @@ def test_stop_string_while_stop_token_terminates(
|
||||
token_ids = [ord(c) for c in generated_text]
|
||||
|
||||
# Create a request with the stop string and initialize the detokenizer.
|
||||
req = _make_request(stop=[stop_string],
|
||||
include_stop_str_in_output=include_stop_str_in_output)
|
||||
req = _make_request(
|
||||
stop=[stop_string], include_stop_str_in_output=include_stop_str_in_output
|
||||
)
|
||||
detok = _DummyDetokenizer(req)
|
||||
|
||||
# Simulate that the last token ('Z') is a stop token (stop_terminated=True).
|
||||
@ -99,5 +99,4 @@ def test_stop_string_while_stop_token_terminates(
|
||||
|
||||
# get_next_output_text should return the full text when finished=True.
|
||||
# (Buffering only applies during streaming when finished=False.)
|
||||
assert detok.get_next_output_text(finished=True,
|
||||
delta=False) == expected_text
|
||||
assert detok.get_next_output_text(finished=True, delta=False) == expected_text
|
||||
|
||||
@ -11,12 +11,14 @@ MODEL = "meta-llama/llama-2-7b-hf"
|
||||
MAX_TOKENS = 200
|
||||
|
||||
|
||||
def _test_stopping(llm: LLM,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
include_in_output: bool = False) -> None:
|
||||
def _test_stopping(
|
||||
llm: LLM,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
include_in_output: bool = False,
|
||||
) -> None:
|
||||
output = llm.generate(
|
||||
"A story about vLLM:\n",
|
||||
SamplingParams(
|
||||
@ -25,7 +27,8 @@ def _test_stopping(llm: LLM,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_in_output,
|
||||
))[0].outputs[0]
|
||||
),
|
||||
)[0].outputs[0]
|
||||
|
||||
assert output is not None
|
||||
assert output.text == expected_output
|
||||
@ -33,17 +36,21 @@ def _test_stopping(llm: LLM,
|
||||
|
||||
|
||||
def _stop_basic(llm):
|
||||
_test_stopping(llm,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".")
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".",
|
||||
)
|
||||
|
||||
_test_stopping(llm,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".")
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".",
|
||||
)
|
||||
|
||||
|
||||
def _stop_multi_tokens(llm):
|
||||
@ -52,45 +59,54 @@ def _stop_multi_tokens(llm):
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo")
|
||||
expected_reason="group of peo",
|
||||
)
|
||||
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo")
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo",
|
||||
)
|
||||
|
||||
|
||||
def _stop_partial_token(llm):
|
||||
_test_stopping(llm,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani")
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani",
|
||||
)
|
||||
|
||||
_test_stopping(llm,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani")
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani",
|
||||
)
|
||||
|
||||
|
||||
def _stop_token_id(llm):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(llm,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013)
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013,
|
||||
)
|
||||
|
||||
_test_stopping(llm,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013)
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
||||
@ -111,8 +111,7 @@ class MockSubscriber:
|
||||
self.last_seq = -1
|
||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||
|
||||
def receive_one(self,
|
||||
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
||||
def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
||||
"""Receive a single message with timeout"""
|
||||
if not self.sub.poll(timeout):
|
||||
return None
|
||||
@ -135,8 +134,7 @@ class MockSubscriber:
|
||||
|
||||
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
||||
|
||||
def receive_replay(self,
|
||||
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
||||
def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
||||
"""Receive replayed messages from a specific replay socket"""
|
||||
if not self.replay_sockets:
|
||||
raise ValueError("Replay sockets not initialized")
|
||||
|
||||
@ -12,7 +12,8 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
|
||||
CustomAllreduce)
|
||||
CustomAllreduce,
|
||||
)
|
||||
|
||||
# create a cpu process group for communicating metadata (ipc handle)
|
||||
dist.init_process_group(backend="gloo")
|
||||
@ -52,7 +53,8 @@ for p in pointers:
|
||||
assert ord(host_data[i]) == byte_value, (
|
||||
f"Rank {rank} failed"
|
||||
f" to verify buffer {p}. Expected {byte_value}, "
|
||||
f"got {ord(host_data[i])}")
|
||||
f"got {ord(host_data[i])}"
|
||||
)
|
||||
|
||||
print(f"Rank {rank} verified all buffers")
|
||||
|
||||
|
||||
@ -13,13 +13,19 @@ import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.distributed import (
|
||||
broadcast_tensor_dict,
|
||||
get_pp_group,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
|
||||
from ..utils import (init_test_distributed_environment, multi_gpu_test,
|
||||
multi_process_parallel)
|
||||
from ..utils import (
|
||||
init_test_distributed_environment,
|
||||
multi_gpu_test,
|
||||
multi_process_parallel,
|
||||
)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@ -37,12 +43,11 @@ def all_reduce_test_worker(
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
||||
(r + 1) for r in range(tp_size)
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
|
||||
for r in range(tp_size)
|
||||
]
|
||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
t = all_tensors[rank % tp_size]
|
||||
@ -51,28 +56,31 @@ def all_reduce_test_worker(
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
|
||||
pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
def reduce_scatter_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
||||
(r + 1) for r in range(tp_size)
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
|
||||
for r in range(tp_size)
|
||||
]
|
||||
|
||||
index = rank % tp_size
|
||||
partition_size = num_elements // tp_size
|
||||
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
|
||||
expected = all_reduce[index * partition_size : (index + 1) * partition_size]
|
||||
t = all_tensors[index]
|
||||
t = tensor_model_parallel_reduce_scatter(t, 0)
|
||||
torch.testing.assert_close(t, expected)
|
||||
@ -92,8 +100,7 @@ def all_gather_test_worker(
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
num_dimensions = 3
|
||||
tensor_size = list(range(2, num_dimensions + 2))
|
||||
total_size = 1
|
||||
@ -101,8 +108,10 @@ def all_gather_test_worker(
|
||||
total_size *= s
|
||||
for all_gather_dimension in range(num_dimensions):
|
||||
all_tensors = [
|
||||
torch.arange(total_size, dtype=torch.float32,
|
||||
device="cuda").reshape(tensor_size) * (r + 1)
|
||||
torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(
|
||||
tensor_size
|
||||
)
|
||||
* (r + 1)
|
||||
for r in range(tp_size)
|
||||
]
|
||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||
@ -125,8 +134,7 @@ def broadcast_tensor_dict_test_worker(
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
test_dict = {
|
||||
# device tensor
|
||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||
@ -134,10 +142,7 @@ def broadcast_tensor_dict_test_worker(
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||
"c": "test",
|
||||
"d": [1, 2, 3],
|
||||
"e": {
|
||||
"a": 1,
|
||||
"b": 2
|
||||
},
|
||||
"e": {"a": 1, "b": 2},
|
||||
# empty tensor
|
||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||
}
|
||||
@ -166,8 +171,7 @@ def send_recv_tensor_dict_test_worker(
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
test_dict = {
|
||||
# device tensor
|
||||
@ -176,10 +180,7 @@ def send_recv_tensor_dict_test_worker(
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||
"c": "test",
|
||||
"d": [1, 2, 3],
|
||||
"e": {
|
||||
"a": 1,
|
||||
"b": 2
|
||||
},
|
||||
"e": {"a": 1, "b": 2},
|
||||
# empty tensor
|
||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||
}
|
||||
@ -211,8 +212,7 @@ def send_recv_test_worker(
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
size = 64
|
||||
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
|
||||
@ -229,10 +229,10 @@ def send_recv_test_worker(
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("test_target", [
|
||||
all_reduce_test_worker, all_gather_test_worker,
|
||||
broadcast_tensor_dict_test_worker
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"test_target",
|
||||
[all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker],
|
||||
)
|
||||
def test_multi_process_tensor_parallel(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
@ -244,7 +244,8 @@ def test_multi_process_tensor_parallel(
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("pp_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
|
||||
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]
|
||||
)
|
||||
def test_multi_process_pipeline_parallel(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
pp_size: int,
|
||||
@ -256,11 +257,16 @@ def test_multi_process_pipeline_parallel(
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pp_size", [2])
|
||||
@pytest.mark.parametrize("test_target", [
|
||||
send_recv_test_worker, send_recv_tensor_dict_test_worker,
|
||||
all_reduce_test_worker, all_gather_test_worker,
|
||||
broadcast_tensor_dict_test_worker
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"test_target",
|
||||
[
|
||||
send_recv_test_worker,
|
||||
send_recv_tensor_dict_test_worker,
|
||||
all_reduce_test_worker,
|
||||
all_gather_test_worker,
|
||||
broadcast_tensor_dict_test_worker,
|
||||
],
|
||||
)
|
||||
def test_multi_process_tensor_parallel_pipeline_parallel(
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
|
||||
@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
all workers in a node other than the head node, which can cause the test
|
||||
to fail.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
@ -56,7 +57,8 @@ class CPTestSettings:
|
||||
raise ValueError(
|
||||
f"Length mismatch: distributed_backends "
|
||||
f"({len(self.distributed_backends)}) != "
|
||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
||||
f"vllm_major_versions ({len(self.vllm_major_versions)})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
@ -74,29 +76,39 @@ class CPTestSettings:
|
||||
for dcp_multiplier in [0.5, 1]:
|
||||
for chunked_prefill_val in [True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
dcp_size=int(dcp_multiplier *
|
||||
tp_base),
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val))
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
dcp_size=int(dcp_multiplier * tp_base),
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
)
|
||||
return CPTestSettings(
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp"],
|
||||
vllm_major_versions=["1"],
|
||||
runner=runner,
|
||||
test_options=CPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
test_options=CPTestOptions(
|
||||
multi_node_only=multi_node_only, load_format=load_format
|
||||
),
|
||||
)
|
||||
|
||||
def iter_params(self, model_id: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||
self.vllm_major_versions):
|
||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
||||
self.runner, opts)
|
||||
for backend, vllm_major_version in zip(
|
||||
self.distributed_backends, self.vllm_major_versions
|
||||
):
|
||||
yield (
|
||||
model_id,
|
||||
parallel_setup,
|
||||
backend,
|
||||
vllm_major_version,
|
||||
self.runner,
|
||||
opts,
|
||||
)
|
||||
|
||||
|
||||
def _compare_cp_with_tp(
|
||||
@ -148,8 +160,10 @@ def _compare_cp_with_tp(
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
pytest.skip(
|
||||
"Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend"
|
||||
)
|
||||
if multi_node_only and not VLLM_MULTI_NODE:
|
||||
pytest.skip("Not in multi-node setting")
|
||||
|
||||
@ -178,8 +192,7 @@ def _compare_cp_with_tp(
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
cp_env = tp_env = {
|
||||
"VLLM_USE_V1":
|
||||
vllm_major_version, # Note(hc): DCP only support V1 engine only
|
||||
"VLLM_USE_V1": vllm_major_version, # Note(hc): DCP only support V1 engine only
|
||||
}
|
||||
|
||||
cp_args = [
|
||||
@ -205,13 +218,15 @@ def _compare_cp_with_tp(
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(model_id,
|
||||
cp_args,
|
||||
tp_args,
|
||||
cp_env,
|
||||
tp_env,
|
||||
method=method,
|
||||
max_wait_seconds=720)
|
||||
compare_two_settings(
|
||||
model_id,
|
||||
cp_args,
|
||||
tp_args,
|
||||
cp_env,
|
||||
tp_env,
|
||||
method=method,
|
||||
max_wait_seconds=720,
|
||||
)
|
||||
except Exception:
|
||||
testing_ray_compiled_graph = cp_env is not None
|
||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||
@ -224,9 +239,10 @@ def _compare_cp_with_tp(
|
||||
|
||||
CP_TEXT_GENERATION_MODELS = {
|
||||
# [MLA attention only]
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat":
|
||||
[CPTestSettings.detailed(),
|
||||
CPTestSettings.detailed(tp_base=2)],
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||
CPTestSettings.detailed(),
|
||||
CPTestSettings.detailed(tp_base=2),
|
||||
],
|
||||
}
|
||||
|
||||
CP_TEST_MODELS = [
|
||||
@ -237,11 +253,19 @@ CP_TEST_MODELS = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"runner", "test_options"),
|
||||
(
|
||||
"model_id",
|
||||
"parallel_setup",
|
||||
"distributed_backend",
|
||||
"vllm_major_version",
|
||||
"runner",
|
||||
"test_options",
|
||||
),
|
||||
[
|
||||
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
||||
for setting in settings for params in setting.iter_params(model_id)
|
||||
params
|
||||
for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
||||
for setting in settings
|
||||
for params in setting.iter_params(model_id)
|
||||
if model_id in CP_TEST_MODELS
|
||||
],
|
||||
)
|
||||
@ -255,12 +279,14 @@ def test_cp_generation(
|
||||
test_options: CPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_cp_with_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False)
|
||||
_compare_cp_with_tp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False,
|
||||
)
|
||||
|
||||
@ -8,12 +8,14 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||
|
||||
from ..utils import (ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment, multi_process_parallel)
|
||||
from ..utils import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment,
|
||||
multi_process_parallel,
|
||||
)
|
||||
|
||||
random.seed(42)
|
||||
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
||||
@ -33,8 +35,7 @@ def graph_allreduce(
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
group = get_tp_group().device_group
|
||||
|
||||
@ -60,18 +61,15 @@ def graph_allreduce(
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with graph_capture(device=device) as graph_capture_context:
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device())
|
||||
inp2 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device())
|
||||
inp1 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
inp2 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph,
|
||||
stream=graph_capture_context.stream):
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for i in range(num_communication):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
@ -96,8 +94,7 @@ def eager_allreduce(
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
# we use the first group to communicate once
|
||||
# and the second group to communicate twice
|
||||
@ -132,5 +129,4 @@ def test_custom_allreduce(
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
|
||||
test_target)
|
||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from ..entrypoints.openai.test_oot_registration import (
|
||||
run_and_test_dummy_opt_api_server)
|
||||
from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server
|
||||
|
||||
|
||||
def test_distributed_oot(dummy_opt_path: str):
|
||||
|
||||
@ -10,10 +10,12 @@ from vllm.distributed.eplb.rebalance_algo import rebalance_experts
|
||||
def test_basic_rebalance():
|
||||
"""Test basic rebalancing functionality"""
|
||||
# Example from https://github.com/deepseek-ai/eplb
|
||||
weight = torch.tensor([
|
||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||
])
|
||||
weight = torch.tensor(
|
||||
[
|
||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||
]
|
||||
)
|
||||
|
||||
num_layers = weight.shape[0]
|
||||
num_replicas = 16
|
||||
@ -21,45 +23,49 @@ def test_basic_rebalance():
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify output shapes
|
||||
assert phy2log.shape == (
|
||||
2,
|
||||
16,
|
||||
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
|
||||
assert (log2phy.shape[0] == 2
|
||||
), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
|
||||
assert (
|
||||
log2phy.shape[1] == 12
|
||||
), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
|
||||
assert log2phy.shape[0] == 2, (
|
||||
f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
|
||||
)
|
||||
assert log2phy.shape[1] == 12, (
|
||||
f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
|
||||
)
|
||||
assert logcnt.shape == (
|
||||
2,
|
||||
12,
|
||||
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
|
||||
|
||||
# Verify physical to logical expert mapping range is correct
|
||||
assert torch.all(phy2log >= 0) and torch.all(
|
||||
phy2log < 12), "Physical to logical mapping should be in range [0, 12)"
|
||||
assert torch.all(phy2log >= 0) and torch.all(phy2log < 12), (
|
||||
"Physical to logical mapping should be in range [0, 12)"
|
||||
)
|
||||
|
||||
# Verify expert count reasonableness
|
||||
assert torch.all(
|
||||
logcnt >= 1), "Each logical expert should have at least 1 replica"
|
||||
assert (
|
||||
torch.sum(logcnt, dim=1).sum() == num_replicas *
|
||||
num_layers), f"Total replicas should be {num_replicas * num_layers}"
|
||||
assert torch.all(logcnt >= 1), "Each logical expert should have at least 1 replica"
|
||||
assert torch.sum(logcnt, dim=1).sum() == num_replicas * num_layers, (
|
||||
f"Total replicas should be {num_replicas * num_layers}"
|
||||
)
|
||||
|
||||
# Verify expected output
|
||||
expected_phy2log = torch.tensor([
|
||||
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
||||
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
|
||||
])
|
||||
expected_phy2log = torch.tensor(
|
||||
[
|
||||
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
||||
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
|
||||
]
|
||||
)
|
||||
assert torch.all(phy2log == expected_phy2log)
|
||||
|
||||
expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1],
|
||||
[1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]])
|
||||
expected_logcnt = torch.tensor(
|
||||
[[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]
|
||||
)
|
||||
assert torch.all(logcnt == expected_logcnt)
|
||||
|
||||
|
||||
@ -71,9 +77,9 @@ def test_single_gpu_case():
|
||||
num_nodes = 1
|
||||
num_gpus = 1
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 4)
|
||||
@ -93,19 +99,19 @@ def test_equal_weights():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 8)
|
||||
|
||||
# With equal weights, each expert should have exactly one replica
|
||||
assert torch.all(
|
||||
logcnt == 1
|
||||
), "With equal weights and no replication, " \
|
||||
"each expert should have exactly 1 replica"
|
||||
assert torch.all(logcnt == 1), (
|
||||
"With equal weights and no replication, "
|
||||
"each expert should have exactly 1 replica"
|
||||
)
|
||||
|
||||
|
||||
def test_extreme_weight_imbalance():
|
||||
@ -116,35 +122,37 @@ def test_extreme_weight_imbalance():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 12)
|
||||
assert logcnt.shape == (1, 8)
|
||||
|
||||
# Expert with highest weight (index 0) should have more replicas
|
||||
assert (
|
||||
logcnt[0, 0]
|
||||
> logcnt[0, 1]), "Expert with highest weight should have more replicas"
|
||||
assert logcnt[0, 0] > logcnt[0, 1], (
|
||||
"Expert with highest weight should have more replicas"
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_layers():
|
||||
"""Test multiple layers case"""
|
||||
weight = torch.tensor([
|
||||
[10, 20, 30, 40, 50, 60], # First layer
|
||||
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
|
||||
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
|
||||
])
|
||||
weight = torch.tensor(
|
||||
[
|
||||
[10, 20, 30, 40, 50, 60], # First layer
|
||||
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
|
||||
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
|
||||
]
|
||||
)
|
||||
num_replicas = 8
|
||||
num_groups = 2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (3, 8)
|
||||
@ -152,12 +160,12 @@ def test_multiple_layers():
|
||||
|
||||
# Verify expert allocation is reasonable for each layer
|
||||
for layer in range(3):
|
||||
assert torch.all(phy2log[layer] >= 0) and torch.all(
|
||||
phy2log[layer] < 6
|
||||
), f"Layer {layer} physical to logical mapping" \
|
||||
"should be in range [0, 6)"
|
||||
assert (torch.sum(logcnt[layer]) == num_replicas
|
||||
), f"Layer {layer} total replicas should be {num_replicas}"
|
||||
assert torch.all(phy2log[layer] >= 0) and torch.all(phy2log[layer] < 6), (
|
||||
f"Layer {layer} physical to logical mappingshould be in range [0, 6)"
|
||||
)
|
||||
assert torch.sum(logcnt[layer]) == num_replicas, (
|
||||
f"Layer {layer} total replicas should be {num_replicas}"
|
||||
)
|
||||
|
||||
|
||||
def test_parameter_validation():
|
||||
@ -179,17 +187,19 @@ def test_parameter_validation():
|
||||
|
||||
def test_small_scale_hierarchical():
|
||||
"""Test small-scale hierarchical load balancing"""
|
||||
weight = torch.tensor([
|
||||
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
|
||||
])
|
||||
weight = torch.tensor(
|
||||
[
|
||||
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
|
||||
]
|
||||
)
|
||||
num_replicas = 12
|
||||
num_groups = 4 # 4 groups, 2 experts each
|
||||
num_nodes = 2 # 2 nodes
|
||||
num_gpus = 4 # 4 GPUs
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify basic constraints
|
||||
assert phy2log.shape == (1, 12)
|
||||
@ -199,8 +209,9 @@ def test_small_scale_hierarchical():
|
||||
|
||||
# Expert with highest weight should have more replicas
|
||||
max_weight_expert = torch.argmax(weight[0])
|
||||
assert (logcnt[0, max_weight_expert]
|
||||
>= 2), "Highest weight expert should have multiple replicas"
|
||||
assert logcnt[0, max_weight_expert] >= 2, (
|
||||
"Highest weight expert should have multiple replicas"
|
||||
)
|
||||
|
||||
|
||||
def test_global_load_balance_fallback():
|
||||
@ -213,9 +224,9 @@ def test_global_load_balance_fallback():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Should work normally, just using global load balancing strategy
|
||||
assert phy2log.shape == (1, 8)
|
||||
@ -235,9 +246,9 @@ def test_device_compatibility(device):
|
||||
num_nodes = 1
|
||||
num_gpus = 2
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Function will convert to CPU internally, but should handle different
|
||||
# device inputs normally
|
||||
@ -250,7 +261,8 @@ def test_additional_cases():
|
||||
|
||||
# Test case 1: Large-scale distributed setup
|
||||
weight1 = torch.tensor(
|
||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]])
|
||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
||||
)
|
||||
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
|
||||
|
||||
assert phy2log1.shape == (1, 24)
|
||||
@ -258,10 +270,12 @@ def test_additional_cases():
|
||||
assert torch.sum(logcnt1) == 24
|
||||
|
||||
# Test case 2: Different weight distributions
|
||||
weight2 = torch.tensor([
|
||||
[200, 150, 100, 50, 25, 12], # Decreasing weights
|
||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||
])
|
||||
weight2 = torch.tensor(
|
||||
[
|
||||
[200, 150, 100, 50, 25, 12], # Decreasing weights
|
||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||
]
|
||||
)
|
||||
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
|
||||
|
||||
assert phy2log2.shape == (2, 10)
|
||||
@ -274,19 +288,21 @@ def test_additional_cases():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weight = torch.tensor([
|
||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||
])
|
||||
weight = torch.tensor(
|
||||
[
|
||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||
]
|
||||
)
|
||||
|
||||
num_replicas = 16
|
||||
num_groups = 4
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
print(phy2log)
|
||||
|
||||
test_basic_rebalance()
|
||||
|
||||
@ -9,11 +9,12 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.distributed.eplb.rebalance_execute import (
|
||||
rearrange_expert_weights_inplace)
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
get_tp_group,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
@ -22,13 +23,13 @@ def distributed_run(fn, world_size):
|
||||
processes: list[multiprocessing.Process] = []
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env['RANK'] = str(i)
|
||||
env['LOCAL_RANK'] = str(i)
|
||||
env['WORLD_SIZE'] = str(number_of_processes)
|
||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
||||
env['MASTER_ADDR'] = 'localhost'
|
||||
env['MASTER_PORT'] = '12345'
|
||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
||||
env["RANK"] = str(i)
|
||||
env["LOCAL_RANK"] = str(i)
|
||||
env["WORLD_SIZE"] = str(number_of_processes)
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = multiprocessing.Process(target=fn, args=(env,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
@ -45,7 +46,7 @@ def worker_fn_wrapper(fn):
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ['LOCAL_RANK']
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_distributed_environment()
|
||||
@ -60,10 +61,10 @@ def worker_fn_wrapper(fn):
|
||||
|
||||
|
||||
def create_expert_indices_with_redundancy(
|
||||
num_layers: int,
|
||||
num_logical_experts: int,
|
||||
total_physical_experts: int,
|
||||
redundancy_config: list[int], # redundancy for each logical expert
|
||||
num_layers: int,
|
||||
num_logical_experts: int,
|
||||
total_physical_experts: int,
|
||||
redundancy_config: list[int], # redundancy for each logical expert
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create expert indices with redundancy.
|
||||
@ -120,27 +121,27 @@ def create_expert_weights(
|
||||
for layer in range(num_layers):
|
||||
layer_weights = []
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
weight_tensor = torch.zeros(num_local_experts,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
weight_tensor = torch.zeros(
|
||||
num_local_experts, hidden_size, device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
for local_expert in range(num_local_experts):
|
||||
# Get the logical expert ID for this physical expert
|
||||
global_pos = rank * num_local_experts + local_expert
|
||||
logical_expert_id = physical_to_logical_mapping[
|
||||
layer, global_pos].item()
|
||||
layer, global_pos
|
||||
].item()
|
||||
|
||||
# Generate weights based on logical expert ID
|
||||
# (so that all replicas of the same logical expert have the
|
||||
# same weights)
|
||||
base_value = (logical_expert_id * 1000 + layer * 100 +
|
||||
weight_idx * 10)
|
||||
weight_tensor[local_expert] = torch.arange(base_value,
|
||||
base_value +
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
base_value = logical_expert_id * 1000 + layer * 100 + weight_idx * 10
|
||||
weight_tensor[local_expert] = torch.arange(
|
||||
base_value,
|
||||
base_value + hidden_size,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
layer_weights.append(weight_tensor)
|
||||
expert_weights.append(layer_weights)
|
||||
@ -182,12 +183,15 @@ def verify_expert_weights_after_shuffle(
|
||||
|
||||
# Check if the weights are correct
|
||||
actual_weights = weight_tensor[local_expert]
|
||||
expected_base = (expected_logical_expert * 1000 + layer * 100 +
|
||||
weight_idx * 10)
|
||||
expected_weights = torch.arange(expected_base,
|
||||
expected_base + hidden_size,
|
||||
device=actual_weights.device,
|
||||
dtype=actual_weights.dtype)
|
||||
expected_base = (
|
||||
expected_logical_expert * 1000 + layer * 100 + weight_idx * 10
|
||||
)
|
||||
expected_weights = torch.arange(
|
||||
expected_base,
|
||||
expected_base + hidden_size,
|
||||
device=actual_weights.device,
|
||||
dtype=actual_weights.dtype,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
actual_weights,
|
||||
@ -195,7 +199,8 @@ def verify_expert_weights_after_shuffle(
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"local expert {local_expert}: "
|
||||
f"weights do not match. "
|
||||
f"Expected logical expert {expected_logical_expert}")
|
||||
f"Expected logical expert {expected_logical_expert}",
|
||||
)
|
||||
|
||||
|
||||
def verify_redundant_experts_have_same_weights(
|
||||
@ -222,23 +227,23 @@ def verify_redundant_experts_have_same_weights(
|
||||
total_physical_experts,
|
||||
hidden_size,
|
||||
device=expert_weights[layer][weight_idx].device,
|
||||
dtype=expert_weights[layer][weight_idx].dtype)
|
||||
dtype=expert_weights[layer][weight_idx].dtype,
|
||||
)
|
||||
|
||||
# Use all_gather to collect expert weights from current node
|
||||
# expert_weights[layer][weight_idx] shape:
|
||||
# [num_local_experts, hidden_size]
|
||||
local_weights = expert_weights[layer][
|
||||
weight_idx] # [num_local_experts, hidden_size]
|
||||
weight_idx
|
||||
] # [num_local_experts, hidden_size]
|
||||
|
||||
# Split tensor along dim 0 into a list for all_gather
|
||||
gathered_weights_list = torch.chunk(gathered_weights,
|
||||
world_size,
|
||||
dim=0)
|
||||
gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0)
|
||||
|
||||
torch.distributed.all_gather(
|
||||
# Output list: each element corresponds to one rank's weights
|
||||
list(gathered_weights_list),
|
||||
local_weights # Input: current rank's local weights
|
||||
local_weights, # Input: current rank's local weights
|
||||
)
|
||||
|
||||
all_weights.append(gathered_weights)
|
||||
@ -266,7 +271,8 @@ def verify_redundant_experts_have_same_weights(
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"logical expert {logical_expert_id}: "
|
||||
f"Physical expert {physical_pos} has different weights"
|
||||
f"than expected")
|
||||
f"than expected",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -290,10 +296,11 @@ def verify_redundant_experts_have_same_weights(
|
||||
# 4 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 32 physical experts, 16 redundant experts
|
||||
(4, 8, 8, 16),
|
||||
])
|
||||
def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts):
|
||||
],
|
||||
)
|
||||
def test_rearrange_expert_weights_with_redundancy(
|
||||
world_size, num_layers, num_local_experts, num_logical_experts
|
||||
):
|
||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
@ -304,8 +311,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||
# to expert parallel)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size,
|
||||
pipeline_model_parallel_size=1)
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
@ -316,8 +323,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
||||
hidden_sizes = [32, 64] # Two different weight matrices
|
||||
|
||||
# Create old expert indices (with redundancy)
|
||||
redundancy_config = create_redundancy_config(num_logical_experts,
|
||||
total_physical_experts)
|
||||
redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
@ -328,7 +336,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
||||
|
||||
# Create new expert indices (with redundancy)
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts)
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
@ -337,9 +346,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
||||
)
|
||||
|
||||
# Create expert weights
|
||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
||||
hidden_sizes, ep_rank, device,
|
||||
old_indices)
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
# Execute weight rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
@ -383,8 +392,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size,
|
||||
pipeline_model_parallel_size=1)
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
@ -401,12 +410,12 @@ def test_rearrange_expert_weights_no_change(world_size):
|
||||
|
||||
# Same indices - no change
|
||||
indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts,
|
||||
redundancy_config)
|
||||
num_layers, num_logical_experts, total_physical_experts, redundancy_config
|
||||
)
|
||||
|
||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
||||
hidden_sizes, ep_rank, device,
|
||||
indices)
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
|
||||
)
|
||||
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
@ -422,7 +431,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False)
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
for layer in range(num_layers):
|
||||
@ -430,8 +440,8 @@ def test_rearrange_expert_weights_no_change(world_size):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg=f"Layer {layer}, weight {weight_idx} should remain "
|
||||
f"unchanged")
|
||||
msg=f"Layer {layer}, weight {weight_idx} should remain unchanged",
|
||||
)
|
||||
|
||||
distributed_run(worker_fn, world_size)
|
||||
|
||||
@ -446,8 +456,8 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size,
|
||||
pipeline_model_parallel_size=1)
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
@ -460,21 +470,23 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
hidden_sizes = [32]
|
||||
|
||||
# Create different index distributions
|
||||
old_redundancy = create_redundancy_config(num_logical_experts,
|
||||
total_physical_experts)
|
||||
new_redundancy = create_redundancy_config(num_logical_experts,
|
||||
total_physical_experts)
|
||||
old_redundancy = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
new_redundancy = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts,
|
||||
old_redundancy)
|
||||
num_layers, num_logical_experts, total_physical_experts, old_redundancy
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts,
|
||||
new_redundancy)
|
||||
num_layers, num_logical_experts, total_physical_experts, new_redundancy
|
||||
)
|
||||
|
||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
||||
hidden_sizes, ep_rank, device,
|
||||
old_indices)
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
@ -490,7 +502,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=True # Profile mode
|
||||
is_profile=True, # Profile mode
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
@ -499,6 +511,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg="In profile mode, the weights should remain unchanged")
|
||||
msg="In profile mode, the weights should remain unchanged",
|
||||
)
|
||||
|
||||
distributed_run(worker_fn, world_size)
|
||||
|
||||
@ -6,24 +6,29 @@ import time
|
||||
import msgspec
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
||||
NullEventPublisher)
|
||||
from vllm.distributed.kv_events import (
|
||||
EventBatch,
|
||||
EventPublisherFactory,
|
||||
NullEventPublisher,
|
||||
)
|
||||
|
||||
DP_RANK = 0
|
||||
|
||||
|
||||
class EventSample(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore
|
||||
array_like=True # type: ignore
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore
|
||||
array_like=True, # type: ignore
|
||||
):
|
||||
"""Test event for publisher testing"""
|
||||
|
||||
id: int
|
||||
value: str
|
||||
|
||||
|
||||
class SampleBatch(EventBatch):
|
||||
"""Test event batch for publisher testing"""
|
||||
|
||||
events: list[EventSample]
|
||||
|
||||
|
||||
@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber):
|
||||
|
||||
seq, received = result
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert received.ts == pytest.approx(test_batch.ts,
|
||||
abs=0.1), ("Timestamp mismatch")
|
||||
assert len(received.events) == len(
|
||||
test_batch.events), ("Number of events mismatch")
|
||||
assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch"
|
||||
assert len(received.events) == len(test_batch.events), "Number of events mismatch"
|
||||
|
||||
for i, event in enumerate(received.events):
|
||||
assert event.id == i, "Event id mismatch"
|
||||
@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber):
|
||||
assert len(replayed) > 0, "No replayed messages received"
|
||||
seqs = [seq for seq, _ in replayed]
|
||||
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
||||
assert seqs == list(range(min(seqs),
|
||||
max(seqs) +
|
||||
1)), ("Replayed messages not consecutive")
|
||||
assert seqs == list(range(min(seqs), max(seqs) + 1)), (
|
||||
"Replayed messages not consecutive"
|
||||
)
|
||||
|
||||
|
||||
def test_buffer_limit(publisher, subscriber, publisher_config):
|
||||
@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config):
|
||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
|
||||
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
||||
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
|
||||
|
||||
@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config):
|
||||
|
||||
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is not None for msg in foo_received), (
|
||||
"Subscriber with matching topic should receive messages")
|
||||
"Subscriber with matching topic should receive messages"
|
||||
)
|
||||
|
||||
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is None for msg in bar_received), (
|
||||
"Subscriber with non-matching topic should receive no messages")
|
||||
"Subscriber with non-matching topic should receive no messages"
|
||||
)
|
||||
finally:
|
||||
pub.shutdown()
|
||||
sub_foo.close()
|
||||
@ -178,8 +184,7 @@ def test_high_volume(publisher, subscriber):
|
||||
|
||||
publisher_thread.join()
|
||||
|
||||
assert len(received) >= num_batches * 0.9, (
|
||||
"We should have received most messages")
|
||||
assert len(received) >= num_batches * 0.9, "We should have received most messages"
|
||||
|
||||
seqs = [seq for seq, _ in received]
|
||||
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
||||
@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config):
|
||||
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
||||
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
||||
expected_endpoint_1 = base_endpoint.replace(
|
||||
":5557", ":5558") # rank 1 gets port + 1
|
||||
":5557", ":5558"
|
||||
) # rank 1 gets port + 1
|
||||
else:
|
||||
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
||||
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
||||
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
|
||||
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
||||
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
||||
|
||||
@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config):
|
||||
|
||||
# Verify DP rank tagging
|
||||
assert received_0.data_parallel_rank == 0, (
|
||||
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
|
||||
f"Expected DP rank 0, got {received_0.data_parallel_rank}"
|
||||
)
|
||||
assert received_1.data_parallel_rank == 1, (
|
||||
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
|
||||
f"Expected DP rank 1, got {received_1.data_parallel_rank}"
|
||||
)
|
||||
|
||||
# Verify event content is correct
|
||||
assert len(
|
||||
received_0.events) == 2, "Wrong number of events from rank 0"
|
||||
assert len(
|
||||
received_1.events) == 3, "Wrong number of events from rank 1"
|
||||
assert len(received_0.events) == 2, "Wrong number of events from rank 0"
|
||||
assert len(received_1.events) == 3, "Wrong number of events from rank 1"
|
||||
|
||||
finally:
|
||||
pub_0.shutdown()
|
||||
|
||||
@ -46,28 +46,24 @@ class EPTestSettings:
|
||||
):
|
||||
return EPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
|
||||
ParallelSetup(
|
||||
tp_size=2 * tp_base, eager_mode=False, chunked_prefill=True
|
||||
),
|
||||
ParallelSetup(
|
||||
tp_size=2 * tp_base, eager_mode=True, chunked_prefill=False
|
||||
),
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
runner=runner,
|
||||
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides),
|
||||
test_options=EPTestOptions(
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -82,16 +78,16 @@ class EPTestSettings:
|
||||
):
|
||||
return EPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
runner=runner,
|
||||
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides),
|
||||
test_options=EPTestOptions(
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides,
|
||||
),
|
||||
)
|
||||
|
||||
def iter_params(self, model_name: str):
|
||||
@ -99,8 +95,13 @@ class EPTestSettings:
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for distributed_backend in self.distributed_backends:
|
||||
yield (model_name, parallel_setup, distributed_backend,
|
||||
self.runner, opts)
|
||||
yield (
|
||||
model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
self.runner,
|
||||
opts,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: You can adjust tp_base locally to fit the model in GPU
|
||||
|
||||
@ -6,8 +6,7 @@ import pytest
|
||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||
|
||||
|
||||
def verify_round_robin_pattern(expert_map, ep_rank, ep_size,
|
||||
global_num_experts):
|
||||
def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts):
|
||||
"""Verify that the expert map follows the round_robin pattern."""
|
||||
# Calculate expected local experts (supporting non-divisible cases)
|
||||
base_experts = global_num_experts // ep_size
|
||||
@ -30,24 +29,21 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size,
|
||||
if global_expert_id in expected_expert_ids:
|
||||
local_expert_id = expert_map[global_expert_id]
|
||||
expected_local_id = expected_expert_ids.index(global_expert_id)
|
||||
assert (
|
||||
local_expert_id == expected_local_id
|
||||
), f"Global expert {global_expert_id} should map to local expert " \
|
||||
assert local_expert_id == expected_local_id, (
|
||||
f"Global expert {global_expert_id} should map to local expert "
|
||||
f"{expected_local_id}, got {local_expert_id}"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
expert_map[global_expert_id] == -1
|
||||
), f"Global expert {global_expert_id} should not be mapped to " \
|
||||
f"this rank"
|
||||
assert expert_map[global_expert_id] == -1, (
|
||||
f"Global expert {global_expert_id} should not be mapped to this rank"
|
||||
)
|
||||
|
||||
# Verify that all local expert IDs are consecutive starting from 0
|
||||
local_expert_ids = [
|
||||
expert_map[global_id] for global_id in expected_expert_ids
|
||||
]
|
||||
local_expert_ids = [expert_map[global_id] for global_id in expected_expert_ids]
|
||||
expected_local_ids = list(range(local_num_experts))
|
||||
assert (
|
||||
local_expert_ids == expected_local_ids
|
||||
), f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
|
||||
assert local_expert_ids == expected_local_ids, (
|
||||
f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
||||
@ -78,8 +74,9 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
|
||||
|
||||
for test_global_experts, test_ep_size in test_cases:
|
||||
# Ensure ep_size matches world_size
|
||||
assert (test_ep_size == world_size
|
||||
), f"ep_size {test_ep_size} must equal world_size {world_size}"
|
||||
assert test_ep_size == world_size, (
|
||||
f"ep_size {test_ep_size} must equal world_size {world_size}"
|
||||
)
|
||||
|
||||
# Test each rank
|
||||
for ep_rank in range(world_size):
|
||||
@ -98,21 +95,22 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
)
|
||||
|
||||
assert (
|
||||
test_local_experts == expected_test_local
|
||||
), f"For {test_global_experts} experts on {test_ep_size} ranks, " \
|
||||
f"rank {ep_rank}: expected {expected_test_local} local" \
|
||||
assert test_local_experts == expected_test_local, (
|
||||
f"For {test_global_experts} experts on {test_ep_size} ranks, "
|
||||
f"rank {ep_rank}: expected {expected_test_local} local"
|
||||
f"experts, got {test_local_experts}"
|
||||
)
|
||||
|
||||
if test_expert_map is not None:
|
||||
assert test_expert_map.shape == (
|
||||
test_global_experts,
|
||||
), f"Expected expert map shape ({test_global_experts},), " \
|
||||
assert test_expert_map.shape == (test_global_experts,), (
|
||||
f"Expected expert map shape ({test_global_experts},), "
|
||||
f"got {test_expert_map.shape}"
|
||||
)
|
||||
|
||||
# Verify round_robin pattern for this test case
|
||||
verify_round_robin_pattern(test_expert_map, ep_rank,
|
||||
test_ep_size, test_global_experts)
|
||||
verify_round_robin_pattern(
|
||||
test_expert_map, ep_rank, test_ep_size, test_global_experts
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
||||
@ -147,28 +145,81 @@ def test_determine_expert_map_comprehensive():
|
||||
# expert_placement_strategy, expected_local, expected_map_pattern)
|
||||
test_cases = [
|
||||
# Round robin placement tests
|
||||
(2, 0, 8, "round_robin", 4, [0, -1, 1, -1, 2, -1, 3,
|
||||
-1]), # rank 0 gets even experts
|
||||
(2, 1, 8, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1,
|
||||
3]), # rank 1 gets odd experts
|
||||
(2, 0, 9, "round_robin", 5, [0, -1, 1, -1, 2, -1, 3, -1, 4
|
||||
]), # rank 0 gets 5 experts (even + last)
|
||||
(2, 1, 9, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, 3,
|
||||
-1]), # rank 1 gets 4 experts (odd)
|
||||
|
||||
(
|
||||
2,
|
||||
0,
|
||||
8,
|
||||
"round_robin",
|
||||
4,
|
||||
[0, -1, 1, -1, 2, -1, 3, -1],
|
||||
), # rank 0 gets even experts
|
||||
(
|
||||
2,
|
||||
1,
|
||||
8,
|
||||
"round_robin",
|
||||
4,
|
||||
[-1, 0, -1, 1, -1, 2, -1, 3],
|
||||
), # rank 1 gets odd experts
|
||||
(
|
||||
2,
|
||||
0,
|
||||
9,
|
||||
"round_robin",
|
||||
5,
|
||||
[0, -1, 1, -1, 2, -1, 3, -1, 4],
|
||||
), # rank 0 gets 5 experts (even + last)
|
||||
(
|
||||
2,
|
||||
1,
|
||||
9,
|
||||
"round_robin",
|
||||
4,
|
||||
[-1, 0, -1, 1, -1, 2, -1, 3, -1],
|
||||
), # rank 1 gets 4 experts (odd)
|
||||
# 4-rank tests
|
||||
(4, 0, 8, "round_robin", 2, [0, -1, -1, -1, 1, -1, -1,
|
||||
-1]), # rank 0 gets experts 0, 4
|
||||
(4, 1, 8, "round_robin", 2, [-1, 0, -1, -1, -1, 1, -1,
|
||||
-1]), # rank 1 gets experts 1, 5
|
||||
(4, 2, 8, "round_robin", 2, [-1, -1, 0, -1, -1, -1, 1,
|
||||
-1]), # rank 2 gets experts 2, 6
|
||||
(4, 3, 8, "round_robin", 2, [-1, -1, -1, 0, -1, -1, -1,
|
||||
1]), # rank 3 gets experts 3, 7
|
||||
(
|
||||
4,
|
||||
0,
|
||||
8,
|
||||
"round_robin",
|
||||
2,
|
||||
[0, -1, -1, -1, 1, -1, -1, -1],
|
||||
), # rank 0 gets experts 0, 4
|
||||
(
|
||||
4,
|
||||
1,
|
||||
8,
|
||||
"round_robin",
|
||||
2,
|
||||
[-1, 0, -1, -1, -1, 1, -1, -1],
|
||||
), # rank 1 gets experts 1, 5
|
||||
(
|
||||
4,
|
||||
2,
|
||||
8,
|
||||
"round_robin",
|
||||
2,
|
||||
[-1, -1, 0, -1, -1, -1, 1, -1],
|
||||
), # rank 2 gets experts 2, 6
|
||||
(
|
||||
4,
|
||||
3,
|
||||
8,
|
||||
"round_robin",
|
||||
2,
|
||||
[-1, -1, -1, 0, -1, -1, -1, 1],
|
||||
), # rank 3 gets experts 3, 7
|
||||
]
|
||||
|
||||
for ep_size, ep_rank, global_num_experts, expert_placement_strategy, \
|
||||
expected_local, expected_map_pattern in test_cases:
|
||||
for (
|
||||
ep_size,
|
||||
ep_rank,
|
||||
global_num_experts,
|
||||
expert_placement_strategy,
|
||||
expected_local,
|
||||
expected_map_pattern,
|
||||
) in test_cases:
|
||||
local_num_experts, expert_map = determine_expert_map(
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
@ -176,19 +227,21 @@ def test_determine_expert_map_comprehensive():
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
)
|
||||
|
||||
assert local_num_experts == expected_local, \
|
||||
f"ep_size={ep_size}, ep_rank={ep_rank}, " \
|
||||
f"global_num_experts={global_num_experts}, " \
|
||||
f"expert_placement_strategy={expert_placement_strategy}: " \
|
||||
assert local_num_experts == expected_local, (
|
||||
f"ep_size={ep_size}, ep_rank={ep_rank}, "
|
||||
f"global_num_experts={global_num_experts}, "
|
||||
f"expert_placement_strategy={expert_placement_strategy}: "
|
||||
f"expected {expected_local} local experts, got {local_num_experts}"
|
||||
)
|
||||
|
||||
if expected_map_pattern is None:
|
||||
assert expert_map is None, "Expected expert_map to be None"
|
||||
else:
|
||||
assert expert_map is not None, "Expected expert_map to not be None"
|
||||
actual_map = expert_map.tolist()
|
||||
assert actual_map == expected_map_pattern, \
|
||||
f"ep_size={ep_size}, ep_rank={ep_rank}, " \
|
||||
f"global_num_experts={global_num_experts}, " \
|
||||
f"expert_placement_strategy={expert_placement_strategy}: " \
|
||||
assert actual_map == expected_map_pattern, (
|
||||
f"ep_size={ep_size}, ep_rank={ep_rank}, "
|
||||
f"global_num_experts={global_num_experts}, "
|
||||
f"expert_placement_strategy={expert_placement_strategy}: "
|
||||
f"expected map {expected_map_pattern}, got {actual_map}"
|
||||
)
|
||||
|
||||
@ -1,10 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.config import (
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
get_kv_connector_cache_layout,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger("test_expert_parallel")
|
||||
@ -23,8 +29,9 @@ def test_get_kv_connector_cache_layout_with_lmcache_connector():
|
||||
kv_connector="LMCacheConnectorV1",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||
kv_transfer_config=kv_transfer_config)
|
||||
vllm_config = VllmConfig(
|
||||
device_config=DeviceConfig("cpu"), kv_transfer_config=kv_transfer_config
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
@ -37,9 +44,11 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
|
||||
kv_role="kv_both",
|
||||
)
|
||||
model_config = ModelConfig()
|
||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||
model_config=model_config,
|
||||
kv_transfer_config=kv_transfer_config)
|
||||
vllm_config = VllmConfig(
|
||||
device_config=DeviceConfig("cpu"),
|
||||
model_config=model_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
@ -47,25 +56,22 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_with_multi_connector():
|
||||
kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [{
|
||||
"kv_connector":
|
||||
"SharedStorageConnector",
|
||||
"kv_role":
|
||||
"kv_both"
|
||||
}, {
|
||||
"kv_connector":
|
||||
"NixlConnector",
|
||||
"kv_role":
|
||||
"kv_both"
|
||||
}]
|
||||
})
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [
|
||||
{"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"},
|
||||
{"kv_connector": "NixlConnector", "kv_role": "kv_both"},
|
||||
]
|
||||
},
|
||||
)
|
||||
model_config = ModelConfig()
|
||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||
model_config=model_config,
|
||||
kv_transfer_config=kv_transfer_config)
|
||||
vllm_config = VllmConfig(
|
||||
device_config=DeviceConfig("cpu"),
|
||||
model_config=model_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
|
||||
@ -24,14 +24,13 @@ from vllm.utils import get_ip
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not VLLM_MULTI_NODE,
|
||||
reason="Need at least 2 nodes to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
not VLLM_MULTI_NODE, reason="Need at least 2 nodes to run the test."
|
||||
)
|
||||
def test_multi_node_assignment() -> None:
|
||||
|
||||
# NOTE: important to keep this class definition here
|
||||
# to let ray use cloudpickle to serialize it.
|
||||
class Actor:
|
||||
|
||||
def get_ip(self):
|
||||
return get_ip()
|
||||
|
||||
@ -41,8 +40,7 @@ def test_multi_node_assignment() -> None:
|
||||
|
||||
current_ip = get_ip()
|
||||
workers = []
|
||||
for bundle_id, bundle in enumerate(
|
||||
config.placement_group.bundle_specs):
|
||||
for bundle_id, bundle in enumerate(config.placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
|
||||
@ -11,15 +11,17 @@ import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
register_nccl_symmetric_ops)
|
||||
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
get_nccl_mem_pool, is_symmetric_memory_enabled)
|
||||
from vllm.distributed.parallel_state import (get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
get_nccl_mem_pool,
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
@ -38,31 +40,32 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
update_environment_variables({
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
})
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
cuda_communicator = typing.cast(CudaCommunicator,
|
||||
get_tp_group().device_communicator)
|
||||
cuda_communicator = typing.cast(
|
||||
CudaCommunicator, get_tp_group().device_communicator
|
||||
)
|
||||
pynccl_comm = cuda_communicator.pynccl_comm
|
||||
if get_nccl_mem_pool() is None:
|
||||
pytest.skip("NCCL allocator compilation failed "
|
||||
"(probably missing NCCL headers).")
|
||||
pytest.skip(
|
||||
"NCCL allocator compilation failed (probably missing NCCL headers)."
|
||||
)
|
||||
if not is_symmetric_memory_enabled():
|
||||
pytest.skip("NCCL symmetric memory allreduce is disabled.")
|
||||
|
||||
register_nccl_symmetric_ops(pynccl_comm)
|
||||
input = torch.randint(1,
|
||||
23, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
input = torch.randint(1, 23, (test_size_elements,), dtype=dtype, device=device)
|
||||
input_clone = input.clone()
|
||||
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
|
||||
assert output is not None
|
||||
@ -77,8 +80,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
|
||||
)
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
@ -88,7 +90,5 @@ def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
||||
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
|
||||
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
|
||||
|
||||
mp.spawn(nccl_symm_mem_allreduce_worker,
|
||||
args=(world_size, ),
|
||||
nprocs=world_size)
|
||||
mp.spawn(nccl_symm_mem_allreduce_worker, args=(world_size,), nprocs=world_size)
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
@ -32,12 +32,15 @@ if __name__ == "__main__":
|
||||
# Expected node count based on environment variable)
|
||||
expected = int(os.environ.get("NUM_NODES", "1"))
|
||||
|
||||
assert test_result == expected, \
|
||||
f"Expected {expected} nodes, got {test_result}"
|
||||
assert test_result == expected, f"Expected {expected} nodes, got {test_result}"
|
||||
|
||||
if pg == dist.group.WORLD:
|
||||
print(f"Node count test passed! Got {test_result} nodes "
|
||||
f"when using torch distributed!")
|
||||
print(
|
||||
f"Node count test passed! Got {test_result} nodes "
|
||||
f"when using torch distributed!"
|
||||
)
|
||||
else:
|
||||
print(f"Node count test passed! Got {test_result} nodes "
|
||||
f"when using StatelessProcessGroup!")
|
||||
print(
|
||||
f"Node count test passed! Got {test_result} nodes "
|
||||
f"when using StatelessProcessGroup!"
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
all workers in a node other than the head node, which can cause the test
|
||||
to fail.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
@ -55,26 +56,17 @@ class PPTestSettings:
|
||||
):
|
||||
return PPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
eager_mode=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
eager_mode=True),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=False),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=True),
|
||||
ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=False),
|
||||
ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=False),
|
||||
ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=True),
|
||||
ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=False),
|
||||
ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=True),
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
runner=runner,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
test_options=PPTestOptions(
|
||||
multi_node_only=multi_node_only, load_format=load_format
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -86,17 +78,15 @@ class PPTestSettings:
|
||||
multi_node_only: bool = False,
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
|
||||
return PPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=True),
|
||||
ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=True),
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
runner=runner,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
test_options=PPTestOptions(
|
||||
multi_node_only=multi_node_only, load_format=load_format
|
||||
),
|
||||
)
|
||||
|
||||
def iter_params(self, model_id: str):
|
||||
@ -281,8 +271,10 @@ def _compare_tp(
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
pytest.skip(
|
||||
"Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend"
|
||||
)
|
||||
if multi_node_only and not VLLM_MULTI_NODE:
|
||||
pytest.skip("Not in multi-node setting")
|
||||
|
||||
@ -357,20 +349,16 @@ def _compare_tp(
|
||||
"mp",
|
||||
]
|
||||
|
||||
compare_two_settings(model_id,
|
||||
pp_args,
|
||||
tp_args,
|
||||
pp_env,
|
||||
tp_env,
|
||||
method=method)
|
||||
compare_two_settings(model_id, pp_args, tp_args, pp_env, tp_env, method=method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
||||
"test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||
[
|
||||
params for model_id, settings in TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
params
|
||||
for model_id, settings in TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_id)
|
||||
if model_id in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
@ -382,22 +370,25 @@ def test_tp_language_generation(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False)
|
||||
_compare_tp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
||||
"test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||
[
|
||||
params for model_id, settings in EMBEDDING_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
params
|
||||
for model_id, settings in EMBEDDING_MODELS.items()
|
||||
for params in settings.iter_params(model_id)
|
||||
if model_id in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
@ -409,22 +400,25 @@ def test_tp_language_embedding(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="encode",
|
||||
is_multimodal=False)
|
||||
_compare_tp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="encode",
|
||||
is_multimodal=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
||||
"test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||
[
|
||||
params for model_id, settings in MULTIMODAL_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
params
|
||||
for model_id, settings in MULTIMODAL_MODELS.items()
|
||||
for params in settings.iter_params(model_id)
|
||||
if model_id in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
@ -436,11 +430,13 @@ def test_tp_multimodal_generation(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=True)
|
||||
_compare_tp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=True,
|
||||
)
|
||||
|
||||
@ -9,7 +9,6 @@ from vllm.distributed.utils import get_pp_indices
|
||||
|
||||
|
||||
def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
def _verify(partition_str, num_layers, pp_size, goldens):
|
||||
@ -57,7 +56,8 @@ def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch):
|
||||
(5, 3, 0, (0, 2)),
|
||||
(5, 3, 1, (2, 4)),
|
||||
(5, 3, 2, (4, 5)),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_uneven_auto_partition(
|
||||
num_hidden_layers: int,
|
||||
pp_size: int,
|
||||
|
||||
@ -12,12 +12,18 @@ if TYPE_CHECKING:
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
|
||||
@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [
|
||||
(2, "JackFram/llama-160m"),
|
||||
])
|
||||
@pytest.mark.parametrize("ATTN_BACKEND", [
|
||||
"FLASH_ATTN",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"PP_SIZE, MODEL_NAME",
|
||||
[
|
||||
(2, "JackFram/llama-160m"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"ATTN_BACKEND",
|
||||
[
|
||||
"FLASH_ATTN",
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_pp_cudagraph(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@ -9,13 +9,15 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
get_world_group, graph_capture,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_world_group,
|
||||
graph_capture,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
@ -24,13 +26,13 @@ def distributed_run(fn, world_size):
|
||||
processes: list[multiprocessing.Process] = []
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env['RANK'] = str(i)
|
||||
env['LOCAL_RANK'] = str(i)
|
||||
env['WORLD_SIZE'] = str(number_of_processes)
|
||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
||||
env['MASTER_ADDR'] = 'localhost'
|
||||
env['MASTER_PORT'] = '12345'
|
||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
||||
env["RANK"] = str(i)
|
||||
env["LOCAL_RANK"] = str(i)
|
||||
env["WORLD_SIZE"] = str(number_of_processes)
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = multiprocessing.Process(target=fn, args=(env,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
@ -47,7 +49,7 @@ def worker_fn_wrapper(fn):
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ['LOCAL_RANK']
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_distributed_environment()
|
||||
@ -58,17 +60,18 @@ def worker_fn_wrapper(fn):
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
pynccl_comm = PyNcclCommunicator(
|
||||
get_world_group().cpu_group, device=get_world_group().device
|
||||
)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl():
|
||||
distributed_run(worker_fn, 2)
|
||||
|
||||
@ -78,7 +81,7 @@ def multiple_allreduce_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
|
||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||
@ -95,8 +98,9 @@ def multiple_allreduce_worker_fn():
|
||||
assert torch.all(tensor == 2).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_multiple_allreduce():
|
||||
# this tests pynccl for multiple tp groups, in a standalone way
|
||||
# i.e. call `pynccl_comm.all_reduce` directly
|
||||
@ -121,8 +125,9 @@ def multiple_allreduce_with_vllm_worker_fn():
|
||||
assert torch.all(tensor == 2).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_multiple_allreduce_with_vllm():
|
||||
# this tests pynccl for multiple tp groups, together with vllm
|
||||
# i.e. call `tensor_model_parallel_all_reduce`
|
||||
@ -133,10 +138,11 @@ def test_pynccl_multiple_allreduce_with_vllm():
|
||||
def worker_fn_with_cudagraph():
|
||||
with torch.no_grad():
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
pynccl_comm = PyNcclCommunicator(
|
||||
get_world_group().cpu_group, device=get_world_group().device
|
||||
)
|
||||
# run something in the default stream to initialize torch engine
|
||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||
a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.graph(graph):
|
||||
a_out = pynccl_comm.all_reduce(a)
|
||||
@ -148,84 +154,90 @@ def worker_fn_with_cudagraph():
|
||||
|
||||
@worker_fn_wrapper
|
||||
def all_gather_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
pynccl_comm = PyNcclCommunicator(
|
||||
get_world_group().cpu_group, device=get_world_group().device
|
||||
)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
device = f"cuda:{pynccl_comm.rank}"
|
||||
|
||||
num_elems = 1000
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * num_elems
|
||||
result = torch.zeros(num_elems * world_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
tensor = (
|
||||
torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
|
||||
)
|
||||
result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device)
|
||||
|
||||
expected = torch.cat([
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||
for r in range(world_size)
|
||||
]).to(device)
|
||||
expected = torch.cat(
|
||||
[
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||
for r in range(world_size)
|
||||
]
|
||||
).to(device)
|
||||
|
||||
pynccl_comm.all_gather(result, tensor)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_all_gather():
|
||||
distributed_run(all_gather_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def all_gatherv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
pynccl_comm = PyNcclCommunicator(
|
||||
get_world_group().cpu_group, device=get_world_group().device
|
||||
)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
device = f"cuda:{pynccl_comm.rank}"
|
||||
|
||||
assert world_size <= 8
|
||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||
num_elems = sizes[rank]
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * 100
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
|
||||
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
|
||||
|
||||
expected = torch.cat([
|
||||
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
||||
for r in range(world_size)
|
||||
]).to(device)
|
||||
expected = torch.cat(
|
||||
[
|
||||
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
||||
for r in range(world_size)
|
||||
]
|
||||
).to(device)
|
||||
|
||||
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_all_gatherv():
|
||||
distributed_run(all_gatherv_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def reduce_scatter_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
pynccl_comm = PyNcclCommunicator(
|
||||
get_world_group().cpu_group, device=get_world_group().device
|
||||
)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
device = f"cuda:{pynccl_comm.rank}"
|
||||
|
||||
num_elems = 1000
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * num_elems
|
||||
assert (num_elems % world_size == 0)
|
||||
result = torch.zeros(num_elems // world_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
tensor = (
|
||||
torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
|
||||
)
|
||||
assert num_elems % world_size == 0
|
||||
result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device)
|
||||
|
||||
# Calculate expected result for this rank's chunk
|
||||
scattered_size = num_elems // world_size
|
||||
@ -233,34 +245,37 @@ def reduce_scatter_worker_fn():
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||
for r in range(world_size)
|
||||
]
|
||||
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
||||
for tensor in all_tensors).to(device)
|
||||
expected = sum(
|
||||
tensor[rank * scattered_size : (rank + 1) * scattered_size]
|
||||
for tensor in all_tensors
|
||||
).to(device)
|
||||
|
||||
pynccl_comm.reduce_scatter(result, tensor)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_reduce_scatter():
|
||||
distributed_run(reduce_scatter_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def reduce_scatterv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
pynccl_comm = PyNcclCommunicator(
|
||||
get_world_group().cpu_group, device=get_world_group().device
|
||||
)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
device = f"cuda:{pynccl_comm.rank}"
|
||||
|
||||
assert world_size <= 8
|
||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||
num_elems = sum(sizes)
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * 100
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
|
||||
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
|
||||
|
||||
# Calculate expected result for this rank's chunk
|
||||
@ -278,41 +293,41 @@ def reduce_scatterv_worker_fn():
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_reduce_scatterv():
|
||||
distributed_run(reduce_scatterv_worker_fn, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_with_cudagraph():
|
||||
distributed_run(worker_fn_with_cudagraph, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def send_recv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
pynccl_comm = PyNcclCommunicator(
|
||||
get_world_group().cpu_group, device=get_world_group().device
|
||||
)
|
||||
if pynccl_comm.rank == 0:
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
else:
|
||||
tensor = torch.empty(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
|
||||
if pynccl_comm.rank == 0:
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(tensor == 1).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_send_recv():
|
||||
distributed_run(send_recv_worker_fn, 2)
|
||||
|
||||
@ -322,27 +337,20 @@ def multiple_send_recv_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[1, 3], backend="gloo")
|
||||
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
elif torch.distributed.get_rank() == 1:
|
||||
tensor = 2 * torch.ones(
|
||||
16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
else:
|
||||
tensor = torch.empty(16,
|
||||
1024,
|
||||
1024,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||
torch.cuda.synchronize()
|
||||
if torch.distributed.get_rank() in [0, 2]:
|
||||
assert torch.all(tensor == 1).cpu().item()
|
||||
@ -350,14 +358,16 @@ def multiple_send_recv_worker_fn():
|
||||
assert torch.all(tensor == 2).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_multiple_send_recv():
|
||||
distributed_run(multiple_send_recv_worker_fn, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
|
||||
)
|
||||
def test_pynccl_broadcast():
|
||||
distributed_run(broadcast_worker_fn, 4)
|
||||
|
||||
@ -366,19 +376,17 @@ def test_pynccl_broadcast():
|
||||
def broadcast_worker_fn():
|
||||
# Test broadcast for every root rank.
|
||||
# Essentially this is an all-gather operation.
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
pynccl_comm = PyNcclCommunicator(
|
||||
get_world_group().cpu_group, device=get_world_group().device
|
||||
)
|
||||
recv_tensors = [
|
||||
torch.empty(16,
|
||||
1024,
|
||||
1024,
|
||||
dtype=torch.float32,
|
||||
device=pynccl_comm.device)
|
||||
torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
|
||||
for i in range(pynccl_comm.world_size)
|
||||
]
|
||||
recv_tensors[pynccl_comm.rank] = torch.ones(
|
||||
16, 1024, 1024, dtype=torch.float32,
|
||||
device=pynccl_comm.device) * pynccl_comm.rank
|
||||
recv_tensors[pynccl_comm.rank] = (
|
||||
torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
|
||||
* pynccl_comm.rank
|
||||
)
|
||||
|
||||
for i in range(pynccl_comm.world_size):
|
||||
pynccl_comm.broadcast(recv_tensors[i], src=i)
|
||||
|
||||
@ -8,20 +8,20 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import (ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment, multi_process_parallel)
|
||||
from ..utils import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment,
|
||||
multi_process_parallel,
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(44)
|
||||
# Size over 8MB is sufficient for custom quick allreduce.
|
||||
test_sizes = [
|
||||
random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)
|
||||
]
|
||||
test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)]
|
||||
for i, v in enumerate(test_sizes):
|
||||
test_sizes[i] -= v % 8
|
||||
|
||||
@ -38,8 +38,7 @@ def graph_quickreduce(
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
group = get_tp_group().device_group
|
||||
|
||||
@ -64,18 +63,15 @@ def graph_quickreduce(
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
with graph_capture(device=device) as graph_capture_context:
|
||||
inp1 = torch.randint(1,
|
||||
23, (sz, ),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device())
|
||||
inp2 = torch.randint(-23,
|
||||
1, (sz, ),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device())
|
||||
inp1 = torch.randint(
|
||||
1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
inp2 = torch.randint(
|
||||
-23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph,
|
||||
stream=graph_capture_context.stream):
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for _ in range(num_communication):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
dist.all_reduce(inp1, group=group)
|
||||
@ -99,39 +95,42 @@ def eager_quickreduce(
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
# Size over 8MB is sufficient for custom quick allreduce.
|
||||
sz = 16 * 1024 * 1024
|
||||
fa = get_tp_group().device_communicator.qr_comm
|
||||
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
|
||||
dtype=torch.float16,
|
||||
device=device)
|
||||
inp = torch.tensor(
|
||||
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device
|
||||
)
|
||||
out = fa.quick_all_reduce(inp)
|
||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||
|
||||
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
|
||||
dtype=torch.bfloat16,
|
||||
device=device)
|
||||
inp = torch.tensor(
|
||||
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
|
||||
)
|
||||
out = fa.quick_all_reduce(inp)
|
||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test quick allreduce for rocm")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
|
||||
)
|
||||
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
|
||||
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
|
||||
def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
||||
pipeline_parallel_size, test_target,
|
||||
quant_mode):
|
||||
def test_custom_quick_allreduce(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size,
|
||||
pipeline_parallel_size,
|
||||
test_target,
|
||||
quant_mode,
|
||||
):
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
||||
|
||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
|
||||
test_target)
|
||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||
|
||||
@ -22,15 +22,13 @@ if __name__ == "__main__":
|
||||
dist.broadcast_object_list(recv, src=0)
|
||||
ip, port = recv
|
||||
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
||||
dist.get_world_size())
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
||||
|
||||
for pg in [dist.group.WORLD, stateless_pg]:
|
||||
test_result = all(in_the_same_node_as(pg, source_rank=0))
|
||||
|
||||
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||
assert test_result == expected, \
|
||||
f"Expected {expected}, got {test_result}"
|
||||
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
||||
if pg == dist.group.WORLD:
|
||||
print("Same node test passed! when using torch distributed!")
|
||||
else:
|
||||
|
||||
@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
all workers in a node other than the head node, which can cause the test
|
||||
to fail.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
@ -56,7 +57,8 @@ class SPTestSettings:
|
||||
raise ValueError(
|
||||
f"Length mismatch: distributed_backends "
|
||||
f"({len(self.distributed_backends)}) != "
|
||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
||||
f"vllm_major_versions ({len(self.vllm_major_versions)})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
@ -72,18 +74,22 @@ class SPTestSettings:
|
||||
for pp_multiplier in [1, 2]:
|
||||
for chunked_prefill_val in [False, True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val))
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
)
|
||||
return SPTestSettings(
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp", "ray"],
|
||||
vllm_major_versions=["1", "1"],
|
||||
runner=runner,
|
||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
test_options=SPTestOptions(
|
||||
multi_node_only=multi_node_only, load_format=load_format
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -100,18 +106,22 @@ class SPTestSettings:
|
||||
for pp_multiplier in [1, 2]:
|
||||
for chunked_prefill_val in [False, True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val))
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
)
|
||||
return SPTestSettings(
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp", "ray"],
|
||||
vllm_major_versions=["1", "1"],
|
||||
runner=runner,
|
||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
test_options=SPTestOptions(
|
||||
multi_node_only=multi_node_only, load_format=load_format
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -126,28 +136,39 @@ class SPTestSettings:
|
||||
parallel_setups = []
|
||||
for fusion_val in [False, True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
enable_fusion=fusion_val,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False))
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
enable_fusion=fusion_val,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False,
|
||||
)
|
||||
)
|
||||
return SPTestSettings(
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp", "ray"],
|
||||
vllm_major_versions=["1", "1"],
|
||||
runner=runner,
|
||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
test_options=SPTestOptions(
|
||||
multi_node_only=multi_node_only, load_format=load_format
|
||||
),
|
||||
)
|
||||
|
||||
def iter_params(self, model_id: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||
self.vllm_major_versions):
|
||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
||||
self.runner, opts)
|
||||
for backend, vllm_major_version in zip(
|
||||
self.distributed_backends, self.vllm_major_versions
|
||||
):
|
||||
yield (
|
||||
model_id,
|
||||
parallel_setup,
|
||||
backend,
|
||||
vllm_major_version,
|
||||
self.runner,
|
||||
opts,
|
||||
)
|
||||
|
||||
|
||||
def _compare_sp(
|
||||
@ -200,8 +221,10 @@ def _compare_sp(
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
pytest.skip(
|
||||
"Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend"
|
||||
)
|
||||
if multi_node_only and not VLLM_MULTI_NODE:
|
||||
pytest.skip("Not in multi-node setting")
|
||||
|
||||
@ -232,13 +255,13 @@ def _compare_sp(
|
||||
common_args.append("--skip-tokenizer-init")
|
||||
|
||||
compilation_config = {
|
||||
'level': 3,
|
||||
'custom_ops': ["+rms_norm"],
|
||||
'compile_sizes': [4, 8],
|
||||
'pass_config': {
|
||||
'enable_sequence_parallelism': True,
|
||||
'enable_fusion': enable_fusion,
|
||||
'enable_noop': True,
|
||||
"level": 3,
|
||||
"custom_ops": ["+rms_norm"],
|
||||
"compile_sizes": [4, 8],
|
||||
"pass_config": {
|
||||
"enable_sequence_parallelism": True,
|
||||
"enable_fusion": enable_fusion,
|
||||
"enable_noop": True,
|
||||
},
|
||||
}
|
||||
|
||||
@ -270,12 +293,9 @@ def _compare_sp(
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(model_id,
|
||||
tp_sp_args,
|
||||
tp_args,
|
||||
tp_sp_env,
|
||||
tp_env,
|
||||
method=method)
|
||||
compare_two_settings(
|
||||
model_id, tp_sp_args, tp_args, tp_sp_env, tp_env, method=method
|
||||
)
|
||||
except Exception:
|
||||
testing_ray_compiled_graph = tp_sp_env is not None
|
||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||
@ -301,10 +321,17 @@ SP_TEST_MODELS = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"runner", "test_options"),
|
||||
(
|
||||
"model_id",
|
||||
"parallel_setup",
|
||||
"distributed_backend",
|
||||
"vllm_major_version",
|
||||
"runner",
|
||||
"test_options",
|
||||
),
|
||||
[
|
||||
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
|
||||
params
|
||||
for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_id)
|
||||
if model_id in SP_TEST_MODELS
|
||||
],
|
||||
@ -319,12 +346,14 @@ def test_tp_sp_generation(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_sp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False)
|
||||
_compare_sp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False,
|
||||
)
|
||||
|
||||
@ -26,13 +26,13 @@ def distributed_run(fn, world_size):
|
||||
processes = []
|
||||
for i in range(number_of_processes):
|
||||
env = {}
|
||||
env['RANK'] = str(i)
|
||||
env['LOCAL_RANK'] = str(i)
|
||||
env['WORLD_SIZE'] = str(number_of_processes)
|
||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
||||
env['MASTER_ADDR'] = 'localhost'
|
||||
env['MASTER_PORT'] = '12345'
|
||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
||||
env["RANK"] = str(i)
|
||||
env["LOCAL_RANK"] = str(i)
|
||||
env["WORLD_SIZE"] = str(number_of_processes)
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = multiprocessing.Process(target=fn, args=(env,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
@ -57,25 +57,23 @@ def worker_fn_wrapper(fn):
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
|
||||
rank = dist.get_rank()
|
||||
if rank == 0:
|
||||
port = get_open_port()
|
||||
ip = '127.0.0.1'
|
||||
ip = "127.0.0.1"
|
||||
dist.broadcast_object_list([ip, port], src=0)
|
||||
else:
|
||||
recv = [None, None]
|
||||
dist.broadcast_object_list(recv, src=0)
|
||||
ip, port = recv # type: ignore
|
||||
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
||||
dist.get_world_size())
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
||||
|
||||
for pg in [dist.group.WORLD, stateless_pg]:
|
||||
|
||||
writer_rank = 2
|
||||
broadcaster = MessageQueue.create_from_process_group(
|
||||
pg, 40 * 1024, 2, writer_rank)
|
||||
pg, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
if rank == writer_rank:
|
||||
seed = random.randint(0, 1000)
|
||||
dist.broadcast_object_list([seed], writer_rank)
|
||||
|
||||
@ -5,7 +5,8 @@ import traceback
|
||||
import unittest
|
||||
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
SingleWriterShmRingBuffer)
|
||||
SingleWriterShmRingBuffer,
|
||||
)
|
||||
|
||||
|
||||
class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||
@ -25,18 +26,21 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||
"""Test opening an existing buffer"""
|
||||
# First create a buffer
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=self.buffer_size, create=True)
|
||||
data_buffer_size=self.buffer_size, create=True
|
||||
)
|
||||
|
||||
# Then open it with another instance
|
||||
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
|
||||
self.assertFalse(reader_buffer.is_writer)
|
||||
self.assertEqual(reader_buffer.shared_memory.name,
|
||||
self.ring_buffer.shared_memory.name)
|
||||
self.assertEqual(
|
||||
reader_buffer.shared_memory.name, self.ring_buffer.shared_memory.name
|
||||
)
|
||||
|
||||
def test_buffer_access(self):
|
||||
"""Test accessing allocated buffers"""
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=self.buffer_size, create=True)
|
||||
data_buffer_size=self.buffer_size, create=True
|
||||
)
|
||||
|
||||
size = 100
|
||||
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
||||
@ -44,11 +48,11 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||
# Write some test data
|
||||
test_data = b"Hello, World!" * 7 # 91 bytes
|
||||
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
||||
data_buf[0:len(test_data)] = test_data
|
||||
data_buf[0 : len(test_data)] = test_data
|
||||
|
||||
# Read it back
|
||||
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
|
||||
read_data = bytes(data_buf2[0:len(test_data)])
|
||||
read_data = bytes(data_buf2[0 : len(test_data)])
|
||||
read_id = metadata2[0]
|
||||
|
||||
self.assertEqual(read_data, test_data)
|
||||
@ -58,7 +62,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||
"""Test that MemoryError is raised when buffer is full"""
|
||||
small_buffer_size = 200
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=small_buffer_size, create=True)
|
||||
data_buffer_size=small_buffer_size, create=True
|
||||
)
|
||||
|
||||
# Fill up the buffer
|
||||
self.ring_buffer.allocate_buf(100)
|
||||
@ -72,7 +77,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||
"""Test allocation and freeing of buffers"""
|
||||
small_buffer_size = 200
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=small_buffer_size, create=True)
|
||||
data_buffer_size=small_buffer_size, create=True
|
||||
)
|
||||
|
||||
size = 80
|
||||
# Write some data
|
||||
@ -81,7 +87,7 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
||||
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
||||
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
|
||||
data_buf[4:len(test_data) + 4] = test_data
|
||||
data_buf[4 : len(test_data) + 4] = test_data
|
||||
print(self.ring_buffer.metadata)
|
||||
freed_ids = self.ring_buffer.free_buf(lambda *args: True)
|
||||
print(f" Freed IDs: {freed_ids}")
|
||||
@ -90,7 +96,8 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||
def test_clear_buffer(self):
|
||||
"""Test clearing the buffer"""
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=self.buffer_size, create=True)
|
||||
data_buffer_size=self.buffer_size, create=True
|
||||
)
|
||||
|
||||
# Allocate some buffers
|
||||
for _ in range(3):
|
||||
@ -121,8 +128,7 @@ def main():
|
||||
# Manual demonstration
|
||||
try:
|
||||
print("Creating ring buffer...")
|
||||
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048,
|
||||
create=True)
|
||||
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, create=True)
|
||||
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
|
||||
|
||||
print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
|
||||
@ -140,7 +146,7 @@ def main():
|
||||
# Write some test data
|
||||
with writer_buffer.access_buf(address) as (data_buf, metadata):
|
||||
test_message = f"Test message {i}".encode()
|
||||
data_buf[0:len(test_message)] = test_message
|
||||
data_buf[0 : len(test_message)] = test_message
|
||||
|
||||
except MemoryError as e:
|
||||
print(f" Failed to allocate {size} bytes: {e}")
|
||||
|
||||
@ -12,28 +12,33 @@ import torch
|
||||
|
||||
# Assuming these are imported from your module
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
|
||||
MultiModalSharedField)
|
||||
MsgpackSerde,
|
||||
SingleWriterShmObjectStorage,
|
||||
SingleWriterShmRingBuffer,
|
||||
)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField,
|
||||
)
|
||||
|
||||
|
||||
def _dummy_elem(modality: str, key: str, size: int):
|
||||
return MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key=key,
|
||||
data=torch.empty((size, ), dtype=torch.int8),
|
||||
data=torch.empty((size,), dtype=torch.int8),
|
||||
field=MultiModalSharedField(1),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||
return MultiModalKwargsItem.from_elems([
|
||||
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
||||
])
|
||||
return MultiModalKwargsItem.from_elems(
|
||||
[_dummy_elem(modality, key, size) for key, size in size_by_key.items()]
|
||||
)
|
||||
|
||||
|
||||
class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures before each test method."""
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
@ -208,8 +213,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.storage.get(address, monotonic_id + 100)
|
||||
|
||||
self.assertIn("has been modified or is invalid", \
|
||||
str(context.exception))
|
||||
self.assertIn("has been modified or is invalid", str(context.exception))
|
||||
|
||||
def test_clear_storage(self):
|
||||
"""Test clearing the storage."""
|
||||
@ -234,8 +238,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
||||
# Reader process function
|
||||
def reader_process(process_id, storage_handle, items_to_read):
|
||||
"""Reader process that connects to existing shared memory and reads data."""
|
||||
reader_storage = SingleWriterShmObjectStorage.create_from_handle(
|
||||
storage_handle)
|
||||
reader_storage = SingleWriterShmObjectStorage.create_from_handle(storage_handle)
|
||||
|
||||
print(f"Reader {process_id} started")
|
||||
|
||||
@ -276,11 +279,7 @@ def run_multiprocess_example():
|
||||
|
||||
# Test basic data types
|
||||
test_data = [
|
||||
("user_data", {
|
||||
"name": "Alice",
|
||||
"age": 30,
|
||||
"scores": [95, 87, 92]
|
||||
}),
|
||||
("user_data", {"name": "Alice", "age": 30, "scores": [95, 87, 92]}),
|
||||
("simple_string", "Hello, World!"),
|
||||
("number", 42),
|
||||
("list_data", [1, 2, 3, "four", 5.0]),
|
||||
@ -301,8 +300,9 @@ def run_multiprocess_example():
|
||||
# initialize lock for reader processes
|
||||
handle.reader_lock = Lock()
|
||||
for i in range(storage.n_readers):
|
||||
p = multiprocessing.Process(target=reader_process,
|
||||
args=(i, handle, stored_items))
|
||||
p = multiprocessing.Process(
|
||||
target=reader_process, args=(i, handle, stored_items)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
|
||||
@ -14,11 +14,12 @@ import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator)
|
||||
from vllm.distributed.parallel_state import (get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.platforms import current_platform
|
||||
@ -32,8 +33,7 @@ test_size_elements = 1024 * 1024
|
||||
|
||||
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
config = VllmConfig(parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=world_size))
|
||||
config = VllmConfig(parallel_config=ParallelConfig(tensor_parallel_size=world_size))
|
||||
|
||||
with monkeypatch.context() as m, set_current_vllm_config(config):
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
@ -42,34 +42,34 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
cuda_communicator = typing.cast(CudaCommunicator,
|
||||
get_tp_group().device_communicator)
|
||||
cuda_communicator = typing.cast(
|
||||
CudaCommunicator, get_tp_group().device_communicator
|
||||
)
|
||||
symm_mem_comm = cuda_communicator.symm_mem_comm
|
||||
if symm_mem_comm is None or symm_mem_comm.disabled:
|
||||
# can't use skip under multiprocessing
|
||||
q.put("SymmMemCommunicator is not available or disabled.")
|
||||
return
|
||||
|
||||
inp_direct_symm_mem = torch.randint(1,
|
||||
23, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
inp_direct_symm_mem = torch.randint(
|
||||
1, 23, (test_size_elements,), dtype=dtype, device=device
|
||||
)
|
||||
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
||||
# can't use skip under multiprocessing
|
||||
q.put(
|
||||
"SymmMemCommunicator isn't used for this world and input size."
|
||||
)
|
||||
q.put("SymmMemCommunicator isn't used for this world and input size.")
|
||||
return
|
||||
|
||||
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
||||
@ -78,42 +78,37 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
||||
|
||||
group = get_tp_group().device_group
|
||||
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
||||
torch.testing.assert_close(out_direct_symm_mem,
|
||||
original_inp_direct_symm_mem,
|
||||
atol=2.5,
|
||||
rtol=0.1)
|
||||
torch.testing.assert_close(
|
||||
out_direct_symm_mem, original_inp_direct_symm_mem, atol=2.5, rtol=0.1
|
||||
)
|
||||
|
||||
# Test tensor_model_parallel_all_reduce which should use symm_mem
|
||||
inp_tensor_parallel = torch.randint(-23,
|
||||
1, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
inp_tensor_parallel = torch.randint(
|
||||
-23, 1, (test_size_elements,), dtype=dtype, device=device
|
||||
)
|
||||
original_inp_tensor_parallel = inp_tensor_parallel.clone()
|
||||
out_tensor_parallel = tensor_model_parallel_all_reduce(
|
||||
inp_tensor_parallel)
|
||||
out_tensor_parallel = tensor_model_parallel_all_reduce(inp_tensor_parallel)
|
||||
dist.all_reduce(original_inp_tensor_parallel, group=group)
|
||||
torch.testing.assert_close(out_tensor_parallel,
|
||||
original_inp_tensor_parallel,
|
||||
atol=2.5,
|
||||
rtol=0.1)
|
||||
torch.testing.assert_close(
|
||||
out_tensor_parallel, original_inp_tensor_parallel, atol=2.5, rtol=0.1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
||||
reason="SymmMemAllreduce is only available for CUDA platforms.",
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
||||
pipeline_parallel_size):
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_symm_mem_allreduce(
|
||||
monkeypatch: pytest.MonkeyPatch, tp_size, pipeline_parallel_size
|
||||
):
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
q = mp.get_context('spawn').Queue()
|
||||
mp.spawn(symm_mem_allreduce_worker,
|
||||
args=(world_size, q),
|
||||
nprocs=world_size)
|
||||
q = mp.get_context("spawn").Queue()
|
||||
mp.spawn(symm_mem_allreduce_worker, args=(world_size, q), nprocs=world_size)
|
||||
try:
|
||||
val = q.get(timeout=1)
|
||||
except queue.Empty:
|
||||
@ -126,18 +121,20 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
reason="SymmMemAllreduce is only available for CUDA platforms.",
|
||||
)
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
|
||||
world_size = 4
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
# Verify that the DataParallel runs without error
|
||||
engine_args = EngineArgs(model="distilbert/distilgpt2",
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=2,
|
||||
data_parallel_backend="mp")
|
||||
engine_args = EngineArgs(
|
||||
model="distilbert/distilgpt2",
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=2,
|
||||
data_parallel_backend="mp",
|
||||
)
|
||||
LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
@ -24,13 +24,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||
# to test if all ranks agree on the same kv cache configuration.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
|
||||
distributed_executor_backend="external_launcher",
|
||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||
swap_space=random.randint(1, 4),
|
||||
seed=0)
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
|
||||
distributed_executor_backend="external_launcher",
|
||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||
swap_space=random.randint(1, 4),
|
||||
seed=0,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -48,15 +50,14 @@ def test_consistent_across_ranks(obj):
|
||||
assert container[0] == obj
|
||||
|
||||
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
# make sure we can access the model parameters from the calling process
|
||||
# of the `LLM` instance.
|
||||
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
|
||||
model.parameters())
|
||||
params = list(
|
||||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
|
||||
)
|
||||
test_consistent_across_ranks(len(params))
|
||||
|
||||
# all ranks should have the same outputs
|
||||
@ -65,5 +66,4 @@ for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
test_consistent_across_ranks(prompt)
|
||||
test_consistent_across_ranks(generated_text)
|
||||
print(f"Rank {torch_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(f"Rank {torch_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
@ -24,23 +24,22 @@ dp_rank = int(os.getenv("DP_RANK", "0"))
|
||||
|
||||
if dp_size > 1:
|
||||
# distribute the prompts across the data parallel ranks
|
||||
prompts = [
|
||||
prompt for idx, prompt in enumerate(prompts)
|
||||
if idx % dp_size == dp_rank
|
||||
]
|
||||
prompts = [prompt for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||
# to test if all ranks agree on the same kv cache configuration.
|
||||
llm = LLM(model="microsoft/Phi-mini-MoE-instruct",
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
|
||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
|
||||
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
|
||||
distributed_executor_backend="external_launcher",
|
||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||
swap_space=random.randint(1, 4),
|
||||
seed=0)
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-mini-MoE-instruct",
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
|
||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
|
||||
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
|
||||
distributed_executor_backend="external_launcher",
|
||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||
swap_space=random.randint(1, 4),
|
||||
seed=0,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -54,21 +53,18 @@ def test_consistent_across_ranks(obj):
|
||||
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
|
||||
else:
|
||||
container = [None]
|
||||
dist.broadcast_object_list(container,
|
||||
src=group.ranks[0],
|
||||
group=cpu_group)
|
||||
dist.broadcast_object_list(container, src=group.ranks[0], group=cpu_group)
|
||||
assert container[0] == obj
|
||||
|
||||
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
# make sure we can access the model parameters from the calling process
|
||||
# of the `LLM` instance.
|
||||
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
|
||||
model.parameters())
|
||||
params = list(
|
||||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
|
||||
)
|
||||
test_consistent_across_ranks(len(params))
|
||||
|
||||
# all ranks should have the same outputs
|
||||
@ -77,5 +73,4 @@ for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
test_consistent_across_ranks(prompt)
|
||||
test_consistent_across_ranks(generated_text)
|
||||
print(f"Rank {group_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(f"Rank {group_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
@ -10,21 +10,22 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import (cuda_device_count_stateless, get_open_port,
|
||||
update_environment_variables)
|
||||
from vllm.utils import (
|
||||
cuda_device_count_stateless,
|
||||
get_open_port,
|
||||
update_environment_variables,
|
||||
)
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
|
||||
@ray.remote
|
||||
class _CUDADeviceCountStatelessTestActor:
|
||||
|
||||
def get_count(self):
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
def get_cuda_visible_devices(self):
|
||||
return envs.CUDA_VISIBLE_DEVICES
|
||||
@ -34,10 +35,9 @@ def test_cuda_device_count_stateless():
|
||||
"""Test that cuda_device_count_stateless changes return value if
|
||||
CUDA_VISIBLE_DEVICES is changed."""
|
||||
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
||||
num_gpus=2).remote()
|
||||
assert len(
|
||||
sorted(ray.get(
|
||||
actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
||||
num_gpus=2
|
||||
).remote()
|
||||
assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
||||
assert ray.get(actor.get_count.remote()) == 2
|
||||
ray.get(actor.set_cuda_visible_devices.remote("0"))
|
||||
assert ray.get(actor.get_count.remote()) == 1
|
||||
@ -46,15 +46,13 @@ def test_cuda_device_count_stateless():
|
||||
|
||||
|
||||
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port1,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pg1 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||
)
|
||||
if rank <= 2:
|
||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port2,
|
||||
rank=rank,
|
||||
world_size=3)
|
||||
pg2 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port2, rank=rank, world_size=3
|
||||
)
|
||||
data = torch.tensor([rank])
|
||||
data = pg1.broadcast_obj(data, src=2)
|
||||
assert data.item() == 2
|
||||
@ -68,16 +66,14 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
|
||||
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
torch.cuda.set_device(rank)
|
||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port1,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pg1 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||
)
|
||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||
if rank <= 2:
|
||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port2,
|
||||
rank=rank,
|
||||
world_size=3)
|
||||
pg2 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port2, rank=rank, world_size=3
|
||||
)
|
||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||
data = torch.tensor([rank]).cuda()
|
||||
pynccl1.all_reduce(data)
|
||||
@ -96,10 +92,9 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
|
||||
|
||||
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port1,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pg1 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||
)
|
||||
if rank == 2:
|
||||
pg1.broadcast_obj("secret", src=2)
|
||||
else:
|
||||
@ -109,10 +104,9 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
||||
|
||||
|
||||
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port1,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pg1 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||
)
|
||||
data = pg1.all_gather_obj(rank)
|
||||
assert data == list(range(WORLD_SIZE))
|
||||
pg1.barrier()
|
||||
@ -121,7 +115,8 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
||||
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@pytest.mark.parametrize(
|
||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]
|
||||
)
|
||||
def test_stateless_process_group(worker):
|
||||
port1 = get_open_port()
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
@ -129,12 +124,14 @@ def test_stateless_process_group(worker):
|
||||
port2 = get_open_port()
|
||||
WORLD_SIZE = 4
|
||||
from multiprocessing import get_context
|
||||
|
||||
ctx = get_context("fork")
|
||||
processes = []
|
||||
for i in range(WORLD_SIZE):
|
||||
rank = i
|
||||
processes.append(
|
||||
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)))
|
||||
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))
|
||||
)
|
||||
for p in processes:
|
||||
p.start()
|
||||
for p in processes:
|
||||
|
||||
@ -10,22 +10,30 @@ from typing import Annotated, Literal, Optional, Union
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||
get_type, get_type_hints, is_not_builtin,
|
||||
is_type, literal_to_kwargs, optional_type,
|
||||
parse_type)
|
||||
from vllm.engine.arg_utils import (
|
||||
EngineArgs,
|
||||
contains_type,
|
||||
get_kwargs,
|
||||
get_type,
|
||||
get_type_hints,
|
||||
is_not_builtin,
|
||||
is_type,
|
||||
literal_to_kwargs,
|
||||
optional_type,
|
||||
parse_type,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type", "value", "expected"), [
|
||||
(int, "42", 42),
|
||||
(float, "3.14", 3.14),
|
||||
(str, "Hello World!", "Hello World!"),
|
||||
(json.loads, '{"foo":1,"bar":2}', {
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
}),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
("type", "value", "expected"),
|
||||
[
|
||||
(int, "42", 42),
|
||||
(float, "3.14", 3.14),
|
||||
(str, "Hello World!", "Hello World!"),
|
||||
(json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}),
|
||||
],
|
||||
)
|
||||
def test_parse_type(type, value, expected):
|
||||
parse_type_func = parse_type(type)
|
||||
assert parse_type_func(value) == expected
|
||||
@ -37,50 +45,56 @@ def test_optional_type():
|
||||
assert optional_type_func("42") == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
||||
(int, int, True),
|
||||
(int, float, False),
|
||||
(list[int], list, True),
|
||||
(list[int], tuple, False),
|
||||
(Literal[0, 1], Literal, True),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
("type_hint", "type", "expected"),
|
||||
[
|
||||
(int, int, True),
|
||||
(int, float, False),
|
||||
(list[int], list, True),
|
||||
(list[int], tuple, False),
|
||||
(Literal[0, 1], Literal, True),
|
||||
],
|
||||
)
|
||||
def test_is_type(type_hint, type, expected):
|
||||
assert is_type(type_hint, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
||||
({float, int}, int, True),
|
||||
({int, tuple}, int, True),
|
||||
({int, tuple[int]}, int, True),
|
||||
({int, tuple[int, ...]}, int, True),
|
||||
({int, tuple[int]}, float, False),
|
||||
({int, tuple[int, ...]}, float, False),
|
||||
({str, Literal["x", "y"]}, Literal, True),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
("type_hints", "type", "expected"),
|
||||
[
|
||||
({float, int}, int, True),
|
||||
({int, tuple}, int, True),
|
||||
({int, tuple[int]}, int, True),
|
||||
({int, tuple[int, ...]}, int, True),
|
||||
({int, tuple[int]}, float, False),
|
||||
({int, tuple[int, ...]}, float, False),
|
||||
({str, Literal["x", "y"]}, Literal, True),
|
||||
],
|
||||
)
|
||||
def test_contains_type(type_hints, type, expected):
|
||||
assert contains_type(type_hints, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
||||
({int, float}, int, int),
|
||||
({int, float}, str, None),
|
||||
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
("type_hints", "type", "expected"),
|
||||
[
|
||||
({int, float}, int, int),
|
||||
({int, float}, str, None),
|
||||
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_get_type(type_hints, type, expected):
|
||||
assert get_type(type_hints, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hints", "expected"), [
|
||||
({Literal[1, 2]}, {
|
||||
"type": int,
|
||||
"choices": [1, 2]
|
||||
}),
|
||||
({str, Literal["x", "y"]}, {
|
||||
"type": str,
|
||||
"metavar": ["x", "y"]
|
||||
}),
|
||||
({Literal[1, "a"]}, Exception),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
("type_hints", "expected"),
|
||||
[
|
||||
({Literal[1, 2]}, {"type": int, "choices": [1, 2]}),
|
||||
({str, Literal["x", "y"]}, {"type": str, "metavar": ["x", "y"]}),
|
||||
({Literal[1, "a"]}, Exception),
|
||||
],
|
||||
)
|
||||
def test_literal_to_kwargs(type_hints, expected):
|
||||
context = nullcontext()
|
||||
if expected is Exception:
|
||||
@ -123,22 +137,27 @@ class DummyConfig:
|
||||
"""Nested config"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||
(int, False),
|
||||
(DummyConfig, True),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
("type_hint", "expected"),
|
||||
[
|
||||
(int, False),
|
||||
(DummyConfig, True),
|
||||
],
|
||||
)
|
||||
def test_is_not_builtin(type_hint, expected):
|
||||
assert is_not_builtin(type_hint) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type_hint", "expected"), [
|
||||
("type_hint", "expected"),
|
||||
[
|
||||
(Annotated[int, "annotation"], {int}),
|
||||
(Optional[int], {int, type(None)}),
|
||||
(Annotated[Optional[int], "annotation"], {int, type(None)}),
|
||||
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
|
||||
],
|
||||
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"])
|
||||
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"],
|
||||
)
|
||||
def test_get_type_hints(type_hint, expected):
|
||||
assert get_type_hints(type_hint) == expected
|
||||
|
||||
@ -178,24 +197,16 @@ def test_get_kwargs():
|
||||
("arg", "expected"),
|
||||
[
|
||||
(None, dict()),
|
||||
('{"video": {"num_frames": 123} }', {
|
||||
"video": {
|
||||
"num_frames": 123
|
||||
}
|
||||
}),
|
||||
('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}),
|
||||
(
|
||||
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
|
||||
{
|
||||
"video": {
|
||||
"num_frames": 123,
|
||||
"fps": 1.0,
|
||||
"foo": "bar"
|
||||
},
|
||||
"image": {
|
||||
"foo": "bar"
|
||||
}
|
||||
}),
|
||||
])
|
||||
"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"},
|
||||
"image": {"foo": "bar"},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_media_io_kwargs_parser(arg, expected):
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
@ -230,24 +241,32 @@ def test_compilation_config():
|
||||
assert args.compilation_config.level == 3
|
||||
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args([
|
||||
"-O",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"use_inductor": false}',
|
||||
])
|
||||
assert (args.compilation_config.level == 3 and
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and not args.compilation_config.use_inductor)
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"-O",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"use_inductor": false}',
|
||||
]
|
||||
)
|
||||
assert (
|
||||
args.compilation_config.level == 3
|
||||
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and not args.compilation_config.use_inductor
|
||||
)
|
||||
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args([
|
||||
"--compilation-config="
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"use_inductor": true}',
|
||||
])
|
||||
assert (args.compilation_config.level == 3 and
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and args.compilation_config.use_inductor)
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--compilation-config="
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"use_inductor": true}',
|
||||
]
|
||||
)
|
||||
assert (
|
||||
args.compilation_config.level == 3
|
||||
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and args.compilation_config.use_inductor
|
||||
)
|
||||
|
||||
|
||||
def test_prefix_cache_default():
|
||||
@ -255,8 +274,7 @@ def test_prefix_cache_default():
|
||||
args = parser.parse_args([])
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args=args)
|
||||
assert (not engine_args.enable_prefix_caching
|
||||
), "prefix caching defaults to off."
|
||||
assert not engine_args.enable_prefix_caching, "prefix caching defaults to off."
|
||||
|
||||
# with flag to turn it on.
|
||||
args = parser.parse_args(["--enable-prefix-caching"])
|
||||
|
||||
@ -5,12 +5,12 @@ import pytest
|
||||
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||
"cherry_blossom":
|
||||
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||
})
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||
"cherry_blossom": "USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||
}
|
||||
)
|
||||
|
||||
models = ["llava-hf/llava-1.5-7b-hf"]
|
||||
|
||||
@ -19,8 +19,7 @@ models = ["llava-hf/llava-1.5-7b-hf"]
|
||||
def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="longer than the maximum model length"):
|
||||
with pytest.raises(ValueError, match="longer than the maximum model length"):
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
max_model_len=128, # LLaVA has a feature size of 576
|
||||
@ -29,6 +28,6 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||
)
|
||||
|
||||
with vllm_model:
|
||||
vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]],
|
||||
max_tokens=1,
|
||||
images=[images[0]])
|
||||
vllm_model.generate_greedy(
|
||||
[HF_IMAGE_PROMPTS[0]], max_tokens=1, images=[images[0]]
|
||||
)
|
||||
|
||||
@ -26,8 +26,10 @@ def sample_token_ids():
|
||||
|
||||
@pytest.fixture
|
||||
def sample_regex():
|
||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
return (
|
||||
r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -35,40 +37,27 @@ def sample_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"age": {
|
||||
"type": "integer"
|
||||
},
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"maxLength": 10
|
||||
},
|
||||
"minItems": 3
|
||||
"items": {"type": "string", "maxLength": 10},
|
||||
"minItems": 3,
|
||||
},
|
||||
"work_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "number"
|
||||
},
|
||||
"position": {
|
||||
"type": "string"
|
||||
}
|
||||
"company": {"type": "string"},
|
||||
"duration": {"type": "number"},
|
||||
"position": {"type": "string"},
|
||||
},
|
||||
"required": ["company", "position"]
|
||||
}
|
||||
}
|
||||
"required": ["company", "position"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["name", "age", "skills", "work_history"]
|
||||
"required": ["name", "age", "skills", "work_history"],
|
||||
}
|
||||
|
||||
|
||||
@ -80,65 +69,53 @@ def sample_complex_json_schema():
|
||||
"score": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"maximum": 100 # Numeric range
|
||||
"maximum": 100, # Numeric range
|
||||
},
|
||||
"grade": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-D]$" # Regex pattern
|
||||
"pattern": "^[A-D]$", # Regex pattern
|
||||
},
|
||||
"email": {
|
||||
"type": "string",
|
||||
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
|
||||
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$",
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern":
|
||||
"^[a-z]{1,10}$" # Combining length and pattern restrictions
|
||||
}
|
||||
}
|
||||
"pattern": "^[a-z]{1,10}$", # Combining length and pattern restrictions
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["score", "grade", "email", "tags"]
|
||||
"required": ["score", "grade", "email", "tags"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_definition_json_schema():
|
||||
return {
|
||||
'$defs': {
|
||||
'Step': {
|
||||
'properties': {
|
||||
'explanation': {
|
||||
'title': 'Explanation',
|
||||
'type': 'string'
|
||||
},
|
||||
'output': {
|
||||
'title': 'Output',
|
||||
'type': 'string'
|
||||
}
|
||||
"$defs": {
|
||||
"Step": {
|
||||
"properties": {
|
||||
"explanation": {"title": "Explanation", "type": "string"},
|
||||
"output": {"title": "Output", "type": "string"},
|
||||
},
|
||||
'required': ['explanation', 'output'],
|
||||
'title': 'Step',
|
||||
'type': 'object'
|
||||
"required": ["explanation", "output"],
|
||||
"title": "Step",
|
||||
"type": "object",
|
||||
}
|
||||
},
|
||||
'properties': {
|
||||
'steps': {
|
||||
'items': {
|
||||
'$ref': '#/$defs/Step'
|
||||
},
|
||||
'title': 'Steps',
|
||||
'type': 'array'
|
||||
"properties": {
|
||||
"steps": {
|
||||
"items": {"$ref": "#/$defs/Step"},
|
||||
"title": "Steps",
|
||||
"type": "array",
|
||||
},
|
||||
'final_answer': {
|
||||
'title': 'Final Answer',
|
||||
'type': 'string'
|
||||
}
|
||||
"final_answer": {"title": "Final Answer", "type": "string"},
|
||||
},
|
||||
'required': ['steps', 'final_answer'],
|
||||
'title': 'MathReasoning',
|
||||
'type': 'object'
|
||||
"required": ["steps", "final_answer"],
|
||||
"title": "MathReasoning",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
@ -149,64 +126,71 @@ def sample_enum_json_schema():
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive",
|
||||
"pending"] # Literal values using enum
|
||||
"enum": ["active", "inactive", "pending"], # Literal values using enum
|
||||
},
|
||||
"priority": {
|
||||
"type": "string",
|
||||
"enum": ["low", "medium", "high", "critical"]
|
||||
"enum": ["low", "medium", "high", "critical"],
|
||||
},
|
||||
"category": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["bug", "feature", "improvement"]
|
||||
"enum": ["bug", "feature", "improvement"],
|
||||
},
|
||||
"severity": {
|
||||
"type": "integer",
|
||||
"enum": [1, 2, 3, 4,
|
||||
5] # Enum can also contain numbers
|
||||
}
|
||||
"enum": [1, 2, 3, 4, 5], # Enum can also contain numbers
|
||||
},
|
||||
},
|
||||
"required": ["type", "severity"]
|
||||
"required": ["type", "severity"],
|
||||
},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": ["urgent", "blocked", "needs_review", "approved"]
|
||||
}
|
||||
}
|
||||
"enum": ["urgent", "blocked", "needs_review", "approved"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["status", "priority", "category", "flags"]
|
||||
"required": ["status", "priority", "category", "flags"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_structured_outputs_choices():
|
||||
return [
|
||||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
|
||||
"Ruby", "Swift", "Kotlin"
|
||||
"Python",
|
||||
"Java",
|
||||
"JavaScript",
|
||||
"C++",
|
||||
"C#",
|
||||
"PHP",
|
||||
"TypeScript",
|
||||
"Ruby",
|
||||
"Swift",
|
||||
"Kotlin",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sql_statements():
|
||||
return ("""
|
||||
return """
|
||||
start: select_statement
|
||||
select_statement: "SELECT" column "from" table "where" condition
|
||||
column: "col_1" | "col_2"
|
||||
table: "table_1" | "table_2"
|
||||
condition: column "=" number
|
||||
number: "1" | "2"
|
||||
""")
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def zephyr_lora_files():
|
||||
"""Download zephyr LoRA files once per test session."""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
|
||||
|
||||
|
||||
@ -214,5 +198,5 @@ def zephyr_lora_files():
|
||||
def opt125_lora_files() -> str:
|
||||
"""Download opt-125m LoRA files once per test session."""
|
||||
from huggingface_hub import snapshot_download
|
||||
return snapshot_download(
|
||||
repo_id="peft-internal-testing/opt-125m-dummy-lora")
|
||||
|
||||
return snapshot_download(repo_id="peft-internal-testing/opt-125m-dummy-lora")
|
||||
|
||||
@ -48,20 +48,23 @@ def run_test(model_name, more_args=None):
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert model_name in EXPECTED_VALUES, (
|
||||
f"Cannot find the expected value for the model {model_name=}")
|
||||
f"Cannot find the expected value for the model {model_name=}"
|
||||
)
|
||||
expected_value = EXPECTED_VALUES[model_name]
|
||||
assert (measured_value - RTOL < expected_value
|
||||
and measured_value + RTOL > expected_value
|
||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||
assert (
|
||||
measured_value - RTOL < expected_value
|
||||
and measured_value + RTOL > expected_value
|
||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||
|
||||
|
||||
# TODO: [AlexM] Fix it with new CI/CD tests
|
||||
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
|
||||
TPU_TP_TEST_STR = "" # "tensor_parallel_size=4"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 is currently only supported on CUDA and TPU")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda() and not current_platform.is_tpu(),
|
||||
reason="V1 is currently only supported on CUDA and TPU",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODEL_NAMES)
|
||||
def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Run with the V1 Engine."""
|
||||
@ -82,12 +85,14 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||
run_test(model, more_args)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 is currently only supported on CUDA and TPU")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda() and not current_platform.is_tpu(),
|
||||
reason="V1 is currently only supported on CUDA and TPU",
|
||||
)
|
||||
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
|
||||
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
|
||||
model, monkeypatch: pytest.MonkeyPatch):
|
||||
model, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
@ -14,9 +14,7 @@ from ..openai.test_vision import TEST_IMAGE_ASSETS
|
||||
def text_llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, seed=0)
|
||||
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
@ -28,14 +26,8 @@ def text_llm():
|
||||
def test_chat(text_llm):
|
||||
prompt1 = "Explain the concept of entropy."
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt1
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": prompt1},
|
||||
]
|
||||
outputs = text_llm.chat(messages)
|
||||
assert len(outputs) == 1
|
||||
@ -46,25 +38,13 @@ def test_multi_chat(text_llm):
|
||||
prompt2 = "Explain what among us is."
|
||||
|
||||
conversation1 = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt1
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": prompt1},
|
||||
]
|
||||
|
||||
conversation2 = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt2
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": prompt2},
|
||||
]
|
||||
|
||||
messages = [conversation1, conversation2]
|
||||
@ -94,26 +74,22 @@ def vision_llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_urls",
|
||||
[[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]],
|
||||
indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True
|
||||
)
|
||||
def test_chat_multi_image(vision_llm, image_urls: list[str]):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
} for image_url in image_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
for image_url in image_urls
|
||||
),
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
outputs = vision_llm.chat(messages)
|
||||
assert len(outputs) >= 0
|
||||
|
||||
@ -124,14 +100,8 @@ def test_llm_chat_tokenization_no_double_bos(text_llm):
|
||||
Check we get a single BOS token for llama chat.
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
outputs = text_llm.chat(messages)
|
||||
assert len(outputs) == 1
|
||||
@ -167,14 +137,8 @@ def thinking_llm():
|
||||
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||
def test_chat_extra_kwargs(thinking_llm, enable_thinking):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+1?"
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "What is 1+1?"},
|
||||
]
|
||||
|
||||
outputs = thinking_llm.chat(
|
||||
|
||||
@ -23,9 +23,11 @@ def test_collective_rpc(tp_size, backend, monkeypatch):
|
||||
return self.rank
|
||||
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
tensor_parallel_size=tp_size,
|
||||
distributed_executor_backend=backend)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
tensor_parallel_size=tp_size,
|
||||
distributed_executor_backend=backend,
|
||||
)
|
||||
assert llm.collective_rpc(echo_rank) == list(range(tp_size))
|
||||
|
||||
@ -29,11 +29,13 @@ TOKEN_IDS = [
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model=MODEL_NAME,
|
||||
max_num_batched_tokens=4096,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.10,
|
||||
enforce_eager=True)
|
||||
llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
max_num_batched_tokens=4096,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.10,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
@ -81,7 +83,8 @@ def test_max_model_len():
|
||||
outputs = llm.generate(PROMPTS, sampling_params)
|
||||
for output in outputs:
|
||||
num_total_tokens = len(output.prompt_token_ids) + len(
|
||||
output.outputs[0].token_ids)
|
||||
output.outputs[0].token_ids
|
||||
)
|
||||
# Total tokens must not exceed max_model_len + 1 (the last token can be
|
||||
# generated with the context length equal to the max model length)
|
||||
# It can be less if generation finishes due to other reasons (e.g., EOS)
|
||||
|
||||
@ -16,9 +16,8 @@ def test_gpu_memory_utilization():
|
||||
# makes sure gpu_memory_utilization is per-instance limit,
|
||||
# not a global limit
|
||||
llms = [
|
||||
LLM(model="facebook/opt-125m",
|
||||
gpu_memory_utilization=0.3,
|
||||
enforce_eager=True) for i in range(3)
|
||||
LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3, enforce_eager=True)
|
||||
for i in range(3)
|
||||
]
|
||||
for llm in llms:
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -8,12 +8,12 @@ from vllm import LLM
|
||||
|
||||
def test_empty_prompt():
|
||||
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
||||
with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
|
||||
with pytest.raises(ValueError, match="decoder prompt cannot be empty"):
|
||||
llm.generate([""])
|
||||
|
||||
|
||||
@pytest.mark.skip_v1
|
||||
def test_out_of_vocab_token():
|
||||
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
||||
with pytest.raises(ValueError, match='out of vocabulary'):
|
||||
with pytest.raises(ValueError, match="out of vocabulary"):
|
||||
llm.generate({"prompt_token_ids": [999999]})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for HF_HUB_OFFLINE mode"""
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import sys
|
||||
@ -91,12 +92,11 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
|
||||
def _re_import_modules():
|
||||
hf_hub_module_names = [
|
||||
k for k in sys.modules if k.startswith("huggingface_hub")
|
||||
]
|
||||
hf_hub_module_names = [k for k in sys.modules if k.startswith("huggingface_hub")]
|
||||
transformers_module_names = [
|
||||
k for k in sys.modules if k.startswith("transformers")
|
||||
and not k.startswith("transformers_modules")
|
||||
k
|
||||
for k in sys.modules
|
||||
if k.startswith("transformers") and not k.startswith("transformers_modules")
|
||||
]
|
||||
|
||||
reload_exception = None
|
||||
|
||||
@ -7,14 +7,14 @@ from vllm.assets.audio import AudioAsset
|
||||
|
||||
@pytest.fixture
|
||||
def mary_had_lamb():
|
||||
path = AudioAsset('mary_had_lamb').get_local_path()
|
||||
path = AudioAsset("mary_had_lamb").get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def winning_call():
|
||||
path = AudioAsset('winning_call').get_local_path()
|
||||
path = AudioAsset("winning_call").get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
@ -22,6 +22,6 @@ def winning_call():
|
||||
@pytest.fixture
|
||||
def foscolo():
|
||||
# Test translation it->en
|
||||
path = AudioAsset('azacinto_foscolo').get_local_path()
|
||||
path = AudioAsset("azacinto_foscolo").get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
@ -44,14 +44,15 @@ def run_test(more_args):
|
||||
print(f"Running with: {args}")
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, args,
|
||||
max_wait_seconds=MAX_WAIT_SECONDS) as remote_server:
|
||||
MODEL_NAME, args, max_wait_seconds=MAX_WAIT_SECONDS
|
||||
) as remote_server:
|
||||
url = f"{remote_server.url_for('v1')}/completions"
|
||||
|
||||
model_args = (
|
||||
f"model={MODEL_NAME},"
|
||||
f"base_url={url},"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
@ -60,15 +61,18 @@ def run_test(more_args):
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
assert (
|
||||
measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu()
|
||||
and not current_platform.is_xpu(),
|
||||
reason="V1 currently only supported on CUDA, XPU and TPU")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu()
|
||||
and not current_platform.is_xpu(),
|
||||
reason="V1 currently only supported on CUDA, XPU and TPU",
|
||||
)
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ a baseline.
|
||||
This simulates real work usage of the API and makes sure that the frontend and
|
||||
AsyncLLMEngine are working correctly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import time
|
||||
@ -45,7 +46,8 @@ async def transcribe_audio(client, tokenizer, y, sr):
|
||||
# NOTE there's no streaming in transcriptions, can't measure ttft
|
||||
latency = end_time - start_time
|
||||
num_output_tokens = len(
|
||||
tokenizer(transcription.text, add_special_tokens=False).input_ids)
|
||||
tokenizer(transcription.text, add_special_tokens=False).input_ids
|
||||
)
|
||||
return latency, num_output_tokens, transcription.text
|
||||
|
||||
|
||||
@ -73,8 +75,8 @@ async def process_dataset(model, client, data, concurrent_request):
|
||||
for sample in data:
|
||||
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||
task = asyncio.create_task(
|
||||
bound_transcribe(sem, client, tokenizer, (audio, sr),
|
||||
sample["text"]))
|
||||
bound_transcribe(sem, client, tokenizer, (audio, sr), sample["text"])
|
||||
)
|
||||
tasks.append(task)
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
@ -98,34 +100,35 @@ def print_performance_metrics(results, total_time):
|
||||
|
||||
|
||||
def add_duration(sample):
|
||||
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
|
||||
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
|
||||
y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||
sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000
|
||||
return sample
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs):
|
||||
def load_hf_dataset(dataset_repo: str, split="validation", **hf_kwargs):
|
||||
## Load and filter the dataset
|
||||
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
|
||||
if 'duration_ms' not in dataset[0]:
|
||||
if "duration_ms" not in dataset[0]:
|
||||
# compute duration to filter
|
||||
dataset = dataset.map(add_duration)
|
||||
|
||||
# Whisper max supported duration
|
||||
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
|
||||
dataset = dataset.filter(lambda example: example["duration_ms"] < 30000)
|
||||
return dataset
|
||||
|
||||
|
||||
def run_evaluation(model: str,
|
||||
client,
|
||||
dataset,
|
||||
max_concurrent_reqs: int,
|
||||
n_examples: int = -1,
|
||||
print_metrics: bool = True):
|
||||
def run_evaluation(
|
||||
model: str,
|
||||
client,
|
||||
dataset,
|
||||
max_concurrent_reqs: int,
|
||||
n_examples: int = -1,
|
||||
print_metrics: bool = True,
|
||||
):
|
||||
if n_examples > 0:
|
||||
dataset = dataset.select(range(n_examples))
|
||||
start = time.perf_counter()
|
||||
results = asyncio.run(
|
||||
process_dataset(model, client, dataset, max_concurrent_reqs))
|
||||
results = asyncio.run(process_dataset(model, client, dataset, max_concurrent_reqs))
|
||||
end = time.perf_counter()
|
||||
total_time = end - start
|
||||
print(f"Total Test Time: {total_time:.4f} seconds")
|
||||
@ -135,8 +138,7 @@ def run_evaluation(model: str,
|
||||
predictions = [res[2] for res in results]
|
||||
references = [res[3] for res in results]
|
||||
wer = load("wer")
|
||||
wer_score = 100 * wer.compute(references=references,
|
||||
predictions=predictions)
|
||||
wer_score = 100 * wer.compute(references=references, predictions=predictions)
|
||||
print("WER:", wer_score)
|
||||
return wer_score
|
||||
|
||||
@ -145,26 +147,25 @@ def run_evaluation(model: str,
|
||||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
|
||||
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"])
|
||||
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
|
||||
)
|
||||
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
||||
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
||||
@pytest.mark.parametrize("expected_wer", [12.744980])
|
||||
def test_wer_correctness(model_name,
|
||||
dataset_repo,
|
||||
expected_wer,
|
||||
n_examples=-1,
|
||||
max_concurrent_request=None):
|
||||
def test_wer_correctness(
|
||||
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
|
||||
):
|
||||
# TODO refactor to use `ASRDataset`
|
||||
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
|
||||
with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server:
|
||||
dataset = load_hf_dataset(dataset_repo)
|
||||
|
||||
if not max_concurrent_request:
|
||||
# No max concurrency
|
||||
max_concurrent_request = n_examples if n_examples > 0\
|
||||
else len(dataset)
|
||||
max_concurrent_request = n_examples if n_examples > 0 else len(dataset)
|
||||
|
||||
client = remote_server.get_async_client()
|
||||
wer = run_evaluation(model_name, client, dataset,
|
||||
max_concurrent_request, n_examples)
|
||||
wer = run_evaluation(
|
||||
model_name, client, dataset, max_concurrent_request, n_examples
|
||||
)
|
||||
if expected_wer:
|
||||
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
|
||||
|
||||
@ -44,15 +44,11 @@ async def client(server):
|
||||
ids=["completion", "chat"],
|
||||
argnames=["create_func_gen", "content_body"],
|
||||
argvalues=[
|
||||
(lambda x: x.completions.create, {
|
||||
"prompt": " ".join(['A'] * 10_000)
|
||||
}),
|
||||
(lambda x: x.chat.completions.create, {
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": " ".join(['A'] * 10_000)
|
||||
}]
|
||||
}),
|
||||
(lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}),
|
||||
(
|
||||
lambda x: x.chat.completions.create,
|
||||
{"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_with_and_without_truncate(
|
||||
@ -65,15 +61,15 @@ async def test_with_and_without_truncate(
|
||||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
||||
|
||||
num_requests = 10
|
||||
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
|
||||
(num_requests - num_requests // 2))
|
||||
truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * (
|
||||
num_requests - num_requests // 2
|
||||
)
|
||||
random.shuffle(truncate_prompt_tokens)
|
||||
|
||||
bodies = [{
|
||||
**body, "extra_body": {
|
||||
'truncate_prompt_tokens': t
|
||||
}
|
||||
} for t in truncate_prompt_tokens]
|
||||
bodies = [
|
||||
{**body, "extra_body": {"truncate_prompt_tokens": t}}
|
||||
for t in truncate_prompt_tokens
|
||||
]
|
||||
|
||||
async def get_status_code(**kwargs):
|
||||
try:
|
||||
|
||||
@ -56,24 +56,18 @@ def base64_encoded_audio() -> dict[str, str]:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
model_name: str, audio_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_single_chat_session_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@ -82,13 +76,15 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
temperature=0.0,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -110,56 +106,52 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": audio_url
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_error_on_invalid_audio_url_type(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": audio_url},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# audio_url should be a dict {"url": "some url"}, not directly a string
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0)
|
||||
_ = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_audio_base64encoded(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||
base64_encoded_audio: dict[str, str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url":
|
||||
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str,
|
||||
base64_encoded_audio: dict[str, str],
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@ -168,13 +160,15 @@ async def test_single_chat_session_audio_base64encoded(
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
temperature=0.0,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -198,25 +192,26 @@ async def test_single_chat_session_audio_base64encoded(
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_input_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||
base64_encoded_audio: dict[str, str]):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str,
|
||||
base64_encoded_audio: dict[str, str],
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@ -224,13 +219,15 @@ async def test_single_chat_session_input_audio(
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -252,24 +249,18 @@ async def test_single_chat_session_input_audio(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||
async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
||||
model_name: str, audio_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_chat_streaming_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@ -309,27 +300,27 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
||||
model_name: str, audio_url: str,
|
||||
base64_encoded_audio: dict[str,
|
||||
str]):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_chat_streaming_input_audio(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str,
|
||||
base64_encoded_audio: dict[str, str],
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@ -369,26 +360,23 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize(
|
||||
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]])
|
||||
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
audio_urls: list[str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*({
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
}
|
||||
} for audio_url in audio_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]
|
||||
)
|
||||
async def test_multi_audio_input(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str]
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}}
|
||||
for audio_url in audio_urls
|
||||
),
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
if len(audio_urls) > MAXIMUM_AUDIOS:
|
||||
with pytest.raises(openai.BadRequestError): # test multi-audio input
|
||||
|
||||
@ -16,9 +16,9 @@ from ...utils import RemoteOpenAIServer
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture(scope="module")
|
||||
def server_args(request: pytest.FixtureRequest) -> list[str]:
|
||||
""" Provide extra arguments to the server via indirect parametrization
|
||||
"""Provide extra arguments to the server via indirect parametrization
|
||||
|
||||
Usage:
|
||||
|
||||
@ -80,8 +80,10 @@ async def client(server):
|
||||
"server_args",
|
||||
[
|
||||
pytest.param([], id="default-frontend-multiprocessing"),
|
||||
pytest.param(["--disable-frontend-multiprocessing"],
|
||||
id="disable-frontend-multiprocessing")
|
||||
pytest.param(
|
||||
["--disable-frontend-multiprocessing"],
|
||||
id="disable-frontend-multiprocessing",
|
||||
),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@ -97,8 +99,10 @@ async def test_show_version(server: RemoteOpenAIServer):
|
||||
"server_args",
|
||||
[
|
||||
pytest.param([], id="default-frontend-multiprocessing"),
|
||||
pytest.param(["--disable-frontend-multiprocessing"],
|
||||
id="disable-frontend-multiprocessing")
|
||||
pytest.param(
|
||||
["--disable-frontend-multiprocessing"],
|
||||
id="disable-frontend-multiprocessing",
|
||||
),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@ -112,11 +116,13 @@ async def test_check_health(server: RemoteOpenAIServer):
|
||||
@pytest.mark.parametrize(
|
||||
"server_args",
|
||||
[
|
||||
pytest.param(["--max-model-len", "10100"],
|
||||
id="default-frontend-multiprocessing"),
|
||||
pytest.param(
|
||||
["--max-model-len", "10100"], id="default-frontend-multiprocessing"
|
||||
),
|
||||
pytest.param(
|
||||
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
|
||||
id="disable-frontend-multiprocessing")
|
||||
id="disable-frontend-multiprocessing",
|
||||
),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@ -131,14 +137,16 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
||||
# Request about 2 million tokens
|
||||
for _ in range(200):
|
||||
task = asyncio.create_task(
|
||||
client.chat.completions.create(messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10000,
|
||||
extra_body={"min_tokens": 10000}))
|
||||
client.chat.completions.create(
|
||||
messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10000,
|
||||
extra_body={"min_tokens": 10000},
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.ALL_COMPLETED)
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
# Make sure all requests were sent to the server and timed out
|
||||
# (We don't want to hide other errors like 400s that would invalidate this
|
||||
@ -151,16 +159,15 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
||||
# If the server had not cancelled all the other requests, then it would not
|
||||
# be able to respond to this one within the timeout
|
||||
client = server.get_async_client(timeout=5)
|
||||
response = await client.chat.completions.create(messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10)
|
||||
response = await client.chat.completions.create(
|
||||
messages=chat_input, model=MODEL_NAME, max_tokens=10
|
||||
)
|
||||
|
||||
assert len(response.choices) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
||||
|
||||
chat_input = [{"role": "user", "content": "Write a long story"}]
|
||||
client = server.get_async_client()
|
||||
|
||||
@ -169,17 +176,13 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
||||
messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10000,
|
||||
extra_headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded"
|
||||
})
|
||||
extra_headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server_args",
|
||||
[
|
||||
pytest.param(["--enable-server-load-tracking"],
|
||||
id="enable-server-load-tracking")
|
||||
],
|
||||
[pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@ -202,7 +205,8 @@ async def test_server_load(server: RemoteOpenAIServer):
|
||||
|
||||
# Start the completion request in a background thread.
|
||||
completion_future = asyncio.create_task(
|
||||
asyncio.to_thread(make_long_completion_request))
|
||||
asyncio.to_thread(make_long_completion_request)
|
||||
)
|
||||
|
||||
# Give a short delay to ensure the request has started.
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user