Fix per file ruff ignores related to typing (#26254)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -115,6 +115,7 @@ include = ["vllm*"]
|
|||||||
"vllm/distributed/parallel_state.py" = ["SIM108"]
|
"vllm/distributed/parallel_state.py" = ["SIM108"]
|
||||||
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
|
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
|
||||||
"vllm/entrypoints/llm.py" = ["SIM108"]
|
"vllm/entrypoints/llm.py" = ["SIM108"]
|
||||||
|
"vllm/executor/ray_distributed_executor.py" = ["SIM108", "SIM112"]
|
||||||
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
|
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
|
||||||
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
|
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
|
||||||
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
|
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
|
||||||
@ -134,23 +135,6 @@ include = ["vllm*"]
|
|||||||
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
|
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
|
||||||
## Loop variable binding issues
|
## Loop variable binding issues
|
||||||
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
|
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
|
||||||
## Type annotation modernization and other rules
|
|
||||||
"vllm/attention/backends/abstract.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/attention/layer.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/engine/arg_utils.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/engine/metrics.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/engine/metrics_types.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/executor/executor_base.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"]
|
|
||||||
"vllm/executor/ray_utils.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"]
|
|
||||||
"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"]
|
|
||||||
## Type comparison issues
|
|
||||||
"vllm/multimodal/inputs.py" = ["E721"]
|
|
||||||
# End of temporary ignores
|
# End of temporary ignores
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -21,7 +21,7 @@ from vllm.utils import is_torch_equal_or_newer
|
|||||||
from ..utils import create_new_process_for_each_test
|
from ..utils import create_new_process_for_each_test
|
||||||
|
|
||||||
|
|
||||||
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
def models_list(*, all: bool = True, keywords: list[str] | None = None):
|
||||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||||
("facebook/opt-125m", {}),
|
("facebook/opt-125m", {}),
|
||||||
(
|
(
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -233,9 +233,9 @@ class MockModelConfig:
|
|||||||
multimodal_config = MultiModalConfig()
|
multimodal_config = MultiModalConfig()
|
||||||
hf_config = MockHFConfig()
|
hf_config = MockHFConfig()
|
||||||
logits_processor_pattern = None
|
logits_processor_pattern = None
|
||||||
diff_sampling_param: Optional[dict] = None
|
diff_sampling_param: dict | None = None
|
||||||
allowed_local_media_path: str = ""
|
allowed_local_media_path: str = ""
|
||||||
allowed_media_domains: Optional[list[str]] = None
|
allowed_media_domains: list[str] | None = None
|
||||||
encoder_config = None
|
encoder_config = None
|
||||||
generation_config: str = "auto"
|
generation_config: str = "auto"
|
||||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import albumentations
|
import albumentations
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -98,9 +98,9 @@ def _convert_np_uint8(float_image: torch.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
def read_geotiff(
|
def read_geotiff(
|
||||||
file_path: Optional[str] = None,
|
file_path: str | None = None,
|
||||||
path_type: Optional[str] = None,
|
path_type: str | None = None,
|
||||||
file_data: Optional[bytes] = None,
|
file_data: bytes | None = None,
|
||||||
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
|
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
|
||||||
"""Read all bands from *file_path* and return image + meta info.
|
"""Read all bands from *file_path* and return image + meta info.
|
||||||
|
|
||||||
@ -114,8 +114,8 @@ def read_geotiff(
|
|||||||
|
|
||||||
if all([x is None for x in [file_path, path_type, file_data]]):
|
if all([x is None for x in [file_path, path_type, file_data]]):
|
||||||
raise Exception("All input fields to read_geotiff are None")
|
raise Exception("All input fields to read_geotiff are None")
|
||||||
write_to_file: Optional[bytes] = None
|
write_to_file: bytes | None = None
|
||||||
path: Optional[str] = None
|
path: str | None = None
|
||||||
if file_data is not None:
|
if file_data is not None:
|
||||||
# with tempfile.NamedTemporaryFile() as tmpfile:
|
# with tempfile.NamedTemporaryFile() as tmpfile:
|
||||||
# tmpfile.write(file_data)
|
# tmpfile.write(file_data)
|
||||||
@ -162,9 +162,9 @@ def read_geotiff(
|
|||||||
def load_image(
|
def load_image(
|
||||||
data: Union[list[str]],
|
data: Union[list[str]],
|
||||||
path_type: str,
|
path_type: str,
|
||||||
mean: Optional[list[float]] = None,
|
mean: list[float] | None = None,
|
||||||
std: Optional[list[float]] = None,
|
std: list[float] | None = None,
|
||||||
indices: Optional[Union[list[int], None]] = None,
|
indices: Union[list[int], None] | None = None,
|
||||||
):
|
):
|
||||||
"""Build an input example by loading images in *file_paths*.
|
"""Build an input example by loading images in *file_paths*.
|
||||||
|
|
||||||
@ -278,7 +278,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
|||||||
def pre_process(
|
def pre_process(
|
||||||
self,
|
self,
|
||||||
prompt: IOProcessorInput,
|
prompt: IOProcessorInput,
|
||||||
request_id: Optional[str] = None,
|
request_id: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[PromptType, Sequence[PromptType]]:
|
) -> Union[PromptType, Sequence[PromptType]]:
|
||||||
image_data = dict(prompt)
|
image_data = dict(prompt)
|
||||||
@ -359,7 +359,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
|||||||
def post_process(
|
def post_process(
|
||||||
self,
|
self,
|
||||||
model_output: Sequence[PoolingRequestOutput],
|
model_output: Sequence[PoolingRequestOutput],
|
||||||
request_id: Optional[str] = None,
|
request_id: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> IOProcessorOutput:
|
) -> IOProcessorOutput:
|
||||||
pred_imgs_list = []
|
pred_imgs_list = []
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
|
|||||||
|
|
||||||
def _get_test_sampling_params(
|
def _get_test_sampling_params(
|
||||||
prompt_list: list[str],
|
prompt_list: list[str],
|
||||||
seed: Optional[int] = 42,
|
seed: int | None = 42,
|
||||||
structured_outputs: bool = False,
|
structured_outputs: bool = False,
|
||||||
) -> tuple[list[SamplingParams], list[int]]:
|
) -> tuple[list[SamplingParams], list[int]]:
|
||||||
"""Generate random sampling params for a batch."""
|
"""Generate random sampling params for a batch."""
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
|
from typing import Generic, Optional, Protocol, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -48,12 +48,12 @@ class AttentionBackend(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_impl_cls() -> Type["AttentionImpl"]:
|
def get_impl_cls() -> type["AttentionImpl"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -73,11 +73,11 @@ class AttentionBackend(ABC):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
cache_dtype_str: str = "auto",
|
cache_dtype_str: str = "auto",
|
||||||
) -> Tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -147,7 +147,7 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[list[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str = "auto",
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer."""
|
"""Attention layer."""
|
||||||
|
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -126,7 +126,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[list[float]] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
@ -586,7 +586,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
|||||||
|
|
||||||
def maybe_save_kv_layer_to_connector(
|
def maybe_save_kv_layer_to_connector(
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_cache_layer: List[torch.Tensor],
|
kv_cache_layer: list[torch.Tensor],
|
||||||
):
|
):
|
||||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
return
|
return
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
import functools
|
||||||
from typing import ClassVar, List, Optional
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class ChunkedLocalAttention(Attention):
|
|||||||
scale: float,
|
scale: float,
|
||||||
attention_chunk_size: int,
|
attention_chunk_size: int,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[list[float]] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ else:
|
|||||||
_flashmla_extension_C_AVAILABLE = False
|
_flashmla_extension_C_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
def is_flashmla_supported() -> tuple[bool, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Return: is_supported_flag, unsupported_reason (optional).
|
Return: is_supported_flag, unsupported_reason (optional).
|
||||||
"""
|
"""
|
||||||
@ -57,7 +57,7 @@ def get_mla_metadata(
|
|||||||
num_heads_q: Optional[int] = None,
|
num_heads_q: Optional[int] = None,
|
||||||
is_fp8_kvcache: bool = False,
|
is_fp8_kvcache: bool = False,
|
||||||
topk: Optional[int] = None,
|
topk: Optional[int] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||||
@ -101,7 +101,7 @@ def flash_mla_with_kvcache(
|
|||||||
descale_k: Optional[torch.Tensor] = None,
|
descale_k: Optional[torch.Tensor] = None,
|
||||||
is_fp8_kvcache: bool = False,
|
is_fp8_kvcache: bool = False,
|
||||||
indices: Optional[torch.Tensor] = None,
|
indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||||
@ -183,7 +183,7 @@ def flash_mla_sparse_prefill(
|
|||||||
indices: torch.Tensor,
|
indices: torch.Tensor,
|
||||||
sm_scale: float,
|
sm_scale: float,
|
||||||
d_v: int = 512,
|
d_v: int = 512,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Sparse attention prefill kernel
|
Sparse attention prefill kernel
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class PagedAttentionMetadata:
|
|||||||
|
|
||||||
class PagedAttention:
|
class PagedAttention:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_head_sizes() -> List[int]:
|
def get_supported_head_sizes() -> list[int]:
|
||||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -51,7 +51,7 @@ class PagedAttention:
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
cache_dtype_str: str = "auto",
|
cache_dtype_str: str = "auto",
|
||||||
) -> Tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -59,7 +59,7 @@ class PagedAttention:
|
|||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
x = 16 // kv_cache.element_size()
|
x = 16 // kv_cache.element_size()
|
||||||
num_blocks = kv_cache.shape[1]
|
num_blocks = kv_cache.shape[1]
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ class PagedAttention:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def copy_blocks(
|
def copy_blocks(
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: list[torch.Tensor],
|
||||||
src_to_dists: torch.Tensor,
|
src_to_dists: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||||
|
|||||||
@ -14,11 +14,8 @@ from typing import (
|
|||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -325,7 +322,7 @@ class EngineArgs:
|
|||||||
"""Arguments for vLLM engine."""
|
"""Arguments for vLLM engine."""
|
||||||
|
|
||||||
model: str = ModelConfig.model
|
model: str = ModelConfig.model
|
||||||
served_model_name: Optional[Union[str, List[str]]] = ModelConfig.served_model_name
|
served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name
|
||||||
tokenizer: Optional[str] = ModelConfig.tokenizer
|
tokenizer: Optional[str] = ModelConfig.tokenizer
|
||||||
hf_config_path: Optional[str] = ModelConfig.hf_config_path
|
hf_config_path: Optional[str] = ModelConfig.hf_config_path
|
||||||
runner: RunnerOption = ModelConfig.runner
|
runner: RunnerOption = ModelConfig.runner
|
||||||
@ -350,7 +347,7 @@ class EngineArgs:
|
|||||||
# is intended for expert use only. The API may change without
|
# is intended for expert use only. The API may change without
|
||||||
# notice.
|
# notice.
|
||||||
distributed_executor_backend: Optional[
|
distributed_executor_backend: Optional[
|
||||||
Union[str, DistributedExecutorBackend, Type[ExecutorBase]]
|
Union[str, DistributedExecutorBackend, type[ExecutorBase]]
|
||||||
] = ParallelConfig.distributed_executor_backend
|
] = ParallelConfig.distributed_executor_backend
|
||||||
# number of P/D disaggregation (or other disaggregation) workers
|
# number of P/D disaggregation (or other disaggregation) workers
|
||||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||||
@ -418,7 +415,7 @@ class EngineArgs:
|
|||||||
media_io_kwargs: dict[str, dict[str, Any]] = get_field(
|
media_io_kwargs: dict[str, dict[str, Any]] = get_field(
|
||||||
MultiModalConfig, "media_io_kwargs"
|
MultiModalConfig, "media_io_kwargs"
|
||||||
)
|
)
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
|
mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
|
||||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||||
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
||||||
mm_processor_cache_type: Optional[MMCacheType] = (
|
mm_processor_cache_type: Optional[MMCacheType] = (
|
||||||
@ -436,7 +433,7 @@ class EngineArgs:
|
|||||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||||
max_loras: int = LoRAConfig.max_loras
|
max_loras: int = LoRAConfig.max_loras
|
||||||
max_lora_rank: int = LoRAConfig.max_lora_rank
|
max_lora_rank: int = LoRAConfig.max_lora_rank
|
||||||
default_mm_loras: Optional[Dict[str, str]] = LoRAConfig.default_mm_loras
|
default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
|
||||||
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||||
@ -446,7 +443,7 @@ class EngineArgs:
|
|||||||
num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
|
num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
|
||||||
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
|
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
|
||||||
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
|
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
|
||||||
ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns
|
ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns
|
||||||
|
|
||||||
enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
|
enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
|
||||||
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
|
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
|
||||||
@ -467,7 +464,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern
|
logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern
|
||||||
|
|
||||||
speculative_config: Optional[Dict[str, Any]] = None
|
speculative_config: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
show_hidden_metrics_for_version: Optional[str] = (
|
show_hidden_metrics_for_version: Optional[str] = (
|
||||||
ObservabilityConfig.show_hidden_metrics_for_version
|
ObservabilityConfig.show_hidden_metrics_for_version
|
||||||
@ -477,7 +474,7 @@ class EngineArgs:
|
|||||||
ObservabilityConfig.collect_detailed_traces
|
ObservabilityConfig.collect_detailed_traces
|
||||||
)
|
)
|
||||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||||
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls
|
||||||
|
|
||||||
pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
|
pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
|
||||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = (
|
override_pooler_config: Optional[Union[dict, PoolerConfig]] = (
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import Counter as CollectionsCounter
|
from collections import Counter as CollectionsCounter
|
||||||
from typing import Dict, List, Optional, Type, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import prometheus_client
|
import prometheus_client
|
||||||
@ -43,7 +43,7 @@ class Metrics:
|
|||||||
_counter_cls = prometheus_client.Counter
|
_counter_cls = prometheus_client.Counter
|
||||||
_histogram_cls = prometheus_client.Histogram
|
_histogram_cls = prometheus_client.Histogram
|
||||||
|
|
||||||
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
|
def __init__(self, labelnames: list[str], vllm_config: VllmConfig):
|
||||||
# Unregister any existing vLLM collectors (for CI/CD)
|
# Unregister any existing vLLM collectors (for CI/CD)
|
||||||
self._unregister_vllm_metrics()
|
self._unregister_vllm_metrics()
|
||||||
|
|
||||||
@ -304,7 +304,7 @@ class _RayGaugeWrapper:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
documentation: str = "",
|
documentation: str = "",
|
||||||
labelnames: Optional[List[str]] = None,
|
labelnames: Optional[list[str]] = None,
|
||||||
multiprocess_mode: str = "",
|
multiprocess_mode: str = "",
|
||||||
):
|
):
|
||||||
del multiprocess_mode
|
del multiprocess_mode
|
||||||
@ -330,7 +330,7 @@ class _RayCounterWrapper:
|
|||||||
prometheus_client.Counter"""
|
prometheus_client.Counter"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, documentation: str = "", labelnames: Optional[List[str]] = None
|
self, name: str, documentation: str = "", labelnames: Optional[list[str]] = None
|
||||||
):
|
):
|
||||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||||
self._counter = ray_metrics.Counter(
|
self._counter = ray_metrics.Counter(
|
||||||
@ -355,8 +355,8 @@ class _RayHistogramWrapper:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
documentation: str = "",
|
documentation: str = "",
|
||||||
labelnames: Optional[List[str]] = None,
|
labelnames: Optional[list[str]] = None,
|
||||||
buckets: Optional[List[float]] = None,
|
buckets: Optional[list[float]] = None,
|
||||||
):
|
):
|
||||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||||
boundaries = buckets if buckets else []
|
boundaries = buckets if buckets else []
|
||||||
@ -381,17 +381,17 @@ class RayMetrics(Metrics):
|
|||||||
Provides the same metrics as Metrics but uses Ray's util.metrics library.
|
Provides the same metrics as Metrics but uses Ray's util.metrics library.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_gauge_cls: Type[prometheus_client.Gauge] = cast(
|
_gauge_cls: type[prometheus_client.Gauge] = cast(
|
||||||
Type[prometheus_client.Gauge], _RayGaugeWrapper
|
type[prometheus_client.Gauge], _RayGaugeWrapper
|
||||||
)
|
)
|
||||||
_counter_cls: Type[prometheus_client.Counter] = cast(
|
_counter_cls: type[prometheus_client.Counter] = cast(
|
||||||
Type[prometheus_client.Counter], _RayCounterWrapper
|
type[prometheus_client.Counter], _RayCounterWrapper
|
||||||
)
|
)
|
||||||
_histogram_cls: Type[prometheus_client.Histogram] = cast(
|
_histogram_cls: type[prometheus_client.Histogram] = cast(
|
||||||
Type[prometheus_client.Histogram], _RayHistogramWrapper
|
type[prometheus_client.Histogram], _RayHistogramWrapper
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
|
def __init__(self, labelnames: list[str], vllm_config: VllmConfig):
|
||||||
if ray_metrics is None:
|
if ray_metrics is None:
|
||||||
raise ImportError("RayMetrics requires Ray to be installed.")
|
raise ImportError("RayMetrics requires Ray to be installed.")
|
||||||
super().__init__(labelnames, vllm_config)
|
super().__init__(labelnames, vllm_config)
|
||||||
@ -401,14 +401,14 @@ class RayMetrics(Metrics):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
|
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Builds a list of buckets with increasing powers of 10 multiplied by
|
Builds a list of buckets with increasing powers of 10 multiplied by
|
||||||
mantissa values until the value exceeds the specified maximum.
|
mantissa values until the value exceeds the specified maximum.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
exponent = 0
|
exponent = 0
|
||||||
buckets: List[int] = []
|
buckets: list[int] = []
|
||||||
while True:
|
while True:
|
||||||
for m in mantissa_lst:
|
for m in mantissa_lst:
|
||||||
value = m * 10**exponent
|
value = m * 10**exponent
|
||||||
@ -419,7 +419,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
|
|||||||
exponent += 1
|
exponent += 1
|
||||||
|
|
||||||
|
|
||||||
def build_1_2_5_buckets(max_value: int) -> List[int]:
|
def build_1_2_5_buckets(max_value: int) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Example:
|
Example:
|
||||||
>>> build_1_2_5_buckets(100)
|
>>> build_1_2_5_buckets(100)
|
||||||
@ -428,7 +428,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
|
|||||||
return build_buckets([1, 2, 5], max_value)
|
return build_buckets([1, 2, 5], max_value)
|
||||||
|
|
||||||
|
|
||||||
def build_1_2_3_5_8_buckets(max_value: int) -> List[int]:
|
def build_1_2_3_5_8_buckets(max_value: int) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Example:
|
Example:
|
||||||
>>> build_1_2_3_5_8_buckets(100)
|
>>> build_1_2_3_5_8_buckets(100)
|
||||||
@ -442,7 +442,7 @@ def local_interval_elapsed(now: float, last_log: float, local_interval: float) -
|
|||||||
return elapsed_time > local_interval
|
return elapsed_time > local_interval
|
||||||
|
|
||||||
|
|
||||||
def get_throughput(tracked_stats: List[int], now: float, last_log: float) -> float:
|
def get_throughput(tracked_stats: list[int], now: float, last_log: float) -> float:
|
||||||
return float(np.sum(tracked_stats) / (now - last_log))
|
return float(np.sum(tracked_stats) / (now - last_log))
|
||||||
|
|
||||||
|
|
||||||
@ -530,7 +530,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
_gauge_cls = prometheus_client.Gauge
|
_gauge_cls = prometheus_client.Gauge
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, local_interval: float, labels: Dict[str, str], vllm_config: VllmConfig
|
self, local_interval: float, labels: dict[str, str], vllm_config: VllmConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(local_interval, vllm_config)
|
super().__init__(local_interval, vllm_config)
|
||||||
# Prometheus metrics
|
# Prometheus metrics
|
||||||
@ -558,12 +558,12 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
for label, count in data.items():
|
for label, count in data.items():
|
||||||
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
||||||
|
|
||||||
def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
|
def _log_histogram(self, histogram, data: Union[list[int], list[float]]) -> None:
|
||||||
# Convenience function for logging list to histogram.
|
# Convenience function for logging list to histogram.
|
||||||
for datum in data:
|
for datum in data:
|
||||||
histogram.labels(**self.labels).observe(datum)
|
histogram.labels(**self.labels).observe(datum)
|
||||||
|
|
||||||
def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None:
|
def _log_gauge_string(self, gauge, data: dict[str, str]) -> None:
|
||||||
gauge.labels(**data).set_to_current_time()
|
gauge.labels(**data).set_to_current_time()
|
||||||
|
|
||||||
def _log_prometheus(self, stats: Stats) -> None:
|
def _log_prometheus(self, stats: Stats) -> None:
|
||||||
|
|||||||
@ -16,7 +16,6 @@ do this in Python code and lazily import prometheus_client.
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||||
|
|
||||||
@ -43,26 +42,26 @@ class Stats:
|
|||||||
num_prompt_tokens_iter: int
|
num_prompt_tokens_iter: int
|
||||||
num_generation_tokens_iter: int
|
num_generation_tokens_iter: int
|
||||||
num_tokens_iter: int
|
num_tokens_iter: int
|
||||||
time_to_first_tokens_iter: List[float]
|
time_to_first_tokens_iter: list[float]
|
||||||
inter_token_latencies_iter: List[float]
|
inter_token_latencies_iter: list[float]
|
||||||
num_preemption_iter: int
|
num_preemption_iter: int
|
||||||
|
|
||||||
# Request stats (should have _requests suffix)
|
# Request stats (should have _requests suffix)
|
||||||
# Latency
|
# Latency
|
||||||
time_e2e_requests: List[float]
|
time_e2e_requests: list[float]
|
||||||
time_queue_requests: List[float]
|
time_queue_requests: list[float]
|
||||||
time_inference_requests: List[float]
|
time_inference_requests: list[float]
|
||||||
time_prefill_requests: List[float]
|
time_prefill_requests: list[float]
|
||||||
time_decode_requests: List[float]
|
time_decode_requests: list[float]
|
||||||
# Metadata
|
# Metadata
|
||||||
num_prompt_tokens_requests: List[int]
|
num_prompt_tokens_requests: list[int]
|
||||||
num_generation_tokens_requests: List[int]
|
num_generation_tokens_requests: list[int]
|
||||||
n_requests: List[int]
|
n_requests: list[int]
|
||||||
max_num_generation_tokens_requests: List[int]
|
max_num_generation_tokens_requests: list[int]
|
||||||
max_tokens_requests: List[int]
|
max_tokens_requests: list[int]
|
||||||
finished_reason_requests: List[str]
|
finished_reason_requests: list[str]
|
||||||
waiting_lora_adapters: List[str]
|
waiting_lora_adapters: list[str]
|
||||||
running_lora_adapters: List[str]
|
running_lora_adapters: list[str]
|
||||||
max_lora: str
|
max_lora: str
|
||||||
|
|
||||||
|
|
||||||
@ -71,8 +70,8 @@ class StatLoggerBase(ABC):
|
|||||||
|
|
||||||
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
|
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
|
||||||
# Tracked stats over current local logging interval.
|
# Tracked stats over current local logging interval.
|
||||||
self.num_prompt_tokens: List[int] = []
|
self.num_prompt_tokens: list[int] = []
|
||||||
self.num_generation_tokens: List[int] = []
|
self.num_generation_tokens: list[int] = []
|
||||||
self.last_local_log = time.time()
|
self.last_local_log = time.time()
|
||||||
self.local_interval = local_interval
|
self.local_interval = local_interval
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from __future__ import annotations
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
@ -79,13 +79,13 @@ def get_encoding():
|
|||||||
|
|
||||||
|
|
||||||
def get_system_message(
|
def get_system_message(
|
||||||
model_identity: Optional[str] = None,
|
model_identity: str | None = None,
|
||||||
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
|
reasoning_effort: Literal["high", "medium", "low"] | None = None,
|
||||||
start_date: Optional[str] = None,
|
start_date: str | None = None,
|
||||||
browser_description: Optional[str] = None,
|
browser_description: str | None = None,
|
||||||
python_description: Optional[str] = None,
|
python_description: str | None = None,
|
||||||
container_description: Optional[str] = None,
|
container_description: str | None = None,
|
||||||
instructions: Optional[str] = None,
|
instructions: str | None = None,
|
||||||
with_custom_tools: bool = False,
|
with_custom_tools: bool = False,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
sys_msg_content = SystemContent.new()
|
sys_msg_content = SystemContent.new()
|
||||||
@ -137,8 +137,8 @@ def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
|
|||||||
|
|
||||||
|
|
||||||
def get_developer_message(
|
def get_developer_message(
|
||||||
instructions: Optional[str] = None,
|
instructions: str | None = None,
|
||||||
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
|
tools: list[Union[Tool, ChatCompletionToolsParam]] | None = None,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
dev_msg_content = DeveloperContent.new()
|
dev_msg_content = DeveloperContent.new()
|
||||||
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
|
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
|
||||||
@ -202,7 +202,7 @@ def parse_response_input(
|
|||||||
msg = msg.with_channel("final")
|
msg = msg.with_channel("final")
|
||||||
elif response_msg["type"] == "function_call_output":
|
elif response_msg["type"] == "function_call_output":
|
||||||
call_id = response_msg["call_id"]
|
call_id = response_msg["call_id"]
|
||||||
call_response: Optional[ResponseFunctionToolCall] = None
|
call_response: ResponseFunctionToolCall | None = None
|
||||||
for prev_response in reversed(prev_responses):
|
for prev_response in reversed(prev_responses):
|
||||||
if (
|
if (
|
||||||
isinstance(prev_response, ResponseFunctionToolCall)
|
isinstance(prev_response, ResponseFunctionToolCall)
|
||||||
@ -450,7 +450,7 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
|
|||||||
|
|
||||||
def parse_chat_output(
|
def parse_chat_output(
|
||||||
token_ids: Sequence[int],
|
token_ids: Sequence[int],
|
||||||
) -> tuple[Optional[str], Optional[str], bool]:
|
) -> tuple[str | None, str | None, bool]:
|
||||||
parser = parse_output_into_messages(token_ids)
|
parser = parse_output_into_messages(token_ids)
|
||||||
output_msgs = parser.messages
|
output_msgs = parser.messages
|
||||||
is_tool_call = False # TODO: update this when tool call is supported
|
is_tool_call = False # TODO: update this when tool call is supported
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import time
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Callable, List, Optional, Set, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ class ExecutorBase(ABC):
|
|||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
) -> Optional[list[Union[SamplerOutput, PoolerOutput]]]:
|
||||||
output = self.collective_rpc("execute_model", args=(execute_model_req,))
|
output = self.collective_rpc("execute_model", args=(execute_model_req,))
|
||||||
return output[0]
|
return output[0]
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ class ExecutorBase(ABC):
|
|||||||
assert lora_id > 0, "lora_id must be greater than 0."
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
|
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> set[int]:
|
||||||
sets = self.collective_rpc("list_loras")
|
sets = self.collective_rpc("list_loras")
|
||||||
for s in sets:
|
for s in sets:
|
||||||
assert s == sets[0], "All workers should have the same LORAs."
|
assert s == sets[0], "All workers should have the same LORAs."
|
||||||
@ -238,7 +238,7 @@ class ExecutorBase(ABC):
|
|||||||
|
|
||||||
async def execute_model_async(
|
async def execute_model_async(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> List[SamplerOutput]:
|
) -> list[SamplerOutput]:
|
||||||
"""Executes one model step on the given sequences."""
|
"""Executes one model step on the given sequences."""
|
||||||
output = await make_async(self.execute_model)(execute_model_req)
|
output = await make_async(self.execute_model)(execute_model_req)
|
||||||
return output
|
return output
|
||||||
@ -272,7 +272,7 @@ class DistributedExecutorBase(ExecutorBase):
|
|||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
) -> List[SamplerOutput]:
|
) -> list[SamplerOutput]:
|
||||||
# TODO: unify into collective_rpc
|
# TODO: unify into collective_rpc
|
||||||
if self.parallel_worker_tasks is None:
|
if self.parallel_worker_tasks is None:
|
||||||
self.parallel_worker_tasks = self._run_workers(
|
self.parallel_worker_tasks = self._run_workers(
|
||||||
@ -299,7 +299,7 @@ class DistributedExecutorBase(ExecutorBase):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _driver_execute_model(
|
def _driver_execute_model(
|
||||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||||
) -> Optional[List[SamplerOutput]]:
|
) -> Optional[list[SamplerOutput]]:
|
||||||
"""Run execute_model in the driver worker.
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
Passing None will cause the driver to stop the model execution loop
|
Passing None will cause the driver to stop the model execution loop
|
||||||
@ -346,7 +346,7 @@ class DistributedExecutorBase(ExecutorBase):
|
|||||||
|
|
||||||
async def execute_model_async(
|
async def execute_model_async(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> List[SamplerOutput]:
|
) -> list[SamplerOutput]:
|
||||||
if self.parallel_worker_tasks is None:
|
if self.parallel_worker_tasks is None:
|
||||||
# Start model execution loop running in the parallel workers
|
# Start model execution loop running in the parallel workers
|
||||||
self.parallel_worker_tasks = asyncio.create_task(
|
self.parallel_worker_tasks = asyncio.create_task(
|
||||||
@ -371,7 +371,7 @@ class DistributedExecutorBase(ExecutorBase):
|
|||||||
async def _driver_execute_model_async(
|
async def _driver_execute_model_async(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||||
) -> List[SamplerOutput]:
|
) -> list[SamplerOutput]:
|
||||||
"""Execute the model asynchronously in the driver worker.
|
"""Execute the model asynchronously in the driver worker.
|
||||||
|
|
||||||
Passing None will cause the driver to stop the model execution
|
Passing None will cause the driver to stop the model execution
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from array import array
|
from array import array
|
||||||
from typing import Any, Type
|
from typing import Any
|
||||||
|
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs
|
from vllm.multimodal.inputs import MultiModalKwargs
|
||||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
|
||||||
@ -23,7 +23,7 @@ def encode_hook(obj: Any) -> Any:
|
|||||||
return dict(obj)
|
return dict(obj)
|
||||||
|
|
||||||
|
|
||||||
def decode_hook(type: Type, obj: Any) -> Any:
|
def decode_hook(type: type, obj: Any) -> Any:
|
||||||
"""Custom msgspec dec hook that supports array types and MultiModalKwargs.
|
"""Custom msgspec dec hook that supports array types and MultiModalKwargs.
|
||||||
|
|
||||||
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import msgspec
|
import msgspec
|
||||||
@ -114,10 +114,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
self._init_workers_ray(placement_group)
|
self._init_workers_ray(placement_group)
|
||||||
|
|
||||||
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||||
self.output_decoder = msgspec.msgpack.Decoder(Optional[List[SamplerOutput]])
|
self.output_decoder = msgspec.msgpack.Decoder(Optional[list[SamplerOutput]])
|
||||||
self.use_v1 = envs.VLLM_USE_V1
|
self.use_v1 = envs.VLLM_USE_V1
|
||||||
|
|
||||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
self.pp_locks: Optional[list[asyncio.Lock]] = None
|
||||||
if not self.use_ray_compiled_dag:
|
if not self.use_ray_compiled_dag:
|
||||||
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||||
|
|
||||||
@ -137,7 +137,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
ray.kill(worker)
|
ray.kill(worker)
|
||||||
self.forward_dag = None
|
self.forward_dag = None
|
||||||
|
|
||||||
def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]:
|
def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
|
||||||
# If nsight profiling is enabled, we need to set the profiling
|
# If nsight profiling is enabled, we need to set the profiling
|
||||||
# configuration for the ray workers as runtime env.
|
# configuration for the ray workers as runtime env.
|
||||||
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
|
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
|
||||||
@ -164,12 +164,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
# It holds the resource for the driver worker.
|
# It holds the resource for the driver worker.
|
||||||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||||
# The remaining workers are the actual ray actors.
|
# The remaining workers are the actual ray actors.
|
||||||
self.workers: List[RayWorkerWrapper] = []
|
self.workers: list[RayWorkerWrapper] = []
|
||||||
|
|
||||||
# Used in ray compiled DAG: indexed first by PP rank,
|
# Used in ray compiled DAG: indexed first by PP rank,
|
||||||
# and then TP rank. In other words, the inner list is
|
# and then TP rank. In other words, the inner list is
|
||||||
# the TP group of workers for a PP rank.
|
# the TP group of workers for a PP rank.
|
||||||
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
|
self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
|
||||||
|
|
||||||
if self.parallel_config.ray_workers_use_nsight:
|
if self.parallel_config.ray_workers_use_nsight:
|
||||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||||
@ -179,7 +179,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||||
|
|
||||||
# Create the workers.
|
# Create the workers.
|
||||||
bundle_indices: List[int]
|
bundle_indices: list[int]
|
||||||
if envs.VLLM_RAY_BUNDLE_INDICES:
|
if envs.VLLM_RAY_BUNDLE_INDICES:
|
||||||
# Use the bundle indices specified by the user.
|
# Use the bundle indices specified by the user.
|
||||||
bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
|
bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
|
||||||
@ -200,7 +200,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
bundle_indices.append(bundle_id)
|
bundle_indices.append(bundle_id)
|
||||||
bundle_indices = bundle_indices[: self.parallel_config.world_size]
|
bundle_indices = bundle_indices[: self.parallel_config.world_size]
|
||||||
|
|
||||||
worker_metadata: List[RayWorkerMetaData] = []
|
worker_metadata: list[RayWorkerMetaData] = []
|
||||||
driver_ip = get_ip()
|
driver_ip = get_ip()
|
||||||
for rank, bundle_id in enumerate(bundle_indices):
|
for rank, bundle_id in enumerate(bundle_indices):
|
||||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||||
@ -262,7 +262,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
"the driver on a GPU node."
|
"the driver on a GPU node."
|
||||||
)
|
)
|
||||||
|
|
||||||
ip_counts: Dict[str, int] = {}
|
ip_counts: dict[str, int] = {}
|
||||||
for ip in worker_ips:
|
for ip in worker_ips:
|
||||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||||
|
|
||||||
@ -416,11 +416,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||||
# global rank 0. These are the workers that will broadcast to the
|
# global rank 0. These are the workers that will broadcast to the
|
||||||
# rest of the workers.
|
# rest of the workers.
|
||||||
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
self.tp_driver_workers: list[RayWorkerWrapper] = []
|
||||||
# This is the list of workers that are not drivers and not the first
|
# This is the list of workers that are not drivers and not the first
|
||||||
# worker in a TP group. These are the workers that will be
|
# worker in a TP group. These are the workers that will be
|
||||||
# broadcasted to.
|
# broadcasted to.
|
||||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
self.non_driver_workers: list[RayWorkerWrapper] = []
|
||||||
|
|
||||||
# Enforce rank order for correct rank to return final output.
|
# Enforce rank order for correct rank to return final output.
|
||||||
for index, worker in enumerate(self.workers):
|
for index, worker in enumerate(self.workers):
|
||||||
@ -433,7 +433,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
|
|
||||||
def _driver_execute_model(
|
def _driver_execute_model(
|
||||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||||
) -> Optional[List[SamplerOutput]]:
|
) -> Optional[list[SamplerOutput]]:
|
||||||
"""Run execute_model in the driver worker.
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
Passing None will cause the driver to stop the model execution
|
Passing None will cause the driver to stop the model execution
|
||||||
@ -446,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> List[SamplerOutput]:
|
) -> list[SamplerOutput]:
|
||||||
if not self.use_ray_spmd_worker:
|
if not self.use_ray_spmd_worker:
|
||||||
return super().execute_model(execute_model_req)
|
return super().execute_model(execute_model_req)
|
||||||
|
|
||||||
@ -675,7 +675,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
|
|
||||||
async def execute_model_async(
|
async def execute_model_async(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> List[SamplerOutput]:
|
) -> list[SamplerOutput]:
|
||||||
if not self.use_ray_spmd_worker:
|
if not self.use_ray_spmd_worker:
|
||||||
return await super().execute_model_async(execute_model_req)
|
return await super().execute_model_async(execute_model_req)
|
||||||
|
|
||||||
@ -689,7 +689,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
|
|
||||||
async def _driver_execute_model_async(
|
async def _driver_execute_model_async(
|
||||||
self, execute_model_req: Optional[ExecuteModelRequest] = None
|
self, execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
) -> List[SamplerOutput]:
|
) -> list[SamplerOutput]:
|
||||||
assert not self.use_ray_spmd_worker, (
|
assert not self.use_ray_spmd_worker, (
|
||||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
|
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ try:
|
|||||||
def get_node_ip(self) -> str:
|
def get_node_ip(self) -> str:
|
||||||
return get_ip()
|
return get_ip()
|
||||||
|
|
||||||
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
def get_node_and_gpu_ids(self) -> tuple[str, list[int]]:
|
||||||
node_id = ray.get_runtime_context().get_node_id()
|
node_id = ray.get_runtime_context().get_node_id()
|
||||||
device_key = vllm.platforms.current_platform.ray_device_key
|
device_key = vllm.platforms.current_platform.ray_device_key
|
||||||
if not device_key:
|
if not device_key:
|
||||||
@ -72,7 +72,7 @@ try:
|
|||||||
|
|
||||||
def execute_model_spmd(
|
def execute_model_spmd(
|
||||||
self,
|
self,
|
||||||
req_or_tuple: Union[bytes, Tuple[bytes, Optional[IntermediateTensors]]],
|
req_or_tuple: Union[bytes, tuple[bytes, Optional[IntermediateTensors]]],
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Execute model in SPMD fashion: used only when SPMD worker and
|
"""Execute model in SPMD fashion: used only when SPMD worker and
|
||||||
compiled DAG are both enabled.
|
compiled DAG are both enabled.
|
||||||
@ -126,10 +126,10 @@ try:
|
|||||||
def execute_model_ray(
|
def execute_model_ray(
|
||||||
self,
|
self,
|
||||||
scheduler_output: Union[
|
scheduler_output: Union[
|
||||||
"SchedulerOutput", Tuple["SchedulerOutput", "IntermediateTensors"]
|
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
|
||||||
],
|
],
|
||||||
) -> Union[
|
) -> Union[
|
||||||
"ModelRunnerOutput", Tuple["SchedulerOutput", "IntermediateTensors"]
|
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
|
||||||
]:
|
]:
|
||||||
# This method is used by Ray Compiled Graph to execute the model,
|
# This method is used by Ray Compiled Graph to execute the model,
|
||||||
# and it needs a special logic of self.setup_device_if_necessary()
|
# and it needs a special logic of self.setup_device_if_necessary()
|
||||||
@ -156,7 +156,7 @@ try:
|
|||||||
output = output.get_output()
|
output = output.get_output()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def override_env_vars(self, vars: Dict[str, str]):
|
def override_env_vars(self, vars: dict[str, str]):
|
||||||
os.environ.update(vars)
|
os.environ.update(vars)
|
||||||
|
|
||||||
ray_import_err = None
|
ray_import_err = None
|
||||||
@ -201,7 +201,7 @@ def _verify_bundles(
|
|||||||
# bundle_idx -> bundle (e.g., {"GPU": 1})
|
# bundle_idx -> bundle (e.g., {"GPU": 1})
|
||||||
bundles = pg_data["bundles"]
|
bundles = pg_data["bundles"]
|
||||||
# node_id -> List of bundle (e.g., {"GPU": 1})
|
# node_id -> List of bundle (e.g., {"GPU": 1})
|
||||||
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
|
node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list)
|
||||||
|
|
||||||
for bundle_idx, node_id in bundle_to_node_ids.items():
|
for bundle_idx, node_id in bundle_to_node_ids.items():
|
||||||
node_id_to_bundle[node_id].append(bundles[bundle_idx])
|
node_id_to_bundle[node_id].append(bundles[bundle_idx])
|
||||||
@ -383,7 +383,7 @@ def initialize_ray_cluster(
|
|||||||
device_str,
|
device_str,
|
||||||
)
|
)
|
||||||
# Create a new placement group
|
# Create a new placement group
|
||||||
placement_group_specs: List[Dict[str, float]] = [
|
placement_group_specs: list[dict[str, float]] = [
|
||||||
{device_str: 1.0} for _ in range(parallel_config.world_size)
|
{device_str: 1.0} for _ in range(parallel_config.world_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import os
|
|||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from multiprocessing import Lock
|
from multiprocessing import Lock
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -68,10 +68,10 @@ class UniProcExecutor(ExecutorBase):
|
|||||||
self,
|
self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: tuple = (),
|
||||||
kwargs: Optional[Dict] = None,
|
kwargs: Optional[dict] = None,
|
||||||
non_block: bool = False,
|
non_block: bool = False,
|
||||||
) -> List[Any]:
|
) -> list[Any]:
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if self.mm_receiver_cache is not None and method == "execute_model":
|
if self.mm_receiver_cache is not None and method == "execute_model":
|
||||||
@ -158,7 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
|||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
return distributed_init_method, rank, local_rank
|
return distributed_init_method, rank, local_rank
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Determine the number of available KV blocks.
|
Determine the number of available KV blocks.
|
||||||
Add an additional all_reduce to get the min across all ranks.
|
Add an additional all_reduce to get the min across all ranks.
|
||||||
|
|||||||
@ -1,9 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import (
|
from typing import Optional
|
||||||
List, # noqa: UP035
|
|
||||||
Optional,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -32,7 +29,7 @@ def flashinfer_fused_moe_blockscale_fp8(
|
|||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
expert_offset: int,
|
expert_offset: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
block_shape: List[int], # noqa: UP006
|
block_shape: list[int],
|
||||||
routed_scaling: float = 1.0,
|
routed_scaling: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||||
|
|||||||
@ -289,7 +289,7 @@ class MultiModalFieldElem:
|
|||||||
return (
|
return (
|
||||||
(self.modality, self.key) == (other.modality, other.key)
|
(self.modality, self.key) == (other.modality, other.key)
|
||||||
and data_equal
|
and data_equal
|
||||||
and type(self.field) == type(other.field)
|
and type(self.field) is type(other.field)
|
||||||
) # noqa: E721
|
) # noqa: E721
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.plugins import load_plugins_by_group
|
from vllm.plugins import load_plugins_by_group
|
||||||
@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_io_processor(
|
def get_io_processor(
|
||||||
vllm_config: VllmConfig, plugin_from_init: Optional[str] = None
|
vllm_config: VllmConfig, plugin_from_init: str | None = None
|
||||||
) -> IOProcessor | None:
|
) -> IOProcessor | None:
|
||||||
# Input.Output processors are loaded as plugins under the
|
# Input.Output processors are loaded as plugins under the
|
||||||
# 'vllm.io_processor_plugins' group. Similar to platform
|
# 'vllm.io_processor_plugins' group. Similar to platform
|
||||||
|
|||||||
@ -68,7 +68,6 @@ from typing import (
|
|||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
|
||||||
TextIO,
|
TextIO,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -247,9 +246,7 @@ class CacheInfo(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
||||||
def __init__(
|
def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
|
||||||
self, capacity: float, getsizeof: Optional[Callable[[_V], float]] = None
|
|
||||||
):
|
|
||||||
super().__init__(capacity, getsizeof)
|
super().__init__(capacity, getsizeof)
|
||||||
|
|
||||||
self.pinned_items = set[_K]()
|
self.pinned_items = set[_K]()
|
||||||
@ -324,15 +321,15 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
self._LRUCache__order[key] = None # type: ignore
|
self._LRUCache__order[key] = None # type: ignore
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get(self, key: _K, /) -> Optional[_V]: ...
|
def get(self, key: _K, /) -> _V | None: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ...
|
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ...
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self, key: _K, /, default: Optional[Union[_V, _T]] = None
|
self, key: _K, /, default: Union[_V, _T] | None = None
|
||||||
) -> Optional[Union[_V, _T]]:
|
) -> Union[_V, _T] | None:
|
||||||
value: Optional[Union[_V, _T]]
|
value: Union[_V, _T] | None
|
||||||
if key in self:
|
if key in self:
|
||||||
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||||
|
|
||||||
@ -350,9 +347,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...
|
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...
|
||||||
|
|
||||||
def pop(
|
def pop(
|
||||||
self, key: _K, default: Optional[Union[_V, _T]] = None
|
self, key: _K, default: Union[_V, _T] | None = None
|
||||||
) -> Optional[Union[_V, _T]]:
|
) -> Union[_V, _T] | None:
|
||||||
value: Optional[Union[_V, _T]]
|
value: Union[_V, _T] | None
|
||||||
if key not in self:
|
if key not in self:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@ -379,7 +376,7 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
"""
|
"""
|
||||||
self.pinned_items.remove(key)
|
self.pinned_items.remove(key)
|
||||||
|
|
||||||
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
|
def _on_remove(self, key: _K, value: _V | None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
||||||
@ -705,7 +702,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def make_async(
|
def make_async(
|
||||||
func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None
|
func: Callable[P, T], executor: concurrent.futures.Executor | None = None
|
||||||
) -> Callable[P, Awaitable[T]]:
|
) -> Callable[P, Awaitable[T]]:
|
||||||
"""Take a blocking function, and run it on in an executor thread.
|
"""Take a blocking function, and run it on in an executor thread.
|
||||||
|
|
||||||
@ -940,7 +937,7 @@ def _get_open_port() -> int:
|
|||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
def find_process_using_port(port: int) -> Optional[psutil.Process]:
|
def find_process_using_port(port: int) -> psutil.Process | None:
|
||||||
# TODO: We can not check for running processes with network
|
# TODO: We can not check for running processes with network
|
||||||
# port on macOS. Therefore, we can not have a full graceful shutdown
|
# port on macOS. Therefore, we can not have a full graceful shutdown
|
||||||
# of vLLM. For now, let's not look for processes in this case.
|
# of vLLM. For now, let's not look for processes in this case.
|
||||||
@ -1025,8 +1022,8 @@ def _generate_random_fp8(
|
|||||||
|
|
||||||
|
|
||||||
def get_kv_cache_torch_dtype(
|
def get_kv_cache_torch_dtype(
|
||||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
cache_dtype: Union[str, torch.dtype] | None,
|
||||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
model_dtype: Union[str, torch.dtype] | None = None,
|
||||||
) -> torch.dtype:
|
) -> torch.dtype:
|
||||||
if isinstance(cache_dtype, str):
|
if isinstance(cache_dtype, str):
|
||||||
if cache_dtype == "auto":
|
if cache_dtype == "auto":
|
||||||
@ -1053,11 +1050,11 @@ def create_kv_caches_with_random_flash(
|
|||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
cache_dtype: Union[str, torch.dtype] | None,
|
||||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
model_dtype: Union[str, torch.dtype] | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
device: Optional[str] = "cuda",
|
device: str | None = "cuda",
|
||||||
cache_layout: Optional[str] = "NHD",
|
cache_layout: str | None = "NHD",
|
||||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -1095,10 +1092,10 @@ def create_kv_caches_with_random(
|
|||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
cache_dtype: Union[str, torch.dtype] | None,
|
||||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
model_dtype: Union[str, torch.dtype] | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
device: Optional[str] = "cuda",
|
device: str | None = "cuda",
|
||||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||||
if cache_dtype == "fp8" and head_size % 16:
|
if cache_dtype == "fp8" and head_size % 16:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1156,7 +1153,7 @@ def is_uva_available() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class DeviceMemoryProfiler:
|
class DeviceMemoryProfiler:
|
||||||
def __init__(self, device: Optional[torch.types.Device] = None):
|
def __init__(self, device: torch.types.Device | None = None):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def current_memory_usage(self) -> float:
|
def current_memory_usage(self) -> float:
|
||||||
@ -1184,7 +1181,7 @@ def make_ndarray_with_pad(
|
|||||||
pad: T,
|
pad: T,
|
||||||
dtype: npt.DTypeLike,
|
dtype: npt.DTypeLike,
|
||||||
*,
|
*,
|
||||||
max_len: Optional[int] = None,
|
max_len: int | None = None,
|
||||||
) -> npt.NDArray:
|
) -> npt.NDArray:
|
||||||
"""
|
"""
|
||||||
Make a padded array from 2D inputs.
|
Make a padded array from 2D inputs.
|
||||||
@ -1209,8 +1206,8 @@ def make_tensor_with_pad(
|
|||||||
pad: T,
|
pad: T,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
*,
|
*,
|
||||||
max_len: Optional[int] = None,
|
max_len: int | None = None,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Union[str, torch.device] | None = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -1405,7 +1402,7 @@ def find_nccl_library() -> str:
|
|||||||
return so_file
|
return so_file
|
||||||
|
|
||||||
|
|
||||||
def find_nccl_include_paths() -> Optional[list[str]]:
|
def find_nccl_include_paths() -> list[str] | None:
|
||||||
"""
|
"""
|
||||||
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
|
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
|
||||||
environment variable, or we find the library file brought by
|
environment variable, or we find the library file brought by
|
||||||
@ -1525,7 +1522,7 @@ F = TypeVar("F", bound=Callable[..., Any])
|
|||||||
def deprecate_args(
|
def deprecate_args(
|
||||||
start_index: int,
|
start_index: int,
|
||||||
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
||||||
additional_message: Optional[str] = None,
|
additional_message: str | None = None,
|
||||||
) -> Callable[[F], F]:
|
) -> Callable[[F], F]:
|
||||||
if not callable(is_deprecated):
|
if not callable(is_deprecated):
|
||||||
is_deprecated = partial(identity, is_deprecated)
|
is_deprecated = partial(identity, is_deprecated)
|
||||||
@ -1565,7 +1562,7 @@ def deprecate_args(
|
|||||||
def deprecate_kwargs(
|
def deprecate_kwargs(
|
||||||
*kws: str,
|
*kws: str,
|
||||||
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
||||||
additional_message: Optional[str] = None,
|
additional_message: str | None = None,
|
||||||
) -> Callable[[F], F]:
|
) -> Callable[[F], F]:
|
||||||
deprecated_kws = set(kws)
|
deprecated_kws = set(kws)
|
||||||
|
|
||||||
@ -1598,7 +1595,7 @@ def deprecate_kwargs(
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
|
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||||
# LRU Cache purposes.
|
# LRU Cache purposes.
|
||||||
|
|
||||||
@ -1746,7 +1743,7 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
|
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
|
||||||
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
|
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
|
||||||
)
|
)
|
||||||
_search_keyword: Optional[str] = None
|
_search_keyword: str | None = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
# Set the default "formatter_class" to SortedHelpFormatter
|
# Set the default "formatter_class" to SortedHelpFormatter
|
||||||
@ -2245,7 +2242,7 @@ def supports_kw(
|
|||||||
|
|
||||||
def get_allowed_kwarg_only_overrides(
|
def get_allowed_kwarg_only_overrides(
|
||||||
callable: Callable[..., object],
|
callable: Callable[..., object],
|
||||||
overrides: Optional[Mapping[str, object]],
|
overrides: Mapping[str, object] | None,
|
||||||
*,
|
*,
|
||||||
requires_kw_only: bool = True,
|
requires_kw_only: bool = True,
|
||||||
allow_var_kwargs: bool = False,
|
allow_var_kwargs: bool = False,
|
||||||
@ -2695,10 +2692,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
|||||||
def direct_register_custom_op(
|
def direct_register_custom_op(
|
||||||
op_name: str,
|
op_name: str,
|
||||||
op_func: Callable,
|
op_func: Callable,
|
||||||
mutates_args: Optional[list[str]] = None,
|
mutates_args: list[str] | None = None,
|
||||||
fake_impl: Optional[Callable] = None,
|
fake_impl: Callable | None = None,
|
||||||
target_lib: Optional[Library] = None,
|
target_lib: Library | None = None,
|
||||||
dispatch_key: Optional[str] = None,
|
dispatch_key: str | None = None,
|
||||||
tags: tuple[torch.Tag, ...] = (),
|
tags: tuple[torch.Tag, ...] = (),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -3016,7 +3013,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]:
|
|||||||
return scheme, host, port
|
return scheme, host, port
|
||||||
|
|
||||||
|
|
||||||
def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str:
|
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
|
||||||
"""Make a ZMQ path from its parts.
|
"""Make a ZMQ path from its parts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -3039,9 +3036,9 @@ def make_zmq_socket(
|
|||||||
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
|
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
|
||||||
path: str,
|
path: str,
|
||||||
socket_type: Any,
|
socket_type: Any,
|
||||||
bind: Optional[bool] = None,
|
bind: bool | None = None,
|
||||||
identity: Optional[bytes] = None,
|
identity: bytes | None = None,
|
||||||
linger: Optional[int] = None,
|
linger: int | None = None,
|
||||||
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
||||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||||
|
|
||||||
@ -3098,9 +3095,9 @@ def make_zmq_socket(
|
|||||||
def zmq_socket_ctx(
|
def zmq_socket_ctx(
|
||||||
path: str,
|
path: str,
|
||||||
socket_type: Any,
|
socket_type: Any,
|
||||||
bind: Optional[bool] = None,
|
bind: bool | None = None,
|
||||||
linger: int = 0,
|
linger: int = 0,
|
||||||
identity: Optional[bytes] = None,
|
identity: bytes | None = None,
|
||||||
) -> Iterator[zmq.Socket]:
|
) -> Iterator[zmq.Socket]:
|
||||||
"""Context manager for a ZMQ socket"""
|
"""Context manager for a ZMQ socket"""
|
||||||
|
|
||||||
@ -3163,7 +3160,7 @@ def get_mp_context():
|
|||||||
def bind_kv_cache(
|
def bind_kv_cache(
|
||||||
ctx: dict[str, Any],
|
ctx: dict[str, Any],
|
||||||
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
|
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
|
||||||
shared_kv_cache_layers: Optional[dict[str, str]] = None,
|
shared_kv_cache_layers: dict[str, str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Bind the kv_cache tensor to Attention modules, similar to
|
# Bind the kv_cache tensor to Attention modules, similar to
|
||||||
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
|
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
|
||||||
@ -3379,7 +3376,7 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def cprofile_context(save_file: Optional[str] = None):
|
def cprofile_context(save_file: str | None = None):
|
||||||
"""Run a cprofile
|
"""Run a cprofile
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -3401,7 +3398,7 @@ def cprofile_context(save_file: Optional[str] = None):
|
|||||||
prof.print_stats(sort="cumtime")
|
prof.print_stats(sort="cumtime")
|
||||||
|
|
||||||
|
|
||||||
def cprofile(save_file: Optional[str] = None, enabled: bool = True):
|
def cprofile(save_file: str | None = None, enabled: bool = True):
|
||||||
"""Decorator to profile a Python method using cProfile.
|
"""Decorator to profile a Python method using cProfile.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -3608,7 +3605,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
|||||||
file.write = write_with_prefix # type: ignore[method-assign]
|
file.write = write_with_prefix # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
def decorate_logs(process_name: Optional[str] = None) -> None:
|
def decorate_logs(process_name: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a process-specific prefix to each line of output written to stdout and
|
Adds a process-specific prefix to each line of output written to stdout and
|
||||||
stderr.
|
stderr.
|
||||||
@ -3631,8 +3628,8 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def length_from_prompt_token_ids_or_embeds(
|
def length_from_prompt_token_ids_or_embeds(
|
||||||
prompt_token_ids: Optional[list[int]],
|
prompt_token_ids: list[int] | None,
|
||||||
prompt_embeds: Optional[torch.Tensor],
|
prompt_embeds: torch.Tensor | None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Calculate the request length (in number of tokens) give either
|
"""Calculate the request length (in number of tokens) give either
|
||||||
prompt_token_ids or prompt_embeds.
|
prompt_token_ids or prompt_embeds.
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from __future__ import annotations
|
|||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, NoReturn, Optional
|
from typing import Any, Callable, NoReturn
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -325,7 +325,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|||||||
def should_use_deepgemm_for_fp8_linear(
|
def should_use_deepgemm_for_fp8_linear(
|
||||||
output_dtype: torch.dtype,
|
output_dtype: torch.dtype,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
supports_deep_gemm: Optional[bool] = None,
|
supports_deep_gemm: bool | None = None,
|
||||||
):
|
):
|
||||||
if supports_deep_gemm is None:
|
if supports_deep_gemm is None:
|
||||||
supports_deep_gemm = is_deep_gemm_supported()
|
supports_deep_gemm = is_deep_gemm_supported()
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import functools
|
|||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, NoReturn, Optional
|
from typing import Any, Callable, NoReturn
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
@ -202,14 +202,14 @@ def supports_trtllm_attention() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
|
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
|
||||||
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
||||||
if env_value is not None:
|
if env_value is not None:
|
||||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||||
return env_value
|
return env_value
|
||||||
|
|
||||||
|
|
||||||
def force_use_trtllm_attention() -> Optional[bool]:
|
def force_use_trtllm_attention() -> bool | None:
|
||||||
"""
|
"""
|
||||||
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
||||||
return ``True`` if TRTLLM attention is forced to be used,
|
return ``True`` if TRTLLM attention is forced to be used,
|
||||||
@ -401,7 +401,7 @@ def flashinfer_scaled_fp8_mm(
|
|||||||
scale_a: torch.Tensor,
|
scale_a: torch.Tensor,
|
||||||
scale_b: torch.Tensor,
|
scale_b: torch.Tensor,
|
||||||
out_dtype: torch.dtype,
|
out_dtype: torch.dtype,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert a.ndim == 2 and b.ndim == 2
|
assert a.ndim == 2 and b.ndim == 2
|
||||||
assert a.shape[1] == b.shape[0]
|
assert a.shape[1] == b.shape[0]
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional, Union
|
from typing import ClassVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -254,12 +254,12 @@ class FlashInferMetadata:
|
|||||||
# For cascade attention (CPU for planning).
|
# For cascade attention (CPU for planning).
|
||||||
use_cascade: bool
|
use_cascade: bool
|
||||||
|
|
||||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None
|
||||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
|
||||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
|
||||||
|
|
||||||
qo_indptr_gpu: Optional[torch.Tensor] = None
|
qo_indptr_gpu: torch.Tensor | None = None
|
||||||
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
|
paged_kv_indptr_gpu: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||||
@ -727,13 +727,13 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
alibi_slopes: Optional[list[float]],
|
alibi_slopes: list[float] | None,
|
||||||
sliding_window: Optional[int],
|
sliding_window: int | None,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: float | None = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[int] = None,
|
kv_sharing_target_layer_name: int | None = None,
|
||||||
sinks: Optional[torch.Tensor] = None,
|
sinks: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -763,7 +763,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
"FlashInferImpl"
|
"FlashInferImpl"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sinks: Optional[torch.Tensor] = None
|
self.sinks: torch.Tensor | None = None
|
||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
if sinks.shape[0] != num_heads:
|
if sinks.shape[0] != num_heads:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -776,9 +776,9 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
self.support_trtllm_attn = (
|
self.support_trtllm_attn = (
|
||||||
supports_trtllm_attention() and num_heads % num_kv_heads == 0
|
supports_trtllm_attention() and num_heads % num_kv_heads == 0
|
||||||
)
|
)
|
||||||
self.bmm1_scale: Optional[float] = None
|
self.bmm1_scale: float | None = None
|
||||||
self.bmm2_scale: Optional[float] = None
|
self.bmm2_scale: float | None = None
|
||||||
self.o_sf_scale: Optional[float] = None
|
self.o_sf_scale: float | None = None
|
||||||
|
|
||||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
return (
|
return (
|
||||||
@ -795,9 +795,9 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: FlashInferMetadata,
|
attn_metadata: FlashInferMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: torch.Tensor | None = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: Optional[torch.Tensor] = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashInfer.
|
"""Forward pass with FlashInfer.
|
||||||
|
|
||||||
@ -1093,13 +1093,13 @@ def fast_plan_decode(
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
pos_encoding_mode: str = "NONE",
|
pos_encoding_mode: str = "NONE",
|
||||||
window_left: int = -1,
|
window_left: int = -1,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: float | None = None,
|
||||||
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
|
q_data_type: Union[str, torch.dtype] | None = "float16",
|
||||||
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
kv_data_type: Union[str, torch.dtype] | None = None,
|
||||||
data_type: Optional[Union[str, torch.dtype]] = None,
|
data_type: Union[str, torch.dtype] | None = None,
|
||||||
sm_scale: Optional[float] = None,
|
sm_scale: float | None = None,
|
||||||
rope_scale: Optional[float] = None,
|
rope_scale: float | None = None,
|
||||||
rope_theta: Optional[float] = None,
|
rope_theta: float | None = None,
|
||||||
non_blocking: bool = True,
|
non_blocking: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from vllm._bc_linter import bc_linter_include
|
from vllm._bc_linter import bc_linter_include
|
||||||
|
|
||||||
@ -25,14 +25,14 @@ if TYPE_CHECKING:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class NewRequestData:
|
class NewRequestData:
|
||||||
req_id: str
|
req_id: str
|
||||||
prompt_token_ids: Optional[list[int]]
|
prompt_token_ids: list[int] | None
|
||||||
mm_features: list[MultiModalFeatureSpec]
|
mm_features: list[MultiModalFeatureSpec]
|
||||||
sampling_params: Optional[SamplingParams]
|
sampling_params: SamplingParams | None
|
||||||
pooling_params: Optional[PoolingParams]
|
pooling_params: PoolingParams | None
|
||||||
block_ids: tuple[list[int], ...]
|
block_ids: tuple[list[int], ...]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: int
|
||||||
lora_request: Optional[LoRARequest]
|
lora_request: LoRARequest | None
|
||||||
prompt_embeds: Optional[torch.Tensor] = None
|
prompt_embeds: torch.Tensor | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_request(
|
def from_request(
|
||||||
@ -98,7 +98,7 @@ class CachedRequestData:
|
|||||||
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
||||||
# When PP is not used, new_token_ids will be empty.
|
# When PP is not used, new_token_ids will be empty.
|
||||||
new_token_ids: list[list[int]]
|
new_token_ids: list[list[int]]
|
||||||
new_block_ids: list[Optional[tuple[list[int], ...]]]
|
new_block_ids: list[tuple[list[int], ...] | None]
|
||||||
num_computed_tokens: list[int]
|
num_computed_tokens: list[int]
|
||||||
num_output_tokens: list[int]
|
num_output_tokens: list[int]
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ class SchedulerOutput:
|
|||||||
# for filling the next token bitmask
|
# for filling the next token bitmask
|
||||||
structured_output_request_ids: dict[str, int]
|
structured_output_request_ids: dict[str, int]
|
||||||
# the bitmask for the whole batch
|
# the bitmask for the whole batch
|
||||||
grammar_bitmask: Optional[npt.NDArray[np.int32]]
|
grammar_bitmask: npt.NDArray[np.int32] | None
|
||||||
|
|
||||||
# KV Cache Connector metadata.
|
# KV Cache Connector metadata.
|
||||||
kv_connector_metadata: Optional[KVConnectorMetadata] = None
|
kv_connector_metadata: KVConnectorMetadata | None = None
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import itertools
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
@ -64,7 +64,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
# request ids should be included in the EngineCoreOutputs returned
|
# request ids should be included in the EngineCoreOutputs returned
|
||||||
# by update_from_outputs(). This is currently used in the multi-engine
|
# by update_from_outputs(). This is currently used in the multi-engine
|
||||||
# case to track request lifetimes efficiently.
|
# case to track request lifetimes efficiently.
|
||||||
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
|
self.finished_req_ids_dict: dict[int, set[str]] | None = (
|
||||||
defaultdict(set) if include_finished_set else None
|
defaultdict(set) if include_finished_set else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -708,7 +708,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
) -> CachedRequestData:
|
) -> CachedRequestData:
|
||||||
req_ids: list[str] = []
|
req_ids: list[str] = []
|
||||||
new_token_ids: list[list[int]] = []
|
new_token_ids: list[list[int]] = []
|
||||||
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
|
new_block_ids: list[tuple[list[int], ...] | None] = []
|
||||||
num_computed_tokens: list[int] = []
|
num_computed_tokens: list[int] = []
|
||||||
num_output_tokens: list[int] = []
|
num_output_tokens: list[int] = []
|
||||||
|
|
||||||
@ -921,7 +921,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
kv_connector_output = model_runner_output.kv_connector_output
|
kv_connector_output = model_runner_output.kv_connector_output
|
||||||
|
|
||||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
spec_decoding_stats: SpecDecodingStats | None = None
|
||||||
kv_connector_stats = (
|
kv_connector_stats = (
|
||||||
kv_connector_output.kv_connector_stats if kv_connector_output else None
|
kv_connector_output.kv_connector_stats if kv_connector_output else None
|
||||||
)
|
)
|
||||||
@ -1212,7 +1212,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
request.status = finished_status
|
request.status = finished_status
|
||||||
self._free_request(request)
|
self._free_request(request)
|
||||||
|
|
||||||
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
|
def _free_request(self, request: Request) -> dict[str, Any] | None:
|
||||||
assert request.is_finished()
|
assert request.is_finished()
|
||||||
|
|
||||||
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
||||||
@ -1243,9 +1243,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
def make_stats(
|
def make_stats(
|
||||||
self,
|
self,
|
||||||
spec_decoding_stats: Optional[SpecDecodingStats] = None,
|
spec_decoding_stats: SpecDecodingStats | None = None,
|
||||||
kv_connector_stats: Optional[KVConnectorStats] = None,
|
kv_connector_stats: KVConnectorStats | None = None,
|
||||||
) -> Optional[SchedulerStats]:
|
) -> SchedulerStats | None:
|
||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return None
|
return None
|
||||||
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
|
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
|
||||||
@ -1262,10 +1262,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
def make_spec_decoding_stats(
|
def make_spec_decoding_stats(
|
||||||
self,
|
self,
|
||||||
spec_decoding_stats: Optional[SpecDecodingStats],
|
spec_decoding_stats: SpecDecodingStats | None,
|
||||||
num_draft_tokens: int,
|
num_draft_tokens: int,
|
||||||
num_accepted_tokens: int,
|
num_accepted_tokens: int,
|
||||||
) -> Optional[SpecDecodingStats]:
|
) -> SpecDecodingStats | None:
|
||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return None
|
return None
|
||||||
if spec_decoding_stats is None:
|
if spec_decoding_stats is None:
|
||||||
@ -1285,12 +1285,12 @@ class Scheduler(SchedulerInterface):
|
|||||||
# KV Connector Related Methods
|
# KV Connector Related Methods
|
||||||
########################################################################
|
########################################################################
|
||||||
|
|
||||||
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
|
def get_kv_connector(self) -> KVConnectorBase_V1 | None:
|
||||||
return self.connector
|
return self.connector
|
||||||
|
|
||||||
def _connector_finished(
|
def _connector_finished(
|
||||||
self, request: Request
|
self, request: Request
|
||||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
) -> tuple[bool, dict[str, Any] | None]:
|
||||||
"""
|
"""
|
||||||
Invoke the KV connector request_finished() method if applicable.
|
Invoke the KV connector request_finished() method if applicable.
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -35,11 +35,11 @@ class StructuredOutputManager:
|
|||||||
"""Engine-level manager for structured output requests."""
|
"""Engine-level manager for structured output requests."""
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
self.backend: Optional[StructuredOutputBackend] = None
|
self.backend: StructuredOutputBackend | None = None
|
||||||
self.reasoner: Optional[ReasoningParser] = None
|
self.reasoner: ReasoningParser | None = None
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
self._grammar_bitmask: torch.Tensor | None = None
|
||||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||||
|
|
||||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||||
@ -168,7 +168,7 @@ class StructuredOutputManager:
|
|||||||
requests: dict[str, Request],
|
requests: dict[str, Request],
|
||||||
structured_output_request_ids: dict[str, int],
|
structured_output_request_ids: dict[str, int],
|
||||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||||
) -> Optional[npt.NDArray[np.int32]]:
|
) -> npt.NDArray[np.int32] | None:
|
||||||
# Prepare the structured output bitmask for this batch.
|
# Prepare the structured output bitmask for this batch.
|
||||||
if not structured_output_request_ids:
|
if not structured_output_request_ids:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -252,7 +252,7 @@ def serialize_guidance_grammar(
|
|||||||
|
|
||||||
|
|
||||||
def validate_guidance_grammar(
|
def validate_guidance_grammar(
|
||||||
sampling_params: SamplingParams, tokenizer: Optional[llguidance.LLTokenizer] = None
|
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
tp, grm = get_structured_output_key(sampling_params)
|
tp, grm = get_structured_output_key(sampling_params)
|
||||||
guidance_grm = serialize_guidance_grammar(tp, grm)
|
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||||
|
|||||||
@ -20,10 +20,10 @@ from vllm.v1.structured_output.backend_types import (
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class StructuredOutputRequest:
|
class StructuredOutputRequest:
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
_grammar: Optional[
|
_grammar: Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] | None = (
|
||||||
Union[Future[StructuredOutputGrammar], StructuredOutputGrammar]
|
None
|
||||||
] = None
|
)
|
||||||
reasoning_ended: Optional[bool] = None
|
reasoning_ended: bool | None = None
|
||||||
|
|
||||||
def _check_grammar_completion(self) -> bool:
|
def _check_grammar_completion(self) -> bool:
|
||||||
# NOTE: We have to lazy import to gate circular imports
|
# NOTE: We have to lazy import to gate circular imports
|
||||||
@ -43,7 +43,7 @@ class StructuredOutputRequest:
|
|||||||
return self._check_grammar_completion()
|
return self._check_grammar_completion()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def grammar(self) -> Optional[StructuredOutputGrammar]:
|
def grammar(self) -> StructuredOutputGrammar | None:
|
||||||
completed = self._check_grammar_completion()
|
completed = self._check_grammar_completion()
|
||||||
return (
|
return (
|
||||||
cast(Optional[StructuredOutputGrammar], self._grammar)
|
cast(Optional[StructuredOutputGrammar], self._grammar)
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -78,8 +78,8 @@ class WorkerBase:
|
|||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
# Device and model state
|
# Device and model state
|
||||||
self.device: Optional[torch.device] = None
|
self.device: torch.device | None = None
|
||||||
self.model_runner: Optional[nn.Module] = None
|
self.model_runner: nn.Module | None = None
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||||
"""Get specifications for KV cache implementation."""
|
"""Get specifications for KV cache implementation."""
|
||||||
@ -115,8 +115,8 @@ class WorkerBase:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, execute_model_req: Optional[ExecuteModelRequest] = None
|
self, execute_model_req: ExecuteModelRequest | None = None
|
||||||
) -> Optional[list[SamplerOutput]]:
|
) -> list[SamplerOutput] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def start_worker_execution_loop(self) -> None:
|
def start_worker_execution_loop(self) -> None:
|
||||||
@ -198,8 +198,8 @@ class WorkerWrapperBase:
|
|||||||
group.
|
group.
|
||||||
"""
|
"""
|
||||||
self.rpc_rank = rpc_rank
|
self.rpc_rank = rpc_rank
|
||||||
self.worker: Optional[WorkerBase] = None
|
self.worker: WorkerBase | None = None
|
||||||
self.vllm_config: Optional[VllmConfig] = None
|
self.vllm_config: VllmConfig | None = None
|
||||||
# do not store this `vllm_config`, `init_worker` will set the final
|
# do not store this `vllm_config`, `init_worker` will set the final
|
||||||
# one. TODO: investigate if we can remove this field in
|
# one. TODO: investigate if we can remove this field in
|
||||||
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
|
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
|
||||||
|
|||||||
Reference in New Issue
Block a user