Improve configs - the rest! (#17562)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -9,7 +9,7 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
@ -95,9 +95,6 @@ def test_full_graph(
|
||||
run_model(optimization_level, model, model_kwargs)
|
||||
|
||||
|
||||
PassConfig = CompilationConfig.PassConfig
|
||||
|
||||
|
||||
# TODO(luka) add other supported compilation config scenarios here
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_config, model_info",
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(pass_config= \
|
||||
CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||
enable_noop=True))
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
FusionPass, QuantKey)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
||||
@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||
vllm_config.compilation_config.pass_config = \
|
||||
CompilationConfig.PassConfig(enable_fusion=True,
|
||||
enable_noop=True)
|
||||
PassConfig(enable_fusion=True, enable_noop=True)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
|
||||
find_specified_fn_maybe, is_func)
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
VllmConfig)
|
||||
PassConfig, VllmConfig)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=CompilationConfig.PassConfig(
|
||||
enable_sequence_parallelism=True, ), )
|
||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True))
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
|
||||
@ -6,7 +6,7 @@ import vllm.envs as envs
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from .backend import TestBackend
|
||||
@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = VllmConfig()
|
||||
config.compilation_config = CompilationConfig(
|
||||
pass_config=CompilationConfig.PassConfig(enable_fusion=True,
|
||||
enable_reshape=True))
|
||||
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(fusion_pass)
|
||||
|
||||
@ -206,7 +206,7 @@ def _compare_sp(
|
||||
'compile_sizes': [4, 8],
|
||||
'splitting_ops': [],
|
||||
'pass_config': {
|
||||
'enable_sequence_parallism': sp_enabled,
|
||||
'enable_sequence_parallelism': sp_enabled,
|
||||
'enable_noop': True,
|
||||
'enable_fusion': True,
|
||||
},
|
||||
@ -223,7 +223,7 @@ def _compare_sp(
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
"--compilation_config",
|
||||
str(compilation_config),
|
||||
json.dumps(compilation_config),
|
||||
]
|
||||
|
||||
tp_env = {
|
||||
|
||||
@ -8,21 +8,18 @@ from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import config
|
||||
from vllm.config import CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||
get_type, is_not_builtin, is_type,
|
||||
literal_to_kwargs, nullable_kvs,
|
||||
optional_type)
|
||||
optional_type, parse_type)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type", "value", "expected"), [
|
||||
(int, "42", 42),
|
||||
(int, "None", None),
|
||||
(float, "3.14", 3.14),
|
||||
(float, "None", None),
|
||||
(str, "Hello World!", "Hello World!"),
|
||||
(str, "None", None),
|
||||
(json.loads, '{"foo":1,"bar":2}', {
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
}),
|
||||
(json.loads, "None", None),
|
||||
])
|
||||
def test_optional_type(type, value, expected):
|
||||
optional_type_func = optional_type(type)
|
||||
def test_parse_type(type, value, expected):
|
||||
parse_type_func = parse_type(type)
|
||||
context = nullcontext()
|
||||
if value == "foo=1,bar=2":
|
||||
context = pytest.warns(DeprecationWarning)
|
||||
with context:
|
||||
assert optional_type_func(value) == expected
|
||||
assert parse_type_func(value) == expected
|
||||
|
||||
|
||||
def test_optional_type():
|
||||
optional_type_func = optional_type(int)
|
||||
assert optional_type_func("None") is None
|
||||
assert optional_type_func("42") == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
||||
@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfigClass:
|
||||
class NestedConfig:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FromCliConfig1:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str):
|
||||
inst = cls(**json.loads(cli_value))
|
||||
inst.field += 1
|
||||
return inst
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FromCliConfig2:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str):
|
||||
inst = cls(**json.loads(cli_value))
|
||||
inst.field += 2
|
||||
return inst
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfig:
|
||||
regular_bool: bool = True
|
||||
"""Regular bool with default True"""
|
||||
optional_bool: Optional[bool] = None
|
||||
@ -108,18 +143,24 @@ class DummyConfigClass:
|
||||
"""Literal of literals with default 1"""
|
||||
json_tip: dict = field(default_factory=dict)
|
||||
"""Dict which will be JSON in CLI"""
|
||||
nested_config: NestedConfig = field(default_factory=NestedConfig)
|
||||
"""Nested config"""
|
||||
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
|
||||
"""Config with from_cli method"""
|
||||
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
|
||||
"""Different config with from_cli method"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||
(int, False),
|
||||
(DummyConfigClass, True),
|
||||
(DummyConfig, True),
|
||||
])
|
||||
def test_is_not_builtin(type_hint, expected):
|
||||
assert is_not_builtin(type_hint) == expected
|
||||
|
||||
|
||||
def test_get_kwargs():
|
||||
kwargs = get_kwargs(DummyConfigClass)
|
||||
kwargs = get_kwargs(DummyConfig)
|
||||
print(kwargs)
|
||||
|
||||
# bools should not have their type set
|
||||
@ -142,6 +183,11 @@ def test_get_kwargs():
|
||||
# dict should have json tip in help
|
||||
json_tip = "\n\nShould be a valid JSON string."
|
||||
assert kwargs["json_tip"]["help"].endswith(json_tip)
|
||||
# nested config should should construct the nested config
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||
# from_cli configs should be constructed with the correct method
|
||||
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
|
||||
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
@ -177,7 +223,7 @@ def test_compilation_config():
|
||||
|
||||
# default value
|
||||
args = parser.parse_args([])
|
||||
assert args.compilation_config is None
|
||||
assert args.compilation_config == CompilationConfig()
|
||||
|
||||
# set to O3
|
||||
args = parser.parse_args(["-O3"])
|
||||
@ -194,7 +240,7 @@ def test_compilation_config():
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args([
|
||||
"--compilation-config",
|
||||
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
||||
])
|
||||
assert (args.compilation_config.level == 3 and
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
||||
@ -202,7 +248,7 @@ def test_compilation_config():
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args([
|
||||
"--compilation-config="
|
||||
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
||||
])
|
||||
assert (args.compilation_config.level == 3 and
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
||||
|
||||
@ -4,7 +4,7 @@ import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.config import PassConfig, VllmConfig
|
||||
# yapf: disable
|
||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||
from vllm.distributed import (
|
||||
@ -56,10 +56,7 @@ class VllmInductorPass(InductorPass):
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
config: CompilationConfig.PassConfig,
|
||||
always=False):
|
||||
def __init__(self, name: str, config: PassConfig, always=False):
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
self.always = always
|
||||
|
||||
513
vllm/config.py
513
vllm/config.py
@ -11,8 +11,8 @@ import textwrap
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
replace)
|
||||
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
|
||||
is_dataclass, replace)
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
@ -20,7 +20,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
||||
Protocol, TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import deprecated
|
||||
@ -57,7 +56,7 @@ if TYPE_CHECKING:
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
else:
|
||||
QuantizationConfig = None
|
||||
QuantizationConfig = Any
|
||||
ConfigType = type
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -169,6 +168,12 @@ def config(cls: ConfigT) -> ConfigT:
|
||||
"""
|
||||
A decorator that ensures all fields in a dataclass have default values
|
||||
and that each field has a docstring.
|
||||
|
||||
If a `ConfigT` is used as a CLI argument itself, the default value provided
|
||||
by `get_kwargs` will be the result parsing a JSON string as the kwargs
|
||||
(i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
|
||||
requires custom construction from CLI (i.e. `CompilationConfig`), it can
|
||||
have a `from_cli` method, which will be called instead.
|
||||
"""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The decorated class must be a dataclass.")
|
||||
@ -202,7 +207,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
cls_fields = {f.name: f for f in fields(cls)}
|
||||
if name not in cls_fields:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
||||
named_field: Field = cls_fields.get(name)
|
||||
named_field: Field = cls_fields[name]
|
||||
if (default_factory := named_field.default_factory) is not MISSING:
|
||||
return field(default_factory=default_factory)
|
||||
if (default := named_field.default) is not MISSING:
|
||||
@ -211,6 +216,10 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return next(f for f in fields(cls) if f.name == name).init
|
||||
|
||||
|
||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
|
||||
@ -2007,13 +2016,13 @@ class SchedulerConfig:
|
||||
def __post_init__(self) -> None:
|
||||
if self.max_model_len is None:
|
||||
self.max_model_len = 8192
|
||||
logger.warning(
|
||||
logger.warning_once(
|
||||
"max_model_len was is not set. Defaulting to arbitrary value "
|
||||
"of %d.", self.max_model_len)
|
||||
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 128
|
||||
logger.warning(
|
||||
logger.warning_once(
|
||||
"max_num_seqs was is not set. Defaulting to arbitrary value "
|
||||
"of %d.", self.max_num_seqs)
|
||||
|
||||
@ -2840,8 +2849,8 @@ class PromptAdapterConfig:
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
limit_per_prompt: dict[str, int] = get_field(ModelConfig,
|
||||
"limit_mm_per_prompt")
|
||||
limit_per_prompt: dict[str, int] = \
|
||||
cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt"))
|
||||
"""
|
||||
The maximum number of input items allowed per prompt for each modality.
|
||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||
@ -3415,41 +3424,49 @@ class ObservabilityConfig:
|
||||
self.collect_detailed_traces[0].split(","))
|
||||
|
||||
|
||||
class KVTransferConfig(BaseModel):
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
# The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
kv_connector: Optional[str] = None
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
# The device used by kv connector to buffer the KV cache.
|
||||
# Currently only support 'cuda'.
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
# The buffer size for TorchDistributedConnector. Measured in number of
|
||||
# bytes. Recommended value: 1e9 (about 1GB).
|
||||
kv_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
# are 'kv_producer', 'kv_consumer', and 'both'.
|
||||
kv_role: Optional[str] = None
|
||||
kv_role: Optional[KVRole] = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'both'."""
|
||||
|
||||
# The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
# 0 for prefill instance, 1 for decode instance.
|
||||
# Currently only 1P1D is supported.
|
||||
kv_rank: Optional[int] = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
0 for prefill instance, 1 for decode instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
# The number of parallel instances for KV cache transfer. For
|
||||
# PyNcclConnector, this should be 2.
|
||||
kv_parallel_size: int = 1
|
||||
"""The number of parallel instances for KV cache transfer. For
|
||||
PyNcclConnector, this should be 2."""
|
||||
|
||||
# The KV connector ip, used to build distributed connection
|
||||
kv_ip: str = "127.0.0.1"
|
||||
"""The KV connector ip, used to build distributed connection."""
|
||||
|
||||
# The KV connector port, used to build distributed connection
|
||||
kv_port: int = 14579
|
||||
"""The KV connector port, used to build distributed connection."""
|
||||
|
||||
# any extra config that the connector may need
|
||||
kv_connector_extra_config: dict[str, Any] = {}
|
||||
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -3470,46 +3487,37 @@ class KVTransferConfig(BaseModel):
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
|
||||
"""Parse the CLI value for the kv cache transfer config."""
|
||||
return KVTransferConfig.model_validate_json(cli_value)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in [
|
||||
"kv_producer", "kv_consumer", "kv_both"
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are `kv_producer`, `kv_consumer`, "
|
||||
f"and `kv_both`")
|
||||
def __post_init__(self) -> None:
|
||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are {get_args(KVRole)}")
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||
"is set, supported roles are `kv_producer`, "
|
||||
"`kv_consumer`, and `kv_both`")
|
||||
f"is set, supported roles are {get_args(KVRole)}")
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]
|
||||
self.kv_role in get_args(KVRole)
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_producer", "kv_both"]
|
||||
self.kv_role in get_args(KVProducer)
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_consumer", "kv_both"]
|
||||
self.kv_role in get_args(KVConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
|
||||
|
||||
class KVEventsConfig(BaseModel):
|
||||
@config
|
||||
@dataclass
|
||||
class KVEventsConfig:
|
||||
"""Configuration for KV event publishing."""
|
||||
|
||||
enable_kv_cache_events: bool = False
|
||||
@ -3548,11 +3556,6 @@ class KVEventsConfig(BaseModel):
|
||||
this topic to receive events.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
||||
"""Parse the CLI value for the event publisher config."""
|
||||
return KVEventsConfig.model_validate_json(cli_value)
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
@ -3562,80 +3565,72 @@ class CompilationLevel:
|
||||
PIECEWISE = 3
|
||||
|
||||
|
||||
class CompilationConfig(BaseModel):
|
||||
"""
|
||||
Configuration for compilation.
|
||||
It has three parts:
|
||||
@config
|
||||
@dataclass
|
||||
class PassConfig:
|
||||
"""Configuration for custom Inductor passes.
|
||||
|
||||
This is separate from general `CompilationConfig` so that inductor passes
|
||||
don't all have access to full configuration - that would create a cycle as
|
||||
the `PassManager` is set as a property of config."""
|
||||
|
||||
dump_graph_stages: list[str] = field(default_factory=list)
|
||||
"""List of stages for which we want to dump the graph. Each pass defines
|
||||
its own stages (before, after, maybe in-between)."""
|
||||
dump_graph_dir: Path = Path(".")
|
||||
"""Directory to dump the graphs."""
|
||||
# TODO(luka) better pass enabling system.
|
||||
enable_fusion: bool = True
|
||||
"""Whether to enable the custom fusion pass."""
|
||||
enable_noop: bool = True
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
enable_sequence_parallelism: bool = False
|
||||
"""Whether to enable sequence parallelism."""
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
include = {
|
||||
"enable_fusion", "enable_noop", "enable_sequence_parallelism"
|
||||
}
|
||||
dict_ = {k: v for k, v in asdict(self).items() if k in include}
|
||||
return InductorPass.hash_dict(dict_)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.enable_noop and self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class CompilationConfig:
|
||||
"""Configuration for compilation. It has three parts:
|
||||
|
||||
- Top-level Compilation control:
|
||||
- level: the level of compilation.
|
||||
- 0: no compilation.
|
||||
- 1: dynamo as is.
|
||||
- 2: dynamo once.
|
||||
- 3: piecewise compilation.
|
||||
- debug_dump_path: the path to dump the debug information.
|
||||
- cache_dir: the directory to store the compiled graph, to
|
||||
accelerate Inductor compilation. By default, it will use
|
||||
model-related information to generate a cache directory.
|
||||
- backend: the backend for compilation. It needs to be a string.
|
||||
- "" (empty string): use the default backend.
|
||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||
- "full.module.name": a qualified name which can be used to import the backend function.
|
||||
We use string to avoid serialization issues when using compilation in a distributed setting.
|
||||
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
|
||||
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
|
||||
- custom_ops: fine-grained control over which custom ops to enable/disable.
|
||||
Use 'all' to enable all, 'none' to disable all.
|
||||
Also specify a list of custom op names to enable (prefixed with a '+'),
|
||||
or disable (prefixed with a '-').
|
||||
Examples:
|
||||
- 'all,-op1' to enable all except op1
|
||||
- 'none,+op1,+op2' to enable only op1 and op2
|
||||
By default, all custom ops are enabled when running without Inductor
|
||||
and disabled when running with Inductor (compile_level >= Inductor).
|
||||
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
|
||||
- {attr}`level`
|
||||
- {attr}`debug_dump_path`
|
||||
- {attr}`cache_dir`
|
||||
- {attr}`backend`
|
||||
- {attr}`custom_ops`
|
||||
- {attr}`splitting_ops`
|
||||
- CudaGraph capture:
|
||||
- use_cudagraph: whether to use cudagraph inside compilation.
|
||||
- False: cudagraph inside compilation is not used.
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
Note that this is orthogonal to the cudagraph capture logic
|
||||
outside of compilation.
|
||||
TODO: move outside cudagraph logic into compilation.
|
||||
torch.compile will handle cudagraph capture logic in the future.
|
||||
- cudagraph_capture_sizes: sizes to capture cudagraph.
|
||||
- None (default): capture sizes are inferred from vllm config.
|
||||
- list[int]: capture sizes are specified as given.
|
||||
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
|
||||
It means the first several runs will be treated as warmup runs.
|
||||
Only after that, the execution will be recorded, and the recorded
|
||||
cudagraph will be used for subsequent runs.
|
||||
- cudagraph_copy_inputs: whether to copy input tensors for
|
||||
cudagraph. If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True, and the compiler will copy the input to an
|
||||
internally managed buffer. Default is False.
|
||||
- full_cuda_graph: whether to use a full cuda graph for the entire forward
|
||||
pass rather than splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models.
|
||||
- {attr}`use_cudagraph`
|
||||
- {attr}`cudagraph_capture_sizes`
|
||||
- {attr}`cudagraph_num_of_warmups`
|
||||
- {attr}`cudagraph_copy_inputs`
|
||||
- {attr}`full_cuda_graph`
|
||||
- Inductor compilation:
|
||||
- use_inductor: whether to use inductor compilation.
|
||||
- False: inductor compilation is not used. graph runs in eager.
|
||||
- True: inductor compilation is used. one graph for symbolic shape
|
||||
is compiled. In addition, compile for compile_sizes,
|
||||
using configurations in inductor_compile_config.
|
||||
- compile_sizes: sizes to compile for inductor. In addition
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture.
|
||||
- inductor_compile_config: additional configurations for inductor.
|
||||
- None: use default configurations.
|
||||
- inductor_passes: additional passes for inductor. It is a dictionary
|
||||
from pass name to pass function qualified name. We use function
|
||||
name because the config uses json format. If we pass the config
|
||||
from Python, functions can also be passed directly via Python object
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||
- custom inductor passes: see PassConfig for more details
|
||||
- {attr}`use_inductor`
|
||||
- {attr}`compile_sizes`
|
||||
- {attr}`inductor_compile_config`
|
||||
- {attr}`inductor_passes`
|
||||
- custom inductor passes
|
||||
|
||||
Why we have different sizes for cudagraph and inductor:
|
||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||
@ -3646,83 +3641,135 @@ class CompilationConfig(BaseModel):
|
||||
static shapes. However, we find the general shape compilation is
|
||||
sufficient for most cases. It might be beneficial to compile for
|
||||
certain small batchsizes, where inductor is good at optimizing.
|
||||
""" # noqa
|
||||
"""
|
||||
# Top-level Compilation control
|
||||
level: int = 0
|
||||
"""The level of compilation:
|
||||
|
||||
- 0: no compilation.
|
||||
- 1: dynamo as is.
|
||||
- 2: dynamo once.
|
||||
- 3: piecewise compilation."""
|
||||
debug_dump_path: str = ""
|
||||
"""The path to dump the debug information."""
|
||||
cache_dir: str = ""
|
||||
"""The directory to store the compiled graph, to accelerate Inductor
|
||||
compilation. By default, it will use model-related information to generate
|
||||
a cache directory."""
|
||||
backend: str = ""
|
||||
custom_ops: list[str] = Field(default_factory=list)
|
||||
splitting_ops: list[str] = Field(default=None) # type: ignore
|
||||
"""The backend for compilation. It needs to be a string:
|
||||
|
||||
- "" (empty string): use the default backend.
|
||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||
- "full.module.name": a qualified name which can be used to import the
|
||||
|
||||
backend function.
|
||||
We use string to avoid serialization issues when using compilation in a
|
||||
distributed setting. When the compilation level is 1 or 2, the backend is
|
||||
used for the compilation directly (it sees the whole graph). When the
|
||||
compilation level is 3, the backend is used for the piecewise compilation
|
||||
(it sees a part of the graph)."""
|
||||
custom_ops: list[str] = field(default_factory=list)
|
||||
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
|
||||
to enable all, 'none' to disable all. Also specify a list of custom op
|
||||
names to enable (prefixed with a '+'), or disable (prefixed with a '-').
|
||||
Examples:
|
||||
|
||||
- 'all,-op1' to enable all except op1
|
||||
- 'none,+op1,+op2' to enable only op1 and op2
|
||||
|
||||
By default, all custom ops are enabled when running without Inductor and
|
||||
disabled when running with Inductor (compile_level >= Inductor)."""
|
||||
splitting_ops: list[str] = field(default_factory=list)
|
||||
"""A list of ops to split the full graph into subgraphs, used in piecewise
|
||||
compilation."""
|
||||
|
||||
# Inductor capture
|
||||
use_inductor: bool = True
|
||||
compile_sizes: Optional[list[Union[int, str]]] = Field(default=None)
|
||||
inductor_compile_config: dict = Field(default_factory=dict)
|
||||
inductor_passes: dict[str, str] = Field(default_factory=dict)
|
||||
"""Whether to use inductor compilation:
|
||||
|
||||
- False: inductor compilation is not used. graph runs in eager.
|
||||
- True: inductor compilation is used. one graph for symbolic shape
|
||||
is compiled. In addition, compile for compile_sizes,
|
||||
using configurations in inductor_compile_config."""
|
||||
compile_sizes: Optional[list[Union[int, str]]] = None
|
||||
"""Sizes to compile for inductor. In addition
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture."""
|
||||
inductor_compile_config: dict = field(default_factory=dict)
|
||||
"""Additional configurations for inductor.
|
||||
- None: use default configurations."""
|
||||
inductor_passes: dict[str, str] = field(default_factory=dict)
|
||||
"""Additional passes for inductor. It is a dictionary
|
||||
from pass name to pass function qualified name. We use function
|
||||
name because the config uses JSON format. If we pass the config
|
||||
from Python, functions can also be passed directly via Python object
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
||||
|
||||
# CudaGraph compilation
|
||||
use_cudagraph: bool = False
|
||||
"""Whether to use cudagraph inside compilation.
|
||||
- False: cudagraph inside compilation is not used.
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
Note that this is orthogonal to the cudagraph capture logic
|
||||
outside of compilation.
|
||||
TODO: move outside cudagraph logic into compilation.
|
||||
torch.compile will handle cudagraph capture logic in the future."""
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
"""Number of warmup runs for cudagraph.
|
||||
It means the first several runs will be treated as warmup runs.
|
||||
Only after that, the execution will be recorded, and the recorded
|
||||
cudagraph will be used for subsequent runs."""
|
||||
cudagraph_capture_sizes: Optional[list[int]] = None
|
||||
"""Sizes to capture cudagraph.
|
||||
- None (default): capture sizes are inferred from vllm config.
|
||||
- list[int]: capture sizes are specified as given."""
|
||||
cudagraph_copy_inputs: bool = False
|
||||
"""Whether to copy input tensors for
|
||||
cudagraph. If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True, and the compiler will copy the input to an
|
||||
internally managed buffer. Default is False."""
|
||||
full_cuda_graph: bool = False
|
||||
"""whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs. Thus this
|
||||
flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models."""
|
||||
|
||||
class PassConfig(BaseModel):
|
||||
"""
|
||||
Configuration for custom Inductor passes.
|
||||
This is separate from general CompilationConfig so that inductor passes
|
||||
don't all have access to full configuration - that would create a cycle
|
||||
as the PassManager is set as a property of config.
|
||||
- dump_graph_stages: list of stages for which we want to dump the graph.
|
||||
Each pass defines its own stages (before, after, maybe in-between).
|
||||
- dump_graph_dir: directory to dump the graphs. Default is .
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
- enable_noop: whether to enable the custom no-op elimination pass.
|
||||
TODO(luka) better pass enabling system.
|
||||
- enable_sequence_parallelism: whether to enable sequence parallelism.
|
||||
"""
|
||||
dump_graph_stages: list[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
enable_noop: bool = True
|
||||
enable_sequence_parallelism: bool = False
|
||||
pass_config: PassConfig = field(default_factory=PassConfig)
|
||||
"""Custom inductor passes, see PassConfig for more details"""
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
|
||||
"enable_sequence_parallelism"})
|
||||
return InductorPass.hash_dict(dict_)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if not self.enable_noop and self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
|
||||
pass_config: PassConfig = Field(default_factory=PassConfig)
|
||||
|
||||
# not configurable, computed after init
|
||||
max_capture_size: int = PrivateAttr
|
||||
local_cache_dir: str = PrivateAttr # local cache dir for each rank
|
||||
# optimization:
|
||||
# Intuitively, bs_to_padded_graph_size should be dict[int, int].
|
||||
# since we know all keys are in a range [0, max_capture_size],
|
||||
# we can optimize it to list[int] for better lookup performance.
|
||||
bs_to_padded_graph_size: list[int] = PrivateAttr
|
||||
max_capture_size: int = field(default=None, init=False) # type: ignore
|
||||
"""not configurable, computed after init"""
|
||||
local_cache_dir: str = field(default=None, init=False) # type: ignore
|
||||
"""local cache dir for each rank"""
|
||||
bs_to_padded_graph_size: list[int] = field(
|
||||
default=None, # type: ignore
|
||||
init=False)
|
||||
"""optimization:
|
||||
Intuitively, bs_to_padded_graph_size should be dict[int, int].
|
||||
since we know all keys are in a range [0, max_capture_size],
|
||||
we can optimize it to list[int] for better lookup performance."""
|
||||
|
||||
# keep track of enabled and disabled custom ops
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||
traced_files: set[str] = PrivateAttr
|
||||
compilation_time: float = PrivateAttr
|
||||
enabled_custom_ops: Counter[str] = field(default_factory=Counter,
|
||||
init=False)
|
||||
"""custom ops that are enabled"""
|
||||
disabled_custom_ops: Counter[str] = field(default_factory=Counter,
|
||||
init=False)
|
||||
"""custom ops that are disabled"""
|
||||
traced_files: set[str] = field(default_factory=set, init=False)
|
||||
"""files that are traced for compilation"""
|
||||
compilation_time: float = field(default=0.0, init=False)
|
||||
"""time taken for compilation"""
|
||||
|
||||
# Per-model forward context
|
||||
# Map from layer name to layer objects that need to be accessed outside
|
||||
# model code, e.g., Attention, FusedMOE when dp_size>1.
|
||||
static_forward_context: dict[str, Any] = PrivateAttr
|
||||
static_forward_context: dict[str, Any] = field(default_factory=dict,
|
||||
init=False)
|
||||
"""Per-model forward context
|
||||
Map from layer name to layer objects that need to be accessed outside
|
||||
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -3757,7 +3804,17 @@ class CompilationConfig(BaseModel):
|
||||
"pass_config",
|
||||
"traced_files",
|
||||
}
|
||||
return self.model_dump_json(exclude=exclude, exclude_unset=True)
|
||||
include = dict()
|
||||
for k, v in asdict(self).items():
|
||||
if k in exclude:
|
||||
continue
|
||||
f = get_field(CompilationConfig, k)
|
||||
if (d := f.default) is not MISSING and d == v:
|
||||
continue
|
||||
if (df := f.default_factory) is not MISSING and df() == v:
|
||||
continue
|
||||
include[k] = v
|
||||
return json.dumps(include)
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@ -3766,12 +3823,9 @@ class CompilationConfig(BaseModel):
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
if cli_value in ["0", "1", "2", "3"]:
|
||||
return cls(level=int(cli_value))
|
||||
# do not use `eval`, it is dangerous and can execute arbitrary code
|
||||
dict_value = ast.literal_eval(cli_value)
|
||||
return CompilationConfig.model_validate(dict_value)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
return cls(**json.loads(cli_value))
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
count_none = self.custom_ops.count("none")
|
||||
count_all = self.custom_ops.count("all")
|
||||
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||
@ -3789,9 +3843,6 @@ class CompilationConfig(BaseModel):
|
||||
if KEY not in self.inductor_compile_config:
|
||||
self.inductor_compile_config[KEY] = False
|
||||
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = []
|
||||
|
||||
for k, v in self.inductor_passes.items():
|
||||
if not isinstance(v, str):
|
||||
assert callable(v), (
|
||||
@ -3808,11 +3859,8 @@ class CompilationConfig(BaseModel):
|
||||
self.inductor_compile_config[k] = func if isinstance(
|
||||
func, InductorPass) else CallableInductorPass(func)
|
||||
|
||||
self.enabled_custom_ops = Counter()
|
||||
self.disabled_custom_ops = Counter()
|
||||
self.traced_files = set()
|
||||
self.static_forward_context = {}
|
||||
self.compilation_time = 0.0
|
||||
if isinstance(self.pass_config, dict):
|
||||
self.pass_config = PassConfig(**self.pass_config)
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
@ -3899,39 +3947,67 @@ class CompilationConfig(BaseModel):
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class VllmConfig:
|
||||
"""Dataclass which contains all vllm-related configuration. This
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
"""
|
||||
|
||||
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||
parallel_config: ParallelConfig = field(default_factory=ParallelConfig,
|
||||
init=True)
|
||||
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig,
|
||||
init=True)
|
||||
device_config: DeviceConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
model_config: ModelConfig = field(default_factory=ModelConfig)
|
||||
"""Model configuration."""
|
||||
cache_config: CacheConfig = field(default_factory=CacheConfig)
|
||||
"""Cache configuration."""
|
||||
parallel_config: ParallelConfig = field(default_factory=ParallelConfig)
|
||||
"""Parallel configuration."""
|
||||
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig)
|
||||
"""Scheduler configuration."""
|
||||
device_config: DeviceConfig = field(default_factory=DeviceConfig)
|
||||
"""Device configuration."""
|
||||
load_config: LoadConfig = field(default_factory=LoadConfig)
|
||||
"""Load configuration."""
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
speculative_config: SpeculativeConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
"""LoRA configuration."""
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
"""Speculative decoding configuration."""
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
"""Decoding configuration."""
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
"""Observability configuration."""
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||
"""Prompt adapter configuration."""
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
compilation_config: CompilationConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
"""Quantization configuration."""
|
||||
compilation_config: CompilationConfig = field(
|
||||
default_factory=CompilationConfig)
|
||||
"""`torch.compile` configuration for the model.
|
||||
|
||||
When it is a number (0, 1, 2, 3), it will be interpreted as the
|
||||
optimization level.
|
||||
|
||||
NOTE: level 0 is the default level without any optimization. level 1 and 2
|
||||
are for internal testing only. level 3 is the recommended level for
|
||||
production.
|
||||
|
||||
Following the convention of traditional compilers, using `-O` without space
|
||||
is also supported. `-O3` is equivalent to `-O 3`.
|
||||
|
||||
You can specify the full compilation config like so:
|
||||
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||
"""
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: Optional[KVEventsConfig] = None
|
||||
"""The configurations for event publishing."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
additional_config: SupportsHash = field(default=None,
|
||||
init=True) # type: ignore
|
||||
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
|
||||
"""Additional config for specified platform. Different platforms may
|
||||
support different configs. Make sure the configs are valid for the platform
|
||||
you are using. Contents must be hashable."""
|
||||
instance_id: str = ""
|
||||
"""The ID of the vLLM instance."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -4012,7 +4088,14 @@ class VllmConfig:
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
vllm_factors.append(self.additional_config.compute_hash())
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
json.dumps(additional_config, sort_keys=True).encode(),
|
||||
usedforsecurity=False,
|
||||
).hexdigest()
|
||||
else:
|
||||
additional_config_hash = additional_config.compute_hash()
|
||||
vllm_factors.append(additional_config_hash)
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
factors.append(vllm_factors)
|
||||
|
||||
@ -5,6 +5,7 @@ import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from dataclasses import asdict
|
||||
from itertools import count
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Optional, Union
|
||||
@ -284,7 +285,7 @@ class EventPublisherFactory:
|
||||
if not config:
|
||||
return NullEventPublisher()
|
||||
|
||||
config_dict = config.model_dump()
|
||||
config_dict = asdict(config)
|
||||
|
||||
kind = config_dict.pop("publisher", "null")
|
||||
config_dict.pop("enable_kv_cache_events")
|
||||
|
||||
@ -7,10 +7,10 @@ import json
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from dataclasses import MISSING, dataclass, fields, is_dataclass
|
||||
from itertools import permutations
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
|
||||
TypeVar, Union, cast, get_args, get_origin)
|
||||
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
||||
Type, TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
|
||||
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
|
||||
is_in_ray_actor)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object]
|
||||
TypeHintT = Union[type[T], object]
|
||||
|
||||
|
||||
def optional_type(
|
||||
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
||||
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||
|
||||
def _optional_type(val: str) -> Optional[T]:
|
||||
if val == "" or val == "None":
|
||||
return None
|
||||
def _parse_type(val: str) -> T:
|
||||
try:
|
||||
if return_type is json.loads and not re.match("^{.*}$", val):
|
||||
return cast(T, nullable_kvs(val))
|
||||
@ -62,14 +60,24 @@ def optional_type(
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Value {val} cannot be converted to {return_type}.") from e
|
||||
|
||||
return _parse_type
|
||||
|
||||
|
||||
def optional_type(
|
||||
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
||||
|
||||
def _optional_type(val: str) -> Optional[T]:
|
||||
if val == "" or val == "None":
|
||||
return None
|
||||
return parse_type(return_type)(val)
|
||||
|
||||
return _optional_type
|
||||
|
||||
|
||||
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
|
||||
if not re.match("^{.*}$", val):
|
||||
return str(val)
|
||||
else:
|
||||
return optional_type(json.loads)(val)
|
||||
return optional_type(json.loads)(val)
|
||||
|
||||
|
||||
@deprecated(
|
||||
@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
kwargs = {}
|
||||
for field in fields(cls):
|
||||
# Get the set of possible types for the field
|
||||
type_hints: set[TypeHint] = set()
|
||||
if get_origin(field.type) in {Union, Annotated}:
|
||||
type_hints.update(get_args(field.type))
|
||||
else:
|
||||
type_hints.add(field.type)
|
||||
|
||||
# If the field is a dataclass, we can use the model_validate_json
|
||||
generator = (th for th in type_hints if is_dataclass(th))
|
||||
dataclass_cls = next(generator, None)
|
||||
|
||||
# Get the default value of the field
|
||||
default = field.default
|
||||
if field.default_factory is not MISSING:
|
||||
default = field.default_factory()
|
||||
if field.default is not MISSING:
|
||||
default = field.default
|
||||
elif field.default_factory is not MISSING:
|
||||
if is_dataclass(field.default_factory) and is_in_doc_build():
|
||||
default = {}
|
||||
else:
|
||||
default = field.default_factory()
|
||||
|
||||
# Get the help text for the field
|
||||
name = field.name
|
||||
@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
# Initialise the kwargs dictionary for the field
|
||||
kwargs[name] = {"default": default, "help": help}
|
||||
|
||||
# Get the set of possible types for the field
|
||||
type_hints: set[TypeHint] = set()
|
||||
if get_origin(field.type) is Union:
|
||||
type_hints.update(get_args(field.type))
|
||||
else:
|
||||
type_hints.add(field.type)
|
||||
|
||||
# Set other kwargs based on the type hints
|
||||
json_tip = "\n\nShould be a valid JSON string."
|
||||
if contains_type(type_hints, bool):
|
||||
if dataclass_cls is not None:
|
||||
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
|
||||
# Special case for configs with a from_cli method
|
||||
if hasattr(dataclass_cls, "from_cli"):
|
||||
from_cli = dataclass_cls.from_cli
|
||||
dataclass_init = lambda x, f=from_cli: f(x)
|
||||
kwargs[name]["type"] = dataclass_init
|
||||
kwargs[name]["help"] += json_tip
|
||||
elif contains_type(type_hints, bool):
|
||||
# Creates --no-<name> and --<name> flags
|
||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||
elif contains_type(type_hints, Literal):
|
||||
@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
kwargs[name]["type"] = union_dict_and_str
|
||||
elif contains_type(type_hints, dict):
|
||||
# Dict arguments will always be optional
|
||||
kwargs[name]["type"] = optional_type(json.loads)
|
||||
kwargs[name]["type"] = parse_type(json.loads)
|
||||
kwargs[name]["help"] += json_tip
|
||||
elif (contains_type(type_hints, str)
|
||||
or any(is_not_builtin(th) for th in type_hints)):
|
||||
@ -771,63 +795,20 @@ class EngineArgs:
|
||||
scheduler_group.add_argument("--scheduler-cls",
|
||||
**scheduler_kwargs["scheduler_cls"])
|
||||
|
||||
# Compilation arguments
|
||||
# compilation_kwargs = get_kwargs(CompilationConfig)
|
||||
compilation_group = parser.add_argument_group(
|
||||
title="CompilationConfig",
|
||||
description=CompilationConfig.__doc__,
|
||||
)
|
||||
compilation_group.add_argument(
|
||||
"--compilation-config",
|
||||
"-O",
|
||||
type=CompilationConfig.from_cli,
|
||||
default=None,
|
||||
help="torch.compile configuration for the model. "
|
||||
"When it is a number (0, 1, 2, 3), it will be "
|
||||
"interpreted as the optimization level.\n"
|
||||
"NOTE: level 0 is the default level without "
|
||||
"any optimization. level 1 and 2 are for internal "
|
||||
"testing only. level 3 is the recommended level "
|
||||
"for production.\n"
|
||||
"To specify the full compilation config, "
|
||||
"use a JSON string, e.g. ``{\"level\": 3, "
|
||||
"\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n"
|
||||
"Following the convention of traditional "
|
||||
"compilers, using ``-O`` without space is also "
|
||||
"supported. ``-O3`` is equivalent to ``-O 3``.")
|
||||
|
||||
# KVTransfer arguments
|
||||
# kv_transfer_kwargs = get_kwargs(KVTransferConfig)
|
||||
kv_transfer_group = parser.add_argument_group(
|
||||
title="KVTransferConfig",
|
||||
description=KVTransferConfig.__doc__,
|
||||
)
|
||||
kv_transfer_group.add_argument(
|
||||
"--kv-transfer-config",
|
||||
type=KVTransferConfig.from_cli,
|
||||
default=None,
|
||||
help="The configurations for distributed KV cache "
|
||||
"transfer. Should be a JSON string.")
|
||||
kv_transfer_group.add_argument(
|
||||
'--kv-events-config',
|
||||
type=KVEventsConfig.from_cli,
|
||||
default=None,
|
||||
help='The configurations for event publishing.')
|
||||
|
||||
# vLLM arguments
|
||||
# vllm_kwargs = get_kwargs(VllmConfig)
|
||||
vllm_kwargs = get_kwargs(VllmConfig)
|
||||
vllm_group = parser.add_argument_group(
|
||||
title="VllmConfig",
|
||||
description=VllmConfig.__doc__,
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--additional-config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Additional config for specified platform in JSON format. "
|
||||
"Different platforms may support different configs. Make sure the "
|
||||
"configs are valid for the platform you are using. The input format"
|
||||
" is like '{\"config_key\":\"config_value\"}'")
|
||||
vllm_group.add_argument("--kv-transfer-config",
|
||||
**vllm_kwargs["kv_transfer_config"])
|
||||
vllm_group.add_argument('--kv-events-config',
|
||||
**vllm_kwargs["kv_events_config"])
|
||||
vllm_group.add_argument("--compilation-config", "-O",
|
||||
**vllm_kwargs["compilation_config"])
|
||||
vllm_group.add_argument("--additional-config",
|
||||
**vllm_kwargs["additional_config"])
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
|
||||
@ -13,7 +13,8 @@ from typing_extensions import TypeVar, deprecated
|
||||
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
|
||||
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
|
||||
is_init_field)
|
||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||
TaskOption)
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
@ -204,9 +205,13 @@ class LLM:
|
||||
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
|
||||
|
||||
if compilation_config is not None:
|
||||
if isinstance(compilation_config, (int, dict)):
|
||||
compilation_config_instance = CompilationConfig.from_cli(
|
||||
str(compilation_config))
|
||||
if isinstance(compilation_config, int):
|
||||
compilation_config_instance = CompilationConfig(
|
||||
level=compilation_config)
|
||||
elif isinstance(compilation_config, dict):
|
||||
predicate = lambda x: is_init_field(CompilationConfig, x[0])
|
||||
compilation_config_instance = CompilationConfig(
|
||||
**dict(filter(predicate, compilation_config.items())))
|
||||
else:
|
||||
compilation_config_instance = compilation_config
|
||||
else:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from tpu_info import device
|
||||
@ -13,9 +13,10 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
||||
from vllm.pooling_params import PoolingParams
|
||||
else:
|
||||
BlockSize = None
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
@ -94,7 +95,7 @@ class TpuPlatform(Platform):
|
||||
cache_config = vllm_config.cache_config
|
||||
# For v0, the default block size is 16.
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
cache_config.block_size = cast(BlockSize, 16)
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
# TPU only supports DYNAMO_ONCE compilation level
|
||||
@ -118,7 +119,7 @@ class TpuPlatform(Platform):
|
||||
from vllm.v1.attention.backends.pallas import (
|
||||
PallasAttentionBackend)
|
||||
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
||||
vllm_config)
|
||||
vllm_config) # type: ignore[assignment]
|
||||
min_page_size = PallasAttentionBackend.get_min_page_size(
|
||||
vllm_config)
|
||||
if min_page_size > cache_config.block_size:
|
||||
@ -128,7 +129,7 @@ class TpuPlatform(Platform):
|
||||
cache_config.block_size,
|
||||
min_page_size,
|
||||
)
|
||||
cache_config.block_size = min_page_size
|
||||
cache_config.block_size = min_page_size # type: ignore[assignment]
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
@ -1820,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
|
||||
|
||||
def is_in_doc_build() -> bool:
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
return isinstance(zmq, _MockModule)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Import a Python file according to its file path.
|
||||
|
||||
Reference in New Issue
Block a user