From 4b2ed7926a1d93d4189ac112209f2e34cd80846a Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 9 May 2025 23:18:44 +0100 Subject: [PATCH] Improve configs - the rest! (#17562) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/compile/test_full_graph.py | 5 +- tests/compile/test_functionalization.py | 7 +- tests/compile/test_fusion.py | 6 +- tests/compile/test_sequence_parallelism.py | 7 +- tests/compile/test_silu_mul_quant_fusion.py | 5 +- tests/distributed/test_sequence_parallel.py | 4 +- tests/engine/test_arg_utils.py | 76 ++- vllm/compilation/vllm_inductor_pass.py | 7 +- vllm/config.py | 513 ++++++++++++-------- vllm/distributed/kv_events.py | 3 +- vllm/engine/arg_utils.py | 131 +++-- vllm/entrypoints/llm.py | 13 +- vllm/platforms/tpu.py | 11 +- vllm/utils.py | 8 + 14 files changed, 456 insertions(+), 340 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index c094063859..397517b866 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -9,7 +9,7 @@ import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig, CompilationLevel, PassConfig from vllm.platforms import current_platform from ..utils import create_new_process_for_each_test @@ -95,9 +95,6 @@ def test_full_graph( run_model(optimization_level, model, model_kwargs) -PassConfig = CompilationConfig.PassConfig - - # TODO(luka) add other supported compilation config scenarios here @pytest.mark.parametrize( "compilation_config, model_info", diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 1e1364ce7b..5d38ff9149 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, PassConfig, VllmConfig from .backend import TestBackend @@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, torch.set_default_device("cuda") vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config= \ - CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_noop=True)) + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) noop_pass = NoOpEliminationPass(vllm_config) fusion_pass = FusionPass.instance(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6a696fe022..4d56b34bde 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, FusionPass, QuantKey) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) @@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) vllm_config.compilation_config.pass_config = \ - CompilationConfig.PassConfig(enable_fusion=True, - enable_noop=True) + PassConfig(enable_fusion=True, enable_noop=True) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 79f5486dad..6152f17170 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, find_specified_fn_maybe, is_func) from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - VllmConfig) + PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) @@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( - pass_config=CompilationConfig.PassConfig( - enable_sequence_parallelism=True, ), ) + vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( + enable_sequence_parallelism=True)) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 313848372e..f87f175acd 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -6,7 +6,7 @@ import vllm.envs as envs from vllm._custom_ops import scaled_fp8_quant from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from .backend import TestBackend @@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): # Reshape pass is needed for the fusion pass to work config = VllmConfig() config.compilation_config = CompilationConfig( - pass_config=CompilationConfig.PassConfig(enable_fusion=True, - enable_reshape=True)) + pass_config=PassConfig(enable_fusion=True, enable_reshape=True)) fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(fusion_pass) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 19497ad9c1..bbf3ed5843 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -206,7 +206,7 @@ def _compare_sp( 'compile_sizes': [4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_sequence_parallism': sp_enabled, + 'enable_sequence_parallelism': sp_enabled, 'enable_noop': True, 'enable_fusion': True, }, @@ -223,7 +223,7 @@ def _compare_sp( "--distributed-executor-backend", distributed_backend, "--compilation_config", - str(compilation_config), + json.dumps(compilation_config), ] tp_env = { diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 65471cb3af..ce8873d58d 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -8,21 +8,18 @@ from typing import Literal, Optional import pytest -from vllm.config import config +from vllm.config import CompilationConfig, config from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, get_type, is_not_builtin, is_type, literal_to_kwargs, nullable_kvs, - optional_type) + optional_type, parse_type) from vllm.utils import FlexibleArgumentParser @pytest.mark.parametrize(("type", "value", "expected"), [ (int, "42", 42), - (int, "None", None), (float, "3.14", 3.14), - (float, "None", None), (str, "Hello World!", "Hello World!"), - (str, "None", None), (json.loads, '{"foo":1,"bar":2}', { "foo": 1, "bar": 2 @@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser "foo": 1, "bar": 2 }), - (json.loads, "None", None), ]) -def test_optional_type(type, value, expected): - optional_type_func = optional_type(type) +def test_parse_type(type, value, expected): + parse_type_func = parse_type(type) context = nullcontext() if value == "foo=1,bar=2": context = pytest.warns(DeprecationWarning) with context: - assert optional_type_func(value) == expected + assert parse_type_func(value) == expected + + +def test_optional_type(): + optional_type_func = optional_type(int) + assert optional_type_func("None") is None + assert optional_type_func("42") == 42 @pytest.mark.parametrize(("type_hint", "type", "expected"), [ @@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected): @config @dataclass -class DummyConfigClass: +class NestedConfig: + field: int = 1 + """field""" + + +@config +@dataclass +class FromCliConfig1: + field: int = 1 + """field""" + + @classmethod + def from_cli(cls, cli_value: str): + inst = cls(**json.loads(cli_value)) + inst.field += 1 + return inst + + +@config +@dataclass +class FromCliConfig2: + field: int = 1 + """field""" + + @classmethod + def from_cli(cls, cli_value: str): + inst = cls(**json.loads(cli_value)) + inst.field += 2 + return inst + + +@config +@dataclass +class DummyConfig: regular_bool: bool = True """Regular bool with default True""" optional_bool: Optional[bool] = None @@ -108,18 +143,24 @@ class DummyConfigClass: """Literal of literals with default 1""" json_tip: dict = field(default_factory=dict) """Dict which will be JSON in CLI""" + nested_config: NestedConfig = field(default_factory=NestedConfig) + """Nested config""" + from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1) + """Config with from_cli method""" + from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2) + """Different config with from_cli method""" @pytest.mark.parametrize(("type_hint", "expected"), [ (int, False), - (DummyConfigClass, True), + (DummyConfig, True), ]) def test_is_not_builtin(type_hint, expected): assert is_not_builtin(type_hint) == expected def test_get_kwargs(): - kwargs = get_kwargs(DummyConfigClass) + kwargs = get_kwargs(DummyConfig) print(kwargs) # bools should not have their type set @@ -142,6 +183,11 @@ def test_get_kwargs(): # dict should have json tip in help json_tip = "\n\nShould be a valid JSON string." assert kwargs["json_tip"]["help"].endswith(json_tip) + # nested config should should construct the nested config + assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) + # from_cli configs should be constructed with the correct method + assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3 + assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4 @pytest.mark.parametrize(("arg", "expected"), [ @@ -177,7 +223,7 @@ def test_compilation_config(): # default value args = parser.parse_args([]) - assert args.compilation_config is None + assert args.compilation_config == CompilationConfig() # set to O3 args = parser.parse_args(["-O3"]) @@ -194,7 +240,7 @@ def test_compilation_config(): # set to string form of a dict args = parser.parse_args([ "--compilation-config", - "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', ]) assert (args.compilation_config.level == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) @@ -202,7 +248,7 @@ def test_compilation_config(): # set to string form of a dict args = parser.parse_args([ "--compilation-config=" - "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', ]) assert (args.compilation_config.level == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index e8bffb406f..c95e0bce5f 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -4,7 +4,7 @@ import time import torch -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import PassConfig, VllmConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( @@ -56,10 +56,7 @@ class VllmInductorPass(InductorPass): class PrinterInductorPass(VllmInductorPass): - def __init__(self, - name: str, - config: CompilationConfig.PassConfig, - always=False): + def __init__(self, name: str, config: PassConfig, always=False): super().__init__(config) self.name = name self.always = always diff --git a/vllm/config.py b/vllm/config.py index cc185b1d5b..ef0163eaff 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,8 +11,8 @@ import textwrap import warnings from collections import Counter from contextlib import contextmanager -from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, - replace) +from dataclasses import (MISSING, Field, asdict, dataclass, field, fields, + is_dataclass, replace) from functools import cached_property from importlib.util import find_spec from pathlib import Path @@ -20,7 +20,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args, get_origin) import torch -from pydantic import BaseModel, Field, PrivateAttr from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated @@ -57,7 +56,7 @@ if TYPE_CHECKING: ConfigType = type[DataclassInstance] else: - QuantizationConfig = None + QuantizationConfig = Any ConfigType = type logger = init_logger(__name__) @@ -169,6 +168,12 @@ def config(cls: ConfigT) -> ConfigT: """ A decorator that ensures all fields in a dataclass have default values and that each field has a docstring. + + If a `ConfigT` is used as a CLI argument itself, the default value provided + by `get_kwargs` will be the result parsing a JSON string as the kwargs + (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` + requires custom construction from CLI (i.e. `CompilationConfig`), it can + have a `from_cli` method, which will be called instead. """ if not is_dataclass(cls): raise TypeError("The decorated class must be a dataclass.") @@ -202,7 +207,7 @@ def get_field(cls: ConfigType, name: str) -> Field: cls_fields = {f.name: f for f in fields(cls)} if name not in cls_fields: raise ValueError(f"Field '{name}' not found in {cls.__name__}.") - named_field: Field = cls_fields.get(name) + named_field: Field = cls_fields[name] if (default_factory := named_field.default_factory) is not MISSING: return field(default_factory=default_factory) if (default := named_field.default) is not MISSING: @@ -211,6 +216,10 @@ def get_field(cls: ConfigType, name: str) -> Field: f"{cls.__name__}.{name} must have a default value or default factory.") +def is_init_field(cls: ConfigType, name: str) -> bool: + return next(f for f in fields(cls) if f.name == name).init + + TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] @@ -2007,13 +2016,13 @@ class SchedulerConfig: def __post_init__(self) -> None: if self.max_model_len is None: self.max_model_len = 8192 - logger.warning( + logger.warning_once( "max_model_len was is not set. Defaulting to arbitrary value " "of %d.", self.max_model_len) if self.max_num_seqs is None: self.max_num_seqs = 128 - logger.warning( + logger.warning_once( "max_num_seqs was is not set. Defaulting to arbitrary value " "of %d.", self.max_num_seqs) @@ -2840,8 +2849,8 @@ class PromptAdapterConfig: class MultiModalConfig: """Controls the behavior of multimodal models.""" - limit_per_prompt: dict[str, int] = get_field(ModelConfig, - "limit_mm_per_prompt") + limit_per_prompt: dict[str, int] = \ + cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt")) """ The maximum number of input items allowed per prompt for each modality. Defaults to 1 (V0) or 999 (V1) for each modality. @@ -3415,41 +3424,49 @@ class ObservabilityConfig: self.collect_detailed_traces[0].split(",")) -class KVTransferConfig(BaseModel): +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: """Configuration for distributed KV cache transfer.""" - # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ - # The device used by kv connector to buffer the KV cache. - # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. + Currently only support 'cuda'.""" - # The buffer size for TorchDistributedConnector. Measured in number of - # bytes. Recommended value: 1e9 (about 1GB). kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" - # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. - kv_role: Optional[str] = None + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'both'.""" - # The rank of this vLLM instance in the KV cache transfer. Typical value: - # 0 for prefill instance, 1 for decode instance. - # Currently only 1P1D is supported. kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" - # The number of parallel instances for KV cache transfer. For - # PyNcclConnector, this should be 2. kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + PyNcclConnector, this should be 2.""" - # The KV connector ip, used to build distributed connection kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" - # The KV connector port, used to build distributed connection kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" - # any extra config that the connector may need - kv_connector_extra_config: dict[str, Any] = {} + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" def compute_hash(self) -> str: """ @@ -3470,46 +3487,37 @@ class KVTransferConfig(BaseModel): usedforsecurity=False).hexdigest() return hash_str - @classmethod - def from_cli(cls, cli_value: str) -> "KVTransferConfig": - """Parse the CLI value for the kv cache transfer config.""" - return KVTransferConfig.model_validate_json(cli_value) - - def model_post_init(self, __context: Any) -> None: - - if self.kv_role is not None and self.kv_role not in [ - "kv_producer", "kv_consumer", "kv_both" - ]: - raise ValueError( - f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are `kv_producer`, `kv_consumer`, " - f"and `kv_both`") + def __post_init__(self) -> None: + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError(f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}") if self.kv_connector is not None and self.kv_role is None: raise ValueError("Please specify kv_disagg_role when kv_connector " - "is set, supported roles are `kv_producer`, " - "`kv_consumer`, and `kv_both`") + f"is set, supported roles are {get_args(KVRole)}") @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + self.kv_role in get_args(KVRole) @property def is_kv_producer(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_producer", "kv_both"] + self.kv_role in get_args(KVProducer) @property def is_kv_consumer(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_consumer", "kv_both"] + self.kv_role in get_args(KVConsumer) def get_from_extra_config(self, key, default) -> Any: return self.kv_connector_extra_config.get(key, default) -class KVEventsConfig(BaseModel): +@config +@dataclass +class KVEventsConfig: """Configuration for KV event publishing.""" enable_kv_cache_events: bool = False @@ -3548,11 +3556,6 @@ class KVEventsConfig(BaseModel): this topic to receive events. """ - @classmethod - def from_cli(cls, cli_value: str) -> "KVEventsConfig": - """Parse the CLI value for the event publisher config.""" - return KVEventsConfig.model_validate_json(cli_value) - class CompilationLevel: # constants for the levels of the compilation process @@ -3562,80 +3565,72 @@ class CompilationLevel: PIECEWISE = 3 -class CompilationConfig(BaseModel): - """ - Configuration for compilation. - It has three parts: +@config +@dataclass +class PassConfig: + """Configuration for custom Inductor passes. + + This is separate from general `CompilationConfig` so that inductor passes + don't all have access to full configuration - that would create a cycle as + the `PassManager` is set as a property of config.""" + + dump_graph_stages: list[str] = field(default_factory=list) + """List of stages for which we want to dump the graph. Each pass defines + its own stages (before, after, maybe in-between).""" + dump_graph_dir: Path = Path(".") + """Directory to dump the graphs.""" + # TODO(luka) better pass enabling system. + enable_fusion: bool = True + """Whether to enable the custom fusion pass.""" + enable_noop: bool = True + """Whether to enable the custom no-op elimination pass.""" + enable_sequence_parallelism: bool = False + """Whether to enable sequence parallelism.""" + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + include = { + "enable_fusion", "enable_noop", "enable_sequence_parallelism" + } + dict_ = {k: v for k, v in asdict(self).items() if k in include} + return InductorPass.hash_dict(dict_) + + def __post_init__(self) -> None: + if not self.enable_noop and self.enable_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm + quant (fp8) fusion might not work") + + +@config +@dataclass +class CompilationConfig: + """Configuration for compilation. It has three parts: + - Top-level Compilation control: - - level: the level of compilation. - - 0: no compilation. - - 1: dynamo as is. - - 2: dynamo once. - - 3: piecewise compilation. - - debug_dump_path: the path to dump the debug information. - - cache_dir: the directory to store the compiled graph, to - accelerate Inductor compilation. By default, it will use - model-related information to generate a cache directory. - - backend: the backend for compilation. It needs to be a string. - - "" (empty string): use the default backend. - - "eager"/"openxla"/...: use the specified backend registered in PyTorch. - - "full.module.name": a qualified name which can be used to import the backend function. - We use string to avoid serialization issues when using compilation in a distributed setting. - When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). - When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). - - custom_ops: fine-grained control over which custom ops to enable/disable. - Use 'all' to enable all, 'none' to disable all. - Also specify a list of custom op names to enable (prefixed with a '+'), - or disable (prefixed with a '-'). - Examples: - - 'all,-op1' to enable all except op1 - - 'none,+op1,+op2' to enable only op1 and op2 - By default, all custom ops are enabled when running without Inductor - and disabled when running with Inductor (compile_level >= Inductor). - - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation. + - {attr}`level` + - {attr}`debug_dump_path` + - {attr}`cache_dir` + - {attr}`backend` + - {attr}`custom_ops` + - {attr}`splitting_ops` - CudaGraph capture: - - use_cudagraph: whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses, and all - splitting ops write their outputs to input buffers. - Note that this is orthogonal to the cudagraph capture logic - outside of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. - - None (default): capture sizes are inferred from vllm config. - - list[int]: capture sizes are specified as given. - - cudagraph_num_of_warmups: number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs. - - cudagraph_copy_inputs: whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. - - full_cuda_graph: whether to use a full cuda graph for the entire forward - pass rather than splitting certain operations such as attention into subgraphs. - Thus this flag cannot be used together with splitting_ops. This may provide - performance benefits for smaller models. + - {attr}`use_cudagraph` + - {attr}`cudagraph_capture_sizes` + - {attr}`cudagraph_num_of_warmups` + - {attr}`cudagraph_copy_inputs` + - {attr}`full_cuda_graph` - Inductor compilation: - - use_inductor: whether to use inductor compilation. - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for compile_sizes, - using configurations in inductor_compile_config. - - compile_sizes: sizes to compile for inductor. In addition - to integers, it also supports "cudagraph_capture_sizes" to - specify the sizes for cudagraph capture. - - inductor_compile_config: additional configurations for inductor. - - None: use default configurations. - - inductor_passes: additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses json format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - custom inductor passes: see PassConfig for more details + - {attr}`use_inductor` + - {attr}`compile_sizes` + - {attr}`inductor_compile_config` + - {attr}`inductor_passes` + - custom inductor passes Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -3646,83 +3641,135 @@ class CompilationConfig(BaseModel): static shapes. However, we find the general shape compilation is sufficient for most cases. It might be beneficial to compile for certain small batchsizes, where inductor is good at optimizing. - """ # noqa + """ + # Top-level Compilation control level: int = 0 + """The level of compilation: + + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation.""" debug_dump_path: str = "" + """The path to dump the debug information.""" cache_dir: str = "" + """The directory to store the compiled graph, to accelerate Inductor + compilation. By default, it will use model-related information to generate + a cache directory.""" backend: str = "" - custom_ops: list[str] = Field(default_factory=list) - splitting_ops: list[str] = Field(default=None) # type: ignore + """The backend for compilation. It needs to be a string: + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the + + backend function. + We use string to avoid serialization issues when using compilation in a + distributed setting. When the compilation level is 1 or 2, the backend is + used for the compilation directly (it sees the whole graph). When the + compilation level is 3, the backend is used for the piecewise compilation + (it sees a part of the graph).""" + custom_ops: list[str] = field(default_factory=list) + """Fine-grained control over which custom ops to enable/disable. Use 'all' + to enable all, 'none' to disable all. Also specify a list of custom op + names to enable (prefixed with a '+'), or disable (prefixed with a '-'). + Examples: + + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + + By default, all custom ops are enabled when running without Inductor and + disabled when running with Inductor (compile_level >= Inductor).""" + splitting_ops: list[str] = field(default_factory=list) + """A list of ops to split the full graph into subgraphs, used in piecewise + compilation.""" + + # Inductor capture use_inductor: bool = True - compile_sizes: Optional[list[Union[int, str]]] = Field(default=None) - inductor_compile_config: dict = Field(default_factory=dict) - inductor_passes: dict[str, str] = Field(default_factory=dict) + """Whether to use inductor compilation: + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for compile_sizes, + using configurations in inductor_compile_config.""" + compile_sizes: Optional[list[Union[int, str]]] = None + """Sizes to compile for inductor. In addition + to integers, it also supports "cudagraph_capture_sizes" to + specify the sizes for cudagraph capture.""" + inductor_compile_config: dict = field(default_factory=dict) + """Additional configurations for inductor. + - None: use default configurations.""" + inductor_passes: dict[str, str] = field(default_factory=dict) + """Additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses JSON format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" + + # CudaGraph compilation use_cudagraph: bool = False + """Whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future.""" cudagraph_num_of_warmups: int = 0 + """Number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs.""" cudagraph_capture_sizes: Optional[list[int]] = None + """Sizes to capture cudagraph. + - None (default): capture sizes are inferred from vllm config. + - list[int]: capture sizes are specified as given.""" cudagraph_copy_inputs: bool = False + """Whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False.""" full_cuda_graph: bool = False + """whether to use a full cuda graph for the entire forward pass rather than + splitting certain operations such as attention into subgraphs. Thus this + flag cannot be used together with splitting_ops. This may provide + performance benefits for smaller models.""" - class PassConfig(BaseModel): - """ - Configuration for custom Inductor passes. - This is separate from general CompilationConfig so that inductor passes - don't all have access to full configuration - that would create a cycle - as the PassManager is set as a property of config. - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graphs. Default is . - - enable_fusion: whether to enable the custom fusion pass. - - enable_noop: whether to enable the custom no-op elimination pass. - TODO(luka) better pass enabling system. - - enable_sequence_parallelism: whether to enable sequence parallelism. - """ - dump_graph_stages: list[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True - enable_noop: bool = True - enable_sequence_parallelism: bool = False + pass_config: PassConfig = field(default_factory=PassConfig) + """Custom inductor passes, see PassConfig for more details""" - def uuid(self): - """ - Produces a hash unique to the pass configuration. - Any new fields that affect compilation should be added to the hash. - Do not include dump_graph_* in the hash - they don't affect - compilation. - """ - dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \ - "enable_sequence_parallelism"}) - return InductorPass.hash_dict(dict_) - - def model_post_init(self, __context: Any) -> None: - if not self.enable_noop and self.enable_fusion: - logger.warning_once( - "Fusion enabled but reshape elimination disabled. " - "RMSNorm + quant (fp8) fusion might not work") - - pass_config: PassConfig = Field(default_factory=PassConfig) - - # not configurable, computed after init - max_capture_size: int = PrivateAttr - local_cache_dir: str = PrivateAttr # local cache dir for each rank - # optimization: - # Intuitively, bs_to_padded_graph_size should be dict[int, int]. - # since we know all keys are in a range [0, max_capture_size], - # we can optimize it to list[int] for better lookup performance. - bs_to_padded_graph_size: list[int] = PrivateAttr + max_capture_size: int = field(default=None, init=False) # type: ignore + """not configurable, computed after init""" + local_cache_dir: str = field(default=None, init=False) # type: ignore + """local cache dir for each rank""" + bs_to_padded_graph_size: list[int] = field( + default=None, # type: ignore + init=False) + """optimization: + Intuitively, bs_to_padded_graph_size should be dict[int, int]. + since we know all keys are in a range [0, max_capture_size], + we can optimize it to list[int] for better lookup performance.""" # keep track of enabled and disabled custom ops - enabled_custom_ops: Counter[str] = PrivateAttr - disabled_custom_ops: Counter[str] = PrivateAttr - traced_files: set[str] = PrivateAttr - compilation_time: float = PrivateAttr + enabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are enabled""" + disabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are disabled""" + traced_files: set[str] = field(default_factory=set, init=False) + """files that are traced for compilation""" + compilation_time: float = field(default=0.0, init=False) + """time taken for compilation""" - # Per-model forward context - # Map from layer name to layer objects that need to be accessed outside - # model code, e.g., Attention, FusedMOE when dp_size>1. - static_forward_context: dict[str, Any] = PrivateAttr + static_forward_context: dict[str, Any] = field(default_factory=dict, + init=False) + """Per-model forward context + Map from layer name to layer objects that need to be accessed outside + model code, e.g., Attention, FusedMOE when dp_size>1.""" def compute_hash(self) -> str: """ @@ -3757,7 +3804,17 @@ class CompilationConfig(BaseModel): "pass_config", "traced_files", } - return self.model_dump_json(exclude=exclude, exclude_unset=True) + include = dict() + for k, v in asdict(self).items(): + if k in exclude: + continue + f = get_field(CompilationConfig, k) + if (d := f.default) is not MISSING and d == v: + continue + if (df := f.default_factory) is not MISSING and df() == v: + continue + include[k] = v + return json.dumps(include) __str__ = __repr__ @@ -3766,12 +3823,9 @@ class CompilationConfig(BaseModel): """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - # do not use `eval`, it is dangerous and can execute arbitrary code - dict_value = ast.literal_eval(cli_value) - return CompilationConfig.model_validate(dict_value) - - def model_post_init(self, __context: Any) -> None: + return cls(**json.loads(cli_value)) + def __post_init__(self) -> None: count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" @@ -3789,9 +3843,6 @@ class CompilationConfig(BaseModel): if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False - if self.splitting_ops is None: - self.splitting_ops = [] - for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( @@ -3808,11 +3859,8 @@ class CompilationConfig(BaseModel): self.inductor_compile_config[k] = func if isinstance( func, InductorPass) else CallableInductorPass(func) - self.enabled_custom_ops = Counter() - self.disabled_custom_ops = Counter() - self.traced_files = set() - self.static_forward_context = {} - self.compilation_time = 0.0 + if isinstance(self.pass_config, dict): + self.pass_config = PassConfig(**self.pass_config) def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: @@ -3899,39 +3947,67 @@ class CompilationConfig(BaseModel): ] +@config @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. """ - model_config: ModelConfig = field(default=None, init=True) # type: ignore - cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default_factory=ParallelConfig, - init=True) - scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig, - init=True) - device_config: DeviceConfig = field(default=None, - init=True) # type: ignore - load_config: LoadConfig = field(default=None, init=True) # type: ignore + model_config: ModelConfig = field(default_factory=ModelConfig) + """Model configuration.""" + cache_config: CacheConfig = field(default_factory=CacheConfig) + """Cache configuration.""" + parallel_config: ParallelConfig = field(default_factory=ParallelConfig) + """Parallel configuration.""" + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) + """Scheduler configuration.""" + device_config: DeviceConfig = field(default_factory=DeviceConfig) + """Device configuration.""" + load_config: LoadConfig = field(default_factory=LoadConfig) + """Load configuration.""" lora_config: Optional[LoRAConfig] = None - speculative_config: SpeculativeConfig = field(default=None, - init=True) # type: ignore + """LoRA configuration.""" + speculative_config: Optional[SpeculativeConfig] = None + """Speculative decoding configuration.""" decoding_config: Optional[DecodingConfig] = None + """Decoding configuration.""" observability_config: Optional[ObservabilityConfig] = None + """Observability configuration.""" prompt_adapter_config: Optional[PromptAdapterConfig] = None + """Prompt adapter configuration.""" quant_config: Optional[QuantizationConfig] = None - compilation_config: CompilationConfig = field(default=None, - init=True) # type: ignore - kv_transfer_config: KVTransferConfig = field(default=None, - init=True) # type: ignore + """Quantization configuration.""" + compilation_config: CompilationConfig = field( + default_factory=CompilationConfig) + """`torch.compile` configuration for the model. + + When it is a number (0, 1, 2, 3), it will be interpreted as the + optimization level. + + NOTE: level 0 is the default level without any optimization. level 1 and 2 + are for internal testing only. level 3 is the recommended level for + production. + + Following the convention of traditional compilers, using `-O` without space + is also supported. `-O3` is equivalent to `-O 3`. + + You can specify the full compilation config like so: + `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` + """ + kv_transfer_config: Optional[KVTransferConfig] = None + """The configurations for distributed KV cache transfer.""" kv_events_config: Optional[KVEventsConfig] = None + """The configurations for event publishing.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. - additional_config: SupportsHash = field(default=None, - init=True) # type: ignore + additional_config: Union[dict, SupportsHash] = field(default_factory=dict) + """Additional config for specified platform. Different platforms may + support different configs. Make sure the configs are valid for the platform + you are using. Contents must be hashable.""" instance_id: str = "" + """The ID of the vLLM instance.""" def compute_hash(self) -> str: """ @@ -4012,7 +4088,14 @@ class VllmConfig: else: vllm_factors.append("None") if self.additional_config: - vllm_factors.append(self.additional_config.compute_hash()) + if isinstance(additional_config := self.additional_config, dict): + additional_config_hash = hashlib.md5( + json.dumps(additional_config, sort_keys=True).encode(), + usedforsecurity=False, + ).hexdigest() + else: + additional_config_hash = additional_config.compute_hash() + vllm_factors.append(additional_config_hash) else: vllm_factors.append("None") factors.append(vllm_factors) diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 9609138585..1141a8e53c 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -5,6 +5,7 @@ import threading import time from abc import ABC, abstractmethod from collections import deque +from dataclasses import asdict from itertools import count from queue import Queue from typing import Any, Callable, Optional, Union @@ -284,7 +285,7 @@ class EventPublisherFactory: if not config: return NullEventPublisher() - config_dict = config.model_dump() + config_dict = asdict(config) kind = config_dict.pop("publisher", "null") config_dict.pop("enable_kv_cache_events") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index be4be6ed5f..0ff6a6fbbc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,10 +7,10 @@ import json import re import threading import warnings -from dataclasses import MISSING, dataclass, fields +from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations -from typing import (Any, Callable, Dict, List, Literal, Optional, Type, - TypeVar, Union, cast, get_args, get_origin) +from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, + Type, TypeVar, Union, cast, get_args, get_origin) import torch from typing_extensions import TypeIs, deprecated @@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor +from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build, + is_in_ray_actor) # yapf: enable @@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object] TypeHintT = Union[type[T], object] -def optional_type( - return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: +def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: - def _optional_type(val: str) -> Optional[T]: - if val == "" or val == "None": - return None + def _parse_type(val: str) -> T: try: if return_type is json.loads and not re.match("^{.*}$", val): return cast(T, nullable_kvs(val)) @@ -62,14 +60,24 @@ def optional_type( raise argparse.ArgumentTypeError( f"Value {val} cannot be converted to {return_type}.") from e + return _parse_type + + +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + return parse_type(return_type)(val) + return _optional_type def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: if not re.match("^{.*}$", val): return str(val) - else: - return optional_type(json.loads)(val) + return optional_type(json.loads)(val) @deprecated( @@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: cls_docs = get_attr_docs(cls) kwargs = {} for field in fields(cls): + # Get the set of possible types for the field + type_hints: set[TypeHint] = set() + if get_origin(field.type) in {Union, Annotated}: + type_hints.update(get_args(field.type)) + else: + type_hints.add(field.type) + + # If the field is a dataclass, we can use the model_validate_json + generator = (th for th in type_hints if is_dataclass(th)) + dataclass_cls = next(generator, None) + # Get the default value of the field - default = field.default - if field.default_factory is not MISSING: - default = field.default_factory() + if field.default is not MISSING: + default = field.default + elif field.default_factory is not MISSING: + if is_dataclass(field.default_factory) and is_in_doc_build(): + default = {} + else: + default = field.default_factory() # Get the help text for the field name = field.name @@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Initialise the kwargs dictionary for the field kwargs[name] = {"default": default, "help": help} - # Get the set of possible types for the field - type_hints: set[TypeHint] = set() - if get_origin(field.type) is Union: - type_hints.update(get_args(field.type)) - else: - type_hints.add(field.type) - # Set other kwargs based on the type hints json_tip = "\n\nShould be a valid JSON string." - if contains_type(type_hints, bool): + if dataclass_cls is not None: + dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x)) + # Special case for configs with a from_cli method + if hasattr(dataclass_cls, "from_cli"): + from_cli = dataclass_cls.from_cli + dataclass_init = lambda x, f=from_cli: f(x) + kwargs[name]["type"] = dataclass_init + kwargs[name]["help"] += json_tip + elif contains_type(type_hints, bool): # Creates --no- and -- flags kwargs[name]["action"] = argparse.BooleanOptionalAction elif contains_type(type_hints, Literal): @@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): # Dict arguments will always be optional - kwargs[name]["type"] = optional_type(json.loads) + kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += json_tip elif (contains_type(type_hints, str) or any(is_not_builtin(th) for th in type_hints)): @@ -771,63 +795,20 @@ class EngineArgs: scheduler_group.add_argument("--scheduler-cls", **scheduler_kwargs["scheduler_cls"]) - # Compilation arguments - # compilation_kwargs = get_kwargs(CompilationConfig) - compilation_group = parser.add_argument_group( - title="CompilationConfig", - description=CompilationConfig.__doc__, - ) - compilation_group.add_argument( - "--compilation-config", - "-O", - type=CompilationConfig.from_cli, - default=None, - help="torch.compile configuration for the model. " - "When it is a number (0, 1, 2, 3), it will be " - "interpreted as the optimization level.\n" - "NOTE: level 0 is the default level without " - "any optimization. level 1 and 2 are for internal " - "testing only. level 3 is the recommended level " - "for production.\n" - "To specify the full compilation config, " - "use a JSON string, e.g. ``{\"level\": 3, " - "\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n" - "Following the convention of traditional " - "compilers, using ``-O`` without space is also " - "supported. ``-O3`` is equivalent to ``-O 3``.") - - # KVTransfer arguments - # kv_transfer_kwargs = get_kwargs(KVTransferConfig) - kv_transfer_group = parser.add_argument_group( - title="KVTransferConfig", - description=KVTransferConfig.__doc__, - ) - kv_transfer_group.add_argument( - "--kv-transfer-config", - type=KVTransferConfig.from_cli, - default=None, - help="The configurations for distributed KV cache " - "transfer. Should be a JSON string.") - kv_transfer_group.add_argument( - '--kv-events-config', - type=KVEventsConfig.from_cli, - default=None, - help='The configurations for event publishing.') - # vLLM arguments - # vllm_kwargs = get_kwargs(VllmConfig) + vllm_kwargs = get_kwargs(VllmConfig) vllm_group = parser.add_argument_group( title="VllmConfig", description=VllmConfig.__doc__, ) - vllm_group.add_argument( - "--additional-config", - type=json.loads, - default=None, - help="Additional config for specified platform in JSON format. " - "Different platforms may support different configs. Make sure the " - "configs are valid for the platform you are using. The input format" - " is like '{\"config_key\":\"config_value\"}'") + vllm_group.add_argument("--kv-transfer-config", + **vllm_kwargs["kv_transfer_config"]) + vllm_group.add_argument('--kv-events-config', + **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument("--compilation-config", "-O", + **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--additional-config", + **vllm_kwargs["additional_config"]) # Other arguments parser.add_argument('--use-v2-block-manager', diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 72ad79bd2d..cebddcc8e6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,7 +13,8 @@ from typing_extensions import TypeVar, deprecated from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) -from vllm.config import CompilationConfig, ModelDType, TokenizerMode +from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, + is_init_field) from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -204,9 +205,13 @@ class LLM: kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) if compilation_config is not None: - if isinstance(compilation_config, (int, dict)): - compilation_config_instance = CompilationConfig.from_cli( - str(compilation_config)) + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + level=compilation_config) + elif isinstance(compilation_config, dict): + predicate = lambda x: is_init_field(CompilationConfig, x[0]) + compilation_config_instance = CompilationConfig( + **dict(filter(predicate, compilation_config.items()))) else: compilation_config_instance = compilation_config else: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 2782a3866d..d0a5af3587 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast import torch from tpu_info import device @@ -13,9 +13,10 @@ from vllm.sampling_params import SamplingParams, SamplingType from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import ModelConfig, VllmConfig + from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams else: + BlockSize = None ModelConfig = None VllmConfig = None PoolingParams = None @@ -94,7 +95,7 @@ class TpuPlatform(Platform): cache_config = vllm_config.cache_config # For v0, the default block size is 16. if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 + cache_config.block_size = cast(BlockSize, 16) compilation_config = vllm_config.compilation_config # TPU only supports DYNAMO_ONCE compilation level @@ -118,7 +119,7 @@ class TpuPlatform(Platform): from vllm.v1.attention.backends.pallas import ( PallasAttentionBackend) cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) + vllm_config) # type: ignore[assignment] min_page_size = PallasAttentionBackend.get_min_page_size( vllm_config) if min_page_size > cache_config.block_size: @@ -128,7 +129,7 @@ class TpuPlatform(Platform): cache_config.block_size, min_page_size, ) - cache_config.block_size = min_page_size + cache_config.block_size = min_page_size # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config diff --git a/vllm/utils.py b/vllm/utils.py index 24535196cc..6779c5b3f8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1820,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) +def is_in_doc_build() -> bool: + try: + from sphinx.ext.autodoc.mock import _MockModule + return isinstance(zmq, _MockModule) + except ModuleNotFoundError: + return False + + def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): """ Import a Python file according to its file path.