Improve configs - the rest! (#17562)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-09 23:18:44 +01:00
committed by GitHub
parent 7e3571134f
commit 4b2ed7926a
14 changed files with 456 additions and 340 deletions

View File

@ -9,7 +9,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test
@ -95,9 +95,6 @@ def test_full_graph(
run_model(optimization_level, model, model_kwargs)
PassConfig = CompilationConfig.PassConfig
# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
"compilation_config, model_info",

View File

@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from .backend import TestBackend
@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch.set_default_device("cuda")
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config= \
CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True))
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)

View File

@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \
CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
PassConfig(enable_fusion=True, enable_noop=True)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)

View File

@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn_maybe, is_func)
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
VllmConfig)
PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(
enable_sequence_parallelism=True, ), )
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
enable_sequence_parallelism=True))
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config

View File

@ -6,7 +6,7 @@ import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from .backend import TestBackend
@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True))
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass)

View File

@ -206,7 +206,7 @@ def _compare_sp(
'compile_sizes': [4, 8],
'splitting_ops': [],
'pass_config': {
'enable_sequence_parallism': sp_enabled,
'enable_sequence_parallelism': sp_enabled,
'enable_noop': True,
'enable_fusion': True,
},
@ -223,7 +223,7 @@ def _compare_sp(
"--distributed-executor-backend",
distributed_backend,
"--compilation_config",
str(compilation_config),
json.dumps(compilation_config),
]
tp_env = {

View File

@ -8,21 +8,18 @@ from typing import Literal, Optional
import pytest
from vllm.config import config
from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs,
optional_type)
optional_type, parse_type)
from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("type", "value", "expected"), [
(int, "42", 42),
(int, "None", None),
(float, "3.14", 3.14),
(float, "None", None),
(str, "Hello World!", "Hello World!"),
(str, "None", None),
(json.loads, '{"foo":1,"bar":2}', {
"foo": 1,
"bar": 2
@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
"foo": 1,
"bar": 2
}),
(json.loads, "None", None),
])
def test_optional_type(type, value, expected):
optional_type_func = optional_type(type)
def test_parse_type(type, value, expected):
parse_type_func = parse_type(type)
context = nullcontext()
if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning)
with context:
assert optional_type_func(value) == expected
assert parse_type_func(value) == expected
def test_optional_type():
optional_type_func = optional_type(int)
assert optional_type_func("None") is None
assert optional_type_func("42") == 42
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
@config
@dataclass
class DummyConfigClass:
class NestedConfig:
field: int = 1
"""field"""
@config
@dataclass
class FromCliConfig1:
field: int = 1
"""field"""
@classmethod
def from_cli(cls, cli_value: str):
inst = cls(**json.loads(cli_value))
inst.field += 1
return inst
@config
@dataclass
class FromCliConfig2:
field: int = 1
"""field"""
@classmethod
def from_cli(cls, cli_value: str):
inst = cls(**json.loads(cli_value))
inst.field += 2
return inst
@config
@dataclass
class DummyConfig:
regular_bool: bool = True
"""Regular bool with default True"""
optional_bool: Optional[bool] = None
@ -108,18 +143,24 @@ class DummyConfigClass:
"""Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
"""Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
"""Nested config"""
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
"""Config with from_cli method"""
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
"""Different config with from_cli method"""
@pytest.mark.parametrize(("type_hint", "expected"), [
(int, False),
(DummyConfigClass, True),
(DummyConfig, True),
])
def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected
def test_get_kwargs():
kwargs = get_kwargs(DummyConfigClass)
kwargs = get_kwargs(DummyConfig)
print(kwargs)
# bools should not have their type set
@ -142,6 +183,11 @@ def test_get_kwargs():
# dict should have json tip in help
json_tip = "\n\nShould be a valid JSON string."
assert kwargs["json_tip"]["help"].endswith(json_tip)
# nested config should should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
# from_cli configs should be constructed with the correct method
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
@pytest.mark.parametrize(("arg", "expected"), [
@ -177,7 +223,7 @@ def test_compilation_config():
# default value
args = parser.parse_args([])
assert args.compilation_config is None
assert args.compilation_config == CompilationConfig()
# set to O3
args = parser.parse_args(["-O3"])
@ -194,7 +240,7 @@ def test_compilation_config():
# set to string form of a dict
args = parser.parse_args([
"--compilation-config",
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
@ -202,7 +248,7 @@ def test_compilation_config():
# set to string form of a dict
args = parser.parse_args([
"--compilation-config="
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])

View File

@ -4,7 +4,7 @@ import time
import torch
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import PassConfig, VllmConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
@ -56,10 +56,7 @@ class VllmInductorPass(InductorPass):
class PrinterInductorPass(VllmInductorPass):
def __init__(self,
name: str,
config: CompilationConfig.PassConfig,
always=False):
def __init__(self, name: str, config: PassConfig, always=False):
super().__init__(config)
self.name = name
self.always = always

View File

@ -11,8 +11,8 @@ import textwrap
import warnings
from collections import Counter
from contextlib import contextmanager
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
is_dataclass, replace)
from functools import cached_property
from importlib.util import find_spec
from pathlib import Path
@ -20,7 +20,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, get_origin)
import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated
@ -57,7 +56,7 @@ if TYPE_CHECKING:
ConfigType = type[DataclassInstance]
else:
QuantizationConfig = None
QuantizationConfig = Any
ConfigType = type
logger = init_logger(__name__)
@ -169,6 +168,12 @@ def config(cls: ConfigT) -> ConfigT:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
If a `ConfigT` is used as a CLI argument itself, the default value provided
by `get_kwargs` will be the result parsing a JSON string as the kwargs
(i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
requires custom construction from CLI (i.e. `CompilationConfig`), it can
have a `from_cli` method, which will be called instead.
"""
if not is_dataclass(cls):
raise TypeError("The decorated class must be a dataclass.")
@ -202,7 +207,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields.get(name)
named_field: Field = cls_fields[name]
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
@ -211,6 +216,10 @@ def get_field(cls: ConfigType, name: str) -> Field:
f"{cls.__name__}.{name} must have a default value or default factory.")
def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
@ -2007,13 +2016,13 @@ class SchedulerConfig:
def __post_init__(self) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
logger.warning(
logger.warning_once(
"max_model_len was is not set. Defaulting to arbitrary value "
"of %d.", self.max_model_len)
if self.max_num_seqs is None:
self.max_num_seqs = 128
logger.warning(
logger.warning_once(
"max_num_seqs was is not set. Defaulting to arbitrary value "
"of %d.", self.max_num_seqs)
@ -2840,8 +2849,8 @@ class PromptAdapterConfig:
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
limit_per_prompt: dict[str, int] = get_field(ModelConfig,
"limit_mm_per_prompt")
limit_per_prompt: dict[str, int] = \
cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt"))
"""
The maximum number of input items allowed per prompt for each modality.
Defaults to 1 (V0) or 999 (V1) for each modality.
@ -3415,41 +3424,49 @@ class ObservabilityConfig:
self.collect_detailed_traces[0].split(","))
class KVTransferConfig(BaseModel):
KVProducer = Literal["kv_producer", "kv_both"]
KVConsumer = Literal["kv_consumer", "kv_both"]
KVRole = Literal[KVProducer, KVConsumer]
@config
@dataclass
class KVTransferConfig:
"""Configuration for distributed KV cache transfer."""
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector: Optional[str] = None
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device: Optional[str] = "cuda"
"""The device used by kv connector to buffer the KV cache.
Currently only support 'cuda'."""
# The buffer size for TorchDistributedConnector. Measured in number of
# bytes. Recommended value: 1e9 (about 1GB).
kv_buffer_size: float = 1e9
"""The buffer size for TorchDistributedConnector. Measured in number of
bytes. Recommended value: 1e9 (about 1GB)."""
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
# are 'kv_producer', 'kv_consumer', and 'both'.
kv_role: Optional[str] = None
kv_role: Optional[KVRole] = None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
are 'kv_producer', 'kv_consumer', and 'both'."""
# The rank of this vLLM instance in the KV cache transfer. Typical value:
# 0 for prefill instance, 1 for decode instance.
# Currently only 1P1D is supported.
kv_rank: Optional[int] = None
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
0 for prefill instance, 1 for decode instance.
Currently only 1P1D is supported."""
# The number of parallel instances for KV cache transfer. For
# PyNcclConnector, this should be 2.
kv_parallel_size: int = 1
"""The number of parallel instances for KV cache transfer. For
PyNcclConnector, this should be 2."""
# The KV connector ip, used to build distributed connection
kv_ip: str = "127.0.0.1"
"""The KV connector ip, used to build distributed connection."""
# The KV connector port, used to build distributed connection
kv_port: int = 14579
"""The KV connector port, used to build distributed connection."""
# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
"""any extra config that the connector may need."""
def compute_hash(self) -> str:
"""
@ -3470,46 +3487,37 @@ class KVTransferConfig(BaseModel):
usedforsecurity=False).hexdigest()
return hash_str
@classmethod
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
"""Parse the CLI value for the kv cache transfer config."""
return KVTransferConfig.model_validate_json(cli_value)
def model_post_init(self, __context: Any) -> None:
if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"
]:
raise ValueError(
f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`")
def __post_init__(self) -> None:
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are {get_args(KVRole)}")
if self.kv_connector is not None and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
f"is set, supported roles are {get_args(KVRole)}")
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]
self.kv_role in get_args(KVRole)
@property
def is_kv_producer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_both"]
self.kv_role in get_args(KVProducer)
@property
def is_kv_consumer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]
self.kv_role in get_args(KVConsumer)
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)
class KVEventsConfig(BaseModel):
@config
@dataclass
class KVEventsConfig:
"""Configuration for KV event publishing."""
enable_kv_cache_events: bool = False
@ -3548,11 +3556,6 @@ class KVEventsConfig(BaseModel):
this topic to receive events.
"""
@classmethod
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
"""Parse the CLI value for the event publisher config."""
return KVEventsConfig.model_validate_json(cli_value)
class CompilationLevel:
# constants for the levels of the compilation process
@ -3562,80 +3565,72 @@ class CompilationLevel:
PIECEWISE = 3
class CompilationConfig(BaseModel):
"""
Configuration for compilation.
It has three parts:
@config
@dataclass
class PassConfig:
"""Configuration for custom Inductor passes.
This is separate from general `CompilationConfig` so that inductor passes
don't all have access to full configuration - that would create a cycle as
the `PassManager` is set as a property of config."""
dump_graph_stages: list[str] = field(default_factory=list)
"""List of stages for which we want to dump the graph. Each pass defines
its own stages (before, after, maybe in-between)."""
dump_graph_dir: Path = Path(".")
"""Directory to dump the graphs."""
# TODO(luka) better pass enabling system.
enable_fusion: bool = True
"""Whether to enable the custom fusion pass."""
enable_noop: bool = True
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism: bool = False
"""Whether to enable sequence parallelism."""
def uuid(self):
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
include = {
"enable_fusion", "enable_noop", "enable_sequence_parallelism"
}
dict_ = {k: v for k, v in asdict(self).items() if k in include}
return InductorPass.hash_dict(dict_)
def __post_init__(self) -> None:
if not self.enable_noop and self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work")
@config
@dataclass
class CompilationConfig:
"""Configuration for compilation. It has three parts:
- Top-level Compilation control:
- level: the level of compilation.
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- debug_dump_path: the path to dump the debug information.
- cache_dir: the directory to store the compiled graph, to
accelerate Inductor compilation. By default, it will use
model-related information to generate a cache directory.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the backend function.
We use string to avoid serialization issues when using compilation in a distributed setting.
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
or disable (prefixed with a '-').
Examples:
- 'all,-op1' to enable all except op1
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor).
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
- {attr}`level`
- {attr}`debug_dump_path`
- {attr}`cache_dir`
- {attr}`backend`
- {attr}`custom_ops`
- {attr}`splitting_ops`
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config.
- list[int]: capture sizes are specified as given.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- full_cuda_graph: whether to use a full cuda graph for the entire forward
pass rather than splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
- {attr}`use_cudagraph`
- {attr}`cudagraph_capture_sizes`
- {attr}`cudagraph_num_of_warmups`
- {attr}`cudagraph_copy_inputs`
- {attr}`full_cuda_graph`
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for compile_sizes,
using configurations in inductor_compile_config.
- compile_sizes: sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes: see PassConfig for more details
- {attr}`use_inductor`
- {attr}`compile_sizes`
- {attr}`inductor_compile_config`
- {attr}`inductor_passes`
- custom inductor passes
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
@ -3646,83 +3641,135 @@ class CompilationConfig(BaseModel):
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
""" # noqa
"""
# Top-level Compilation control
level: int = 0
"""The level of compilation:
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation."""
debug_dump_path: str = ""
"""The path to dump the debug information."""
cache_dir: str = ""
"""The directory to store the compiled graph, to accelerate Inductor
compilation. By default, it will use model-related information to generate
a cache directory."""
backend: str = ""
custom_ops: list[str] = Field(default_factory=list)
splitting_ops: list[str] = Field(default=None) # type: ignore
"""The backend for compilation. It needs to be a string:
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the
backend function.
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation level is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
(it sees a part of the graph)."""
custom_ops: list[str] = field(default_factory=list)
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
to enable all, 'none' to disable all. Also specify a list of custom op
names to enable (prefixed with a '+'), or disable (prefixed with a '-').
Examples:
- 'all,-op1' to enable all except op1
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor (compile_level >= Inductor)."""
splitting_ops: list[str] = field(default_factory=list)
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
# Inductor capture
use_inductor: bool = True
compile_sizes: Optional[list[Union[int, str]]] = Field(default=None)
inductor_compile_config: dict = Field(default_factory=dict)
inductor_passes: dict[str, str] = Field(default_factory=dict)
"""Whether to use inductor compilation:
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for compile_sizes,
using configurations in inductor_compile_config."""
compile_sizes: Optional[list[Union[int, str]]] = None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
inductor_compile_config: dict = field(default_factory=dict)
"""Additional configurations for inductor.
- None: use default configurations."""
inductor_passes: dict[str, str] = field(default_factory=dict)
"""Additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses JSON format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
use_cudagraph: bool = False
"""Whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future."""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs."""
cudagraph_capture_sizes: Optional[list[int]] = None
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config.
- list[int]: capture sizes are specified as given."""
cudagraph_copy_inputs: bool = False
"""Whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False."""
full_cuda_graph: bool = False
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
class PassConfig(BaseModel):
"""
Configuration for custom Inductor passes.
This is separate from general CompilationConfig so that inductor passes
don't all have access to full configuration - that would create a cycle
as the PassManager is set as a property of config.
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graphs. Default is .
- enable_fusion: whether to enable the custom fusion pass.
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
- enable_sequence_parallelism: whether to enable sequence parallelism.
"""
dump_graph_stages: list[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_noop: bool = True
enable_sequence_parallelism: bool = False
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
def uuid(self):
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
"enable_sequence_parallelism"})
return InductorPass.hash_dict(dict_)
def model_post_init(self, __context: Any) -> None:
if not self.enable_noop and self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work")
pass_config: PassConfig = Field(default_factory=PassConfig)
# not configurable, computed after init
max_capture_size: int = PrivateAttr
local_cache_dir: str = PrivateAttr # local cache dir for each rank
# optimization:
# Intuitively, bs_to_padded_graph_size should be dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
# we can optimize it to list[int] for better lookup performance.
bs_to_padded_graph_size: list[int] = PrivateAttr
max_capture_size: int = field(default=None, init=False) # type: ignore
"""not configurable, computed after init"""
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False)
"""optimization:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_capture_size],
we can optimize it to list[int] for better lookup performance."""
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
traced_files: set[str] = PrivateAttr
compilation_time: float = PrivateAttr
enabled_custom_ops: Counter[str] = field(default_factory=Counter,
init=False)
"""custom ops that are enabled"""
disabled_custom_ops: Counter[str] = field(default_factory=Counter,
init=False)
"""custom ops that are disabled"""
traced_files: set[str] = field(default_factory=set, init=False)
"""files that are traced for compilation"""
compilation_time: float = field(default=0.0, init=False)
"""time taken for compilation"""
# Per-model forward context
# Map from layer name to layer objects that need to be accessed outside
# model code, e.g., Attention, FusedMOE when dp_size>1.
static_forward_context: dict[str, Any] = PrivateAttr
static_forward_context: dict[str, Any] = field(default_factory=dict,
init=False)
"""Per-model forward context
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
def compute_hash(self) -> str:
"""
@ -3757,7 +3804,17 @@ class CompilationConfig(BaseModel):
"pass_config",
"traced_files",
}
return self.model_dump_json(exclude=exclude, exclude_unset=True)
include = dict()
for k, v in asdict(self).items():
if k in exclude:
continue
f = get_field(CompilationConfig, k)
if (d := f.default) is not MISSING and d == v:
continue
if (df := f.default_factory) is not MISSING and df() == v:
continue
include[k] = v
return json.dumps(include)
__str__ = __repr__
@ -3766,12 +3823,9 @@ class CompilationConfig(BaseModel):
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
# do not use `eval`, it is dangerous and can execute arbitrary code
dict_value = ast.literal_eval(cli_value)
return CompilationConfig.model_validate(dict_value)
def model_post_init(self, __context: Any) -> None:
return cls(**json.loads(cli_value))
def __post_init__(self) -> None:
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
@ -3789,9 +3843,6 @@ class CompilationConfig(BaseModel):
if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False
if self.splitting_ops is None:
self.splitting_ops = []
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
@ -3808,11 +3859,8 @@ class CompilationConfig(BaseModel):
self.inductor_compile_config[k] = func if isinstance(
func, InductorPass) else CallableInductorPass(func)
self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()
self.traced_files = set()
self.static_forward_context = {}
self.compilation_time = 0.0
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
@ -3899,39 +3947,67 @@ class CompilationConfig(BaseModel):
]
@config
@dataclass
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config: ModelConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default_factory=ParallelConfig,
init=True)
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig,
init=True)
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
model_config: ModelConfig = field(default_factory=ModelConfig)
"""Model configuration."""
cache_config: CacheConfig = field(default_factory=CacheConfig)
"""Cache configuration."""
parallel_config: ParallelConfig = field(default_factory=ParallelConfig)
"""Parallel configuration."""
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig)
"""Scheduler configuration."""
device_config: DeviceConfig = field(default_factory=DeviceConfig)
"""Device configuration."""
load_config: LoadConfig = field(default_factory=LoadConfig)
"""Load configuration."""
lora_config: Optional[LoRAConfig] = None
speculative_config: SpeculativeConfig = field(default=None,
init=True) # type: ignore
"""LoRA configuration."""
speculative_config: Optional[SpeculativeConfig] = None
"""Speculative decoding configuration."""
decoding_config: Optional[DecodingConfig] = None
"""Decoding configuration."""
observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration."""
prompt_adapter_config: Optional[PromptAdapterConfig] = None
"""Prompt adapter configuration."""
quant_config: Optional[QuantizationConfig] = None
compilation_config: CompilationConfig = field(default=None,
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
"""Quantization configuration."""
compilation_config: CompilationConfig = field(
default_factory=CompilationConfig)
"""`torch.compile` configuration for the model.
When it is a number (0, 1, 2, 3), it will be interpreted as the
optimization level.
NOTE: level 0 is the default level without any optimization. level 1 and 2
are for internal testing only. level 3 is the recommended level for
production.
Following the convention of traditional compilers, using `-O` without space
is also supported. `-O3` is equivalent to `-O 3`.
You can specify the full compilation config like so:
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
"""
kv_transfer_config: Optional[KVTransferConfig] = None
"""The configurations for distributed KV cache transfer."""
kv_events_config: Optional[KVEventsConfig] = None
"""The configurations for event publishing."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
additional_config: SupportsHash = field(default=None,
init=True) # type: ignore
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
"""Additional config for specified platform. Different platforms may
support different configs. Make sure the configs are valid for the platform
you are using. Contents must be hashable."""
instance_id: str = ""
"""The ID of the vLLM instance."""
def compute_hash(self) -> str:
"""
@ -4012,7 +4088,14 @@ class VllmConfig:
else:
vllm_factors.append("None")
if self.additional_config:
vllm_factors.append(self.additional_config.compute_hash())
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(
json.dumps(additional_config, sort_keys=True).encode(),
usedforsecurity=False,
).hexdigest()
else:
additional_config_hash = additional_config.compute_hash()
vllm_factors.append(additional_config_hash)
else:
vllm_factors.append("None")
factors.append(vllm_factors)

View File

@ -5,6 +5,7 @@ import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import asdict
from itertools import count
from queue import Queue
from typing import Any, Callable, Optional, Union
@ -284,7 +285,7 @@ class EventPublisherFactory:
if not config:
return NullEventPublisher()
config_dict = config.model_dump()
config_dict = asdict(config)
kind = config_dict.pop("publisher", "null")
config_dict.pop("enable_kv_cache_events")

View File

@ -7,10 +7,10 @@ import json
import re
import threading
import warnings
from dataclasses import MISSING, dataclass, fields
from dataclasses import MISSING, dataclass, fields, is_dataclass
from itertools import permutations
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin)
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
Type, TypeVar, Union, cast, get_args, get_origin)
import torch
from typing_extensions import TypeIs, deprecated
@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
is_in_ray_actor)
# yapf: enable
@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
def _parse_type(val: str) -> T:
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
@ -62,14 +60,24 @@ def optional_type(
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _parse_type
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
return parse_type(return_type)(val)
return _optional_type
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val):
return str(val)
else:
return optional_type(json.loads)(val)
return optional_type(json.loads)(val)
@deprecated(
@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) in {Union, Annotated}:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# If the field is a dataclass, we can use the model_validate_json
generator = (th for th in type_hints if is_dataclass(th))
dataclass_cls = next(generator, None)
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
if field.default is not MISSING:
default = field.default
elif field.default_factory is not MISSING:
if is_dataclass(field.default_factory) and is_in_doc_build():
default = {}
else:
default = field.default_factory()
# Get the help text for the field
name = field.name
@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
json_tip = "\n\nShould be a valid JSON string."
if contains_type(type_hints, bool):
if dataclass_cls is not None:
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
# Special case for configs with a from_cli method
if hasattr(dataclass_cls, "from_cli"):
from_cli = dataclass_cls.from_cli
dataclass_init = lambda x, f=from_cli: f(x)
kwargs[name]["type"] = dataclass_init
kwargs[name]["help"] += json_tip
elif contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
kwargs[name]["type"] = parse_type(json.loads)
kwargs[name]["help"] += json_tip
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
@ -771,63 +795,20 @@ class EngineArgs:
scheduler_group.add_argument("--scheduler-cls",
**scheduler_kwargs["scheduler_cls"])
# Compilation arguments
# compilation_kwargs = get_kwargs(CompilationConfig)
compilation_group = parser.add_argument_group(
title="CompilationConfig",
description=CompilationConfig.__doc__,
)
compilation_group.add_argument(
"--compilation-config",
"-O",
type=CompilationConfig.from_cli,
default=None,
help="torch.compile configuration for the model. "
"When it is a number (0, 1, 2, 3), it will be "
"interpreted as the optimization level.\n"
"NOTE: level 0 is the default level without "
"any optimization. level 1 and 2 are for internal "
"testing only. level 3 is the recommended level "
"for production.\n"
"To specify the full compilation config, "
"use a JSON string, e.g. ``{\"level\": 3, "
"\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n"
"Following the convention of traditional "
"compilers, using ``-O`` without space is also "
"supported. ``-O3`` is equivalent to ``-O 3``.")
# KVTransfer arguments
# kv_transfer_kwargs = get_kwargs(KVTransferConfig)
kv_transfer_group = parser.add_argument_group(
title="KVTransferConfig",
description=KVTransferConfig.__doc__,
)
kv_transfer_group.add_argument(
"--kv-transfer-config",
type=KVTransferConfig.from_cli,
default=None,
help="The configurations for distributed KV cache "
"transfer. Should be a JSON string.")
kv_transfer_group.add_argument(
'--kv-events-config',
type=KVEventsConfig.from_cli,
default=None,
help='The configurations for event publishing.')
# vLLM arguments
# vllm_kwargs = get_kwargs(VllmConfig)
vllm_kwargs = get_kwargs(VllmConfig)
vllm_group = parser.add_argument_group(
title="VllmConfig",
description=VllmConfig.__doc__,
)
vllm_group.add_argument(
"--additional-config",
type=json.loads,
default=None,
help="Additional config for specified platform in JSON format. "
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{\"config_key\":\"config_value\"}'")
vllm_group.add_argument("--kv-transfer-config",
**vllm_kwargs["kv_transfer_config"])
vllm_group.add_argument('--kv-events-config',
**vllm_kwargs["kv_events_config"])
vllm_group.add_argument("--compilation-config", "-O",
**vllm_kwargs["compilation_config"])
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])
# Other arguments
parser.add_argument('--use-v2-block-manager',

View File

@ -13,7 +13,8 @@ from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
is_init_field)
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine
@ -204,9 +205,13 @@ class LLM:
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
str(compilation_config))
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(
level=compilation_config)
elif isinstance(compilation_config, dict):
predicate = lambda x: is_init_field(CompilationConfig, x[0])
compilation_config_instance = CompilationConfig(
**dict(filter(predicate, compilation_config.items())))
else:
compilation_config_instance = compilation_config
else:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
import torch
from tpu_info import device
@ -13,9 +13,10 @@ from vllm.sampling_params import SamplingParams, SamplingType
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
else:
BlockSize = None
ModelConfig = None
VllmConfig = None
PoolingParams = None
@ -94,7 +95,7 @@ class TpuPlatform(Platform):
cache_config = vllm_config.cache_config
# For v0, the default block size is 16.
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
cache_config.block_size = cast(BlockSize, 16)
compilation_config = vllm_config.compilation_config
# TPU only supports DYNAMO_ONCE compilation level
@ -118,7 +119,7 @@ class TpuPlatform(Platform):
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config)
vllm_config) # type: ignore[assignment]
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > cache_config.block_size:
@ -128,7 +129,7 @@ class TpuPlatform(Platform):
cache_config.block_size,
min_page_size,
)
cache_config.block_size = min_page_size
cache_config.block_size = min_page_size # type: ignore[assignment]
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config

View File

@ -1820,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
return isinstance(zmq, _MockModule)
except ModuleNotFoundError:
return False
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
"""
Import a Python file according to its file path.