Fix per file ruff ignores related to typing (#26254)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 17:37:55 +01:00
committed by GitHub
parent 5f317530ec
commit 1c0c68202c
32 changed files with 258 additions and 285 deletions

View File

@ -115,6 +115,7 @@ include = ["vllm*"]
"vllm/distributed/parallel_state.py" = ["SIM108"] "vllm/distributed/parallel_state.py" = ["SIM108"]
"vllm/entrypoints/chat_utils.py" = ["SIM108"] "vllm/entrypoints/chat_utils.py" = ["SIM108"]
"vllm/entrypoints/llm.py" = ["SIM108"] "vllm/entrypoints/llm.py" = ["SIM108"]
"vllm/executor/ray_distributed_executor.py" = ["SIM108", "SIM112"]
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"] "vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"] "vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"] "vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
@ -134,23 +135,6 @@ include = ["vllm*"]
"tools/profiler/print_layerwise_table.py" = ["SIM118"] "tools/profiler/print_layerwise_table.py" = ["SIM118"]
## Loop variable binding issues ## Loop variable binding issues
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"] "tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
## Type annotation modernization and other rules
"vllm/attention/backends/abstract.py" = ["UP035", "UP006"]
"vllm/attention/layer.py" = ["UP035", "UP006"]
"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"]
"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"]
"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"]
"vllm/engine/arg_utils.py" = ["UP035", "UP006"]
"vllm/engine/metrics.py" = ["UP035", "UP006"]
"vllm/engine/metrics_types.py" = ["UP035", "UP006"]
"vllm/executor/executor_base.py" = ["UP035", "UP006"]
"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"]
"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"]
"vllm/executor/ray_utils.py" = ["UP035", "UP006"]
"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"]
"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"]
## Type comparison issues
"vllm/multimodal/inputs.py" = ["E721"]
# End of temporary ignores # End of temporary ignores
[tool.ruff.lint] [tool.ruff.lint]

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging import logging
import tempfile import tempfile
from typing import Any, Optional, Union from typing import Any, Union
import pytest import pytest
import torch import torch
@ -21,7 +21,7 @@ from vllm.utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): def models_list(*, all: bool = True, keywords: list[str] | None = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
( (

View File

@ -6,7 +6,7 @@ from __future__ import annotations
import asyncio import asyncio
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@ -233,9 +233,9 @@ class MockModelConfig:
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: dict | None = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None allowed_media_domains: list[str] | None = None
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)

View File

@ -9,7 +9,7 @@ import os
import tempfile import tempfile
import urllib.request import urllib.request
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional, Union from typing import Any, Union
import albumentations import albumentations
import numpy as np import numpy as np
@ -98,9 +98,9 @@ def _convert_np_uint8(float_image: torch.Tensor):
def read_geotiff( def read_geotiff(
file_path: Optional[str] = None, file_path: str | None = None,
path_type: Optional[str] = None, path_type: str | None = None,
file_data: Optional[bytes] = None, file_data: bytes | None = None,
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]: ) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
"""Read all bands from *file_path* and return image + meta info. """Read all bands from *file_path* and return image + meta info.
@ -114,8 +114,8 @@ def read_geotiff(
if all([x is None for x in [file_path, path_type, file_data]]): if all([x is None for x in [file_path, path_type, file_data]]):
raise Exception("All input fields to read_geotiff are None") raise Exception("All input fields to read_geotiff are None")
write_to_file: Optional[bytes] = None write_to_file: bytes | None = None
path: Optional[str] = None path: str | None = None
if file_data is not None: if file_data is not None:
# with tempfile.NamedTemporaryFile() as tmpfile: # with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(file_data) # tmpfile.write(file_data)
@ -162,9 +162,9 @@ def read_geotiff(
def load_image( def load_image(
data: Union[list[str]], data: Union[list[str]],
path_type: str, path_type: str,
mean: Optional[list[float]] = None, mean: list[float] | None = None,
std: Optional[list[float]] = None, std: list[float] | None = None,
indices: Optional[Union[list[int], None]] = None, indices: Union[list[int], None] | None = None,
): ):
"""Build an input example by loading images in *file_paths*. """Build an input example by loading images in *file_paths*.
@ -278,7 +278,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
def pre_process( def pre_process(
self, self,
prompt: IOProcessorInput, prompt: IOProcessorInput,
request_id: Optional[str] = None, request_id: str | None = None,
**kwargs, **kwargs,
) -> Union[PromptType, Sequence[PromptType]]: ) -> Union[PromptType, Sequence[PromptType]]:
image_data = dict(prompt) image_data = dict(prompt)
@ -359,7 +359,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
def post_process( def post_process(
self, self,
model_output: Sequence[PoolingRequestOutput], model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None, request_id: str | None = None,
**kwargs, **kwargs,
) -> IOProcessorOutput: ) -> IOProcessorOutput:
pred_imgs_list = [] pred_imgs_list = []

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import random import random
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import pytest import pytest
@ -78,7 +78,7 @@ def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
def _get_test_sampling_params( def _get_test_sampling_params(
prompt_list: list[str], prompt_list: list[str],
seed: Optional[int] = 42, seed: int | None = 42,
structured_outputs: bool = False, structured_outputs: bool = False,
) -> tuple[list[SamplingParams], list[int]]: ) -> tuple[list[SamplingParams], list[int]]:
"""Generate random sampling params for a batch.""" """Generate random sampling params for a batch."""

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar from typing import Generic, Optional, Protocol, TypeVar
import torch import torch
@ -48,12 +48,12 @@ class AttentionBackend(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]: def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -73,11 +73,11 @@ class AttentionBackend(ABC):
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto", cache_dtype_str: str = "auto",
) -> Tuple[int, ...]: ) -> tuple[int, ...]:
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]: def get_kv_cache_stride_order() -> tuple[int, ...]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -147,7 +147,7 @@ class AttentionImpl(ABC, Generic[T]):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[list[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto", kv_cache_dtype: str = "auto",
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer.""" """Attention layer."""
from typing import Callable, List, Optional from typing import Callable, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -126,7 +126,7 @@ class Attention(nn.Module, AttentionLayerBase):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[list[float]] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
@ -586,7 +586,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
def maybe_save_kv_layer_to_connector( def maybe_save_kv_layer_to_connector(
layer_name: str, layer_name: str,
kv_cache_layer: List[torch.Tensor], kv_cache_layer: list[torch.Tensor],
): ):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return return

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from typing import ClassVar, List, Optional from typing import ClassVar, Optional
import torch import torch
@ -61,7 +61,7 @@ class ChunkedLocalAttention(Attention):
scale: float, scale: float,
attention_chunk_size: int, attention_chunk_size: int,
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[list[float]] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple from typing import Optional
import torch import torch
@ -31,7 +31,7 @@ else:
_flashmla_extension_C_AVAILABLE = False _flashmla_extension_C_AVAILABLE = False
def is_flashmla_supported() -> Tuple[bool, Optional[str]]: def is_flashmla_supported() -> tuple[bool, Optional[str]]:
""" """
Return: is_supported_flag, unsupported_reason (optional). Return: is_supported_flag, unsupported_reason (optional).
""" """
@ -57,7 +57,7 @@ def get_mla_metadata(
num_heads_q: Optional[int] = None, num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False, is_fp8_kvcache: bool = False,
topk: Optional[int] = None, topk: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
- cache_seqlens: (batch_size), dtype torch.int32. - cache_seqlens: (batch_size), dtype torch.int32.
@ -101,7 +101,7 @@ def flash_mla_with_kvcache(
descale_k: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None,
is_fp8_kvcache: bool = False, is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
- q: (batch_size, seq_len_q, num_heads_q, head_dim). - q: (batch_size, seq_len_q, num_heads_q, head_dim).
@ -183,7 +183,7 @@ def flash_mla_sparse_prefill(
indices: torch.Tensor, indices: torch.Tensor,
sm_scale: float, sm_scale: float,
d_v: int = 512, d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Sparse attention prefill kernel Sparse attention prefill kernel

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import Optional
import torch import torch
@ -41,7 +41,7 @@ class PagedAttentionMetadata:
class PagedAttention: class PagedAttention:
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> list[int]:
return [32, 64, 80, 96, 112, 120, 128, 192, 256] return [32, 64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod @staticmethod
@ -51,7 +51,7 @@ class PagedAttention:
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto", cache_dtype_str: str = "auto",
) -> Tuple[int, ...]: ) -> tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size) return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod @staticmethod
@ -59,7 +59,7 @@ class PagedAttention:
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size() x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1] num_blocks = kv_cache.shape[1]
@ -255,7 +255,7 @@ class PagedAttention:
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor, src_to_dists: torch.Tensor,
) -> None: ) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches] key_caches = [kv_cache[0] for kv_cache in kv_caches]

View File

@ -14,11 +14,8 @@ from typing import (
Annotated, Annotated,
Any, Any,
Callable, Callable,
Dict,
List,
Literal, Literal,
Optional, Optional,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -325,7 +322,7 @@ class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str = ModelConfig.model model: str = ModelConfig.model
served_model_name: Optional[Union[str, List[str]]] = ModelConfig.served_model_name served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name
tokenizer: Optional[str] = ModelConfig.tokenizer tokenizer: Optional[str] = ModelConfig.tokenizer
hf_config_path: Optional[str] = ModelConfig.hf_config_path hf_config_path: Optional[str] = ModelConfig.hf_config_path
runner: RunnerOption = ModelConfig.runner runner: RunnerOption = ModelConfig.runner
@ -350,7 +347,7 @@ class EngineArgs:
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
distributed_executor_backend: Optional[ distributed_executor_backend: Optional[
Union[str, DistributedExecutorBackend, Type[ExecutorBase]] Union[str, DistributedExecutorBackend, type[ExecutorBase]]
] = ParallelConfig.distributed_executor_backend ] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
@ -418,7 +415,7 @@ class EngineArgs:
media_io_kwargs: dict[str, dict[str, Any]] = get_field( media_io_kwargs: dict[str, dict[str, Any]] = get_field(
MultiModalConfig, "media_io_kwargs" MultiModalConfig, "media_io_kwargs"
) )
mm_processor_kwargs: Optional[Dict[str, Any]] = MultiModalConfig.mm_processor_kwargs mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_processor_cache_type: Optional[MMCacheType] = ( mm_processor_cache_type: Optional[MMCacheType] = (
@ -436,7 +433,7 @@ class EngineArgs:
enable_lora_bias: bool = LoRAConfig.bias_enabled enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[Dict[str, str]] = LoRAConfig.default_mm_loras default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
@ -446,7 +443,7 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns
enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
@ -467,7 +464,7 @@ class EngineArgs:
logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern
speculative_config: Optional[Dict[str, Any]] = None speculative_config: Optional[dict[str, Any]] = None
show_hidden_metrics_for_version: Optional[str] = ( show_hidden_metrics_for_version: Optional[str] = (
ObservabilityConfig.show_hidden_metrics_for_version ObservabilityConfig.show_hidden_metrics_for_version
@ -477,7 +474,7 @@ class EngineArgs:
ObservabilityConfig.collect_detailed_traces ObservabilityConfig.collect_detailed_traces
) )
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls
pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
override_pooler_config: Optional[Union[dict, PoolerConfig]] = ( override_pooler_config: Optional[Union[dict, PoolerConfig]] = (

View File

@ -3,7 +3,7 @@
import time import time
from collections import Counter as CollectionsCounter from collections import Counter as CollectionsCounter
from typing import Dict, List, Optional, Type, Union, cast from typing import Optional, Union, cast
import numpy as np import numpy as np
import prometheus_client import prometheus_client
@ -43,7 +43,7 @@ class Metrics:
_counter_cls = prometheus_client.Counter _counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram _histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], vllm_config: VllmConfig): def __init__(self, labelnames: list[str], vllm_config: VllmConfig):
# Unregister any existing vLLM collectors (for CI/CD) # Unregister any existing vLLM collectors (for CI/CD)
self._unregister_vllm_metrics() self._unregister_vllm_metrics()
@ -304,7 +304,7 @@ class _RayGaugeWrapper:
self, self,
name: str, name: str,
documentation: str = "", documentation: str = "",
labelnames: Optional[List[str]] = None, labelnames: Optional[list[str]] = None,
multiprocess_mode: str = "", multiprocess_mode: str = "",
): ):
del multiprocess_mode del multiprocess_mode
@ -330,7 +330,7 @@ class _RayCounterWrapper:
prometheus_client.Counter""" prometheus_client.Counter"""
def __init__( def __init__(
self, name: str, documentation: str = "", labelnames: Optional[List[str]] = None self, name: str, documentation: str = "", labelnames: Optional[list[str]] = None
): ):
labelnames_tuple = tuple(labelnames) if labelnames else None labelnames_tuple = tuple(labelnames) if labelnames else None
self._counter = ray_metrics.Counter( self._counter = ray_metrics.Counter(
@ -355,8 +355,8 @@ class _RayHistogramWrapper:
self, self,
name: str, name: str,
documentation: str = "", documentation: str = "",
labelnames: Optional[List[str]] = None, labelnames: Optional[list[str]] = None,
buckets: Optional[List[float]] = None, buckets: Optional[list[float]] = None,
): ):
labelnames_tuple = tuple(labelnames) if labelnames else None labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else [] boundaries = buckets if buckets else []
@ -381,17 +381,17 @@ class RayMetrics(Metrics):
Provides the same metrics as Metrics but uses Ray's util.metrics library. Provides the same metrics as Metrics but uses Ray's util.metrics library.
""" """
_gauge_cls: Type[prometheus_client.Gauge] = cast( _gauge_cls: type[prometheus_client.Gauge] = cast(
Type[prometheus_client.Gauge], _RayGaugeWrapper type[prometheus_client.Gauge], _RayGaugeWrapper
) )
_counter_cls: Type[prometheus_client.Counter] = cast( _counter_cls: type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper type[prometheus_client.Counter], _RayCounterWrapper
) )
_histogram_cls: Type[prometheus_client.Histogram] = cast( _histogram_cls: type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper type[prometheus_client.Histogram], _RayHistogramWrapper
) )
def __init__(self, labelnames: List[str], vllm_config: VllmConfig): def __init__(self, labelnames: list[str], vllm_config: VllmConfig):
if ray_metrics is None: if ray_metrics is None:
raise ImportError("RayMetrics requires Ray to be installed.") raise ImportError("RayMetrics requires Ray to be installed.")
super().__init__(labelnames, vllm_config) super().__init__(labelnames, vllm_config)
@ -401,14 +401,14 @@ class RayMetrics(Metrics):
pass pass
def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
""" """
Builds a list of buckets with increasing powers of 10 multiplied by Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values until the value exceeds the specified maximum. mantissa values until the value exceeds the specified maximum.
""" """
exponent = 0 exponent = 0
buckets: List[int] = [] buckets: list[int] = []
while True: while True:
for m in mantissa_lst: for m in mantissa_lst:
value = m * 10**exponent value = m * 10**exponent
@ -419,7 +419,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
exponent += 1 exponent += 1
def build_1_2_5_buckets(max_value: int) -> List[int]: def build_1_2_5_buckets(max_value: int) -> list[int]:
""" """
Example: Example:
>>> build_1_2_5_buckets(100) >>> build_1_2_5_buckets(100)
@ -428,7 +428,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
return build_buckets([1, 2, 5], max_value) return build_buckets([1, 2, 5], max_value)
def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: def build_1_2_3_5_8_buckets(max_value: int) -> list[int]:
""" """
Example: Example:
>>> build_1_2_3_5_8_buckets(100) >>> build_1_2_3_5_8_buckets(100)
@ -442,7 +442,7 @@ def local_interval_elapsed(now: float, last_log: float, local_interval: float) -
return elapsed_time > local_interval return elapsed_time > local_interval
def get_throughput(tracked_stats: List[int], now: float, last_log: float) -> float: def get_throughput(tracked_stats: list[int], now: float, last_log: float) -> float:
return float(np.sum(tracked_stats) / (now - last_log)) return float(np.sum(tracked_stats) / (now - last_log))
@ -530,7 +530,7 @@ class PrometheusStatLogger(StatLoggerBase):
_gauge_cls = prometheus_client.Gauge _gauge_cls = prometheus_client.Gauge
def __init__( def __init__(
self, local_interval: float, labels: Dict[str, str], vllm_config: VllmConfig self, local_interval: float, labels: dict[str, str], vllm_config: VllmConfig
) -> None: ) -> None:
super().__init__(local_interval, vllm_config) super().__init__(local_interval, vllm_config)
# Prometheus metrics # Prometheus metrics
@ -558,12 +558,12 @@ class PrometheusStatLogger(StatLoggerBase):
for label, count in data.items(): for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count) counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: def _log_histogram(self, histogram, data: Union[list[int], list[float]]) -> None:
# Convenience function for logging list to histogram. # Convenience function for logging list to histogram.
for datum in data: for datum in data:
histogram.labels(**self.labels).observe(datum) histogram.labels(**self.labels).observe(datum)
def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None: def _log_gauge_string(self, gauge, data: dict[str, str]) -> None:
gauge.labels(**data).set_to_current_time() gauge.labels(**data).set_to_current_time()
def _log_prometheus(self, stats: Stats) -> None: def _log_prometheus(self, stats: Stats) -> None:

View File

@ -16,7 +16,6 @@ do this in Python code and lazily import prometheus_client.
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List
from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.config import SupportsMetricsInfo, VllmConfig
@ -43,26 +42,26 @@ class Stats:
num_prompt_tokens_iter: int num_prompt_tokens_iter: int
num_generation_tokens_iter: int num_generation_tokens_iter: int
num_tokens_iter: int num_tokens_iter: int
time_to_first_tokens_iter: List[float] time_to_first_tokens_iter: list[float]
inter_token_latencies_iter: List[float] inter_token_latencies_iter: list[float]
num_preemption_iter: int num_preemption_iter: int
# Request stats (should have _requests suffix) # Request stats (should have _requests suffix)
# Latency # Latency
time_e2e_requests: List[float] time_e2e_requests: list[float]
time_queue_requests: List[float] time_queue_requests: list[float]
time_inference_requests: List[float] time_inference_requests: list[float]
time_prefill_requests: List[float] time_prefill_requests: list[float]
time_decode_requests: List[float] time_decode_requests: list[float]
# Metadata # Metadata
num_prompt_tokens_requests: List[int] num_prompt_tokens_requests: list[int]
num_generation_tokens_requests: List[int] num_generation_tokens_requests: list[int]
n_requests: List[int] n_requests: list[int]
max_num_generation_tokens_requests: List[int] max_num_generation_tokens_requests: list[int]
max_tokens_requests: List[int] max_tokens_requests: list[int]
finished_reason_requests: List[str] finished_reason_requests: list[str]
waiting_lora_adapters: List[str] waiting_lora_adapters: list[str]
running_lora_adapters: List[str] running_lora_adapters: list[str]
max_lora: str max_lora: str
@ -71,8 +70,8 @@ class StatLoggerBase(ABC):
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
# Tracked stats over current local logging interval. # Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = [] self.num_prompt_tokens: list[int] = []
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: list[int] = []
self.last_local_log = time.time() self.last_local_log = time.time()
self.local_interval = local_interval self.local_interval = local_interval

View File

@ -6,7 +6,7 @@ from __future__ import annotations
import datetime import datetime
import json import json
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import Literal, Optional, Union from typing import Literal, Union
from openai.types.responses import ( from openai.types.responses import (
ResponseFunctionToolCall, ResponseFunctionToolCall,
@ -79,13 +79,13 @@ def get_encoding():
def get_system_message( def get_system_message(
model_identity: Optional[str] = None, model_identity: str | None = None,
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, reasoning_effort: Literal["high", "medium", "low"] | None = None,
start_date: Optional[str] = None, start_date: str | None = None,
browser_description: Optional[str] = None, browser_description: str | None = None,
python_description: Optional[str] = None, python_description: str | None = None,
container_description: Optional[str] = None, container_description: str | None = None,
instructions: Optional[str] = None, instructions: str | None = None,
with_custom_tools: bool = False, with_custom_tools: bool = False,
) -> Message: ) -> Message:
sys_msg_content = SystemContent.new() sys_msg_content = SystemContent.new()
@ -137,8 +137,8 @@ def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
def get_developer_message( def get_developer_message(
instructions: Optional[str] = None, instructions: str | None = None,
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None, tools: list[Union[Tool, ChatCompletionToolsParam]] | None = None,
) -> Message: ) -> Message:
dev_msg_content = DeveloperContent.new() dev_msg_content = DeveloperContent.new()
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
@ -202,7 +202,7 @@ def parse_response_input(
msg = msg.with_channel("final") msg = msg.with_channel("final")
elif response_msg["type"] == "function_call_output": elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"] call_id = response_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None call_response: ResponseFunctionToolCall | None = None
for prev_response in reversed(prev_responses): for prev_response in reversed(prev_responses):
if ( if (
isinstance(prev_response, ResponseFunctionToolCall) isinstance(prev_response, ResponseFunctionToolCall)
@ -450,7 +450,7 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
def parse_chat_output( def parse_chat_output(
token_ids: Sequence[int], token_ids: Sequence[int],
) -> tuple[Optional[str], Optional[str], bool]: ) -> tuple[str | None, str | None, bool]:
parser = parse_output_into_messages(token_ids) parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported is_tool_call = False # TODO: update this when tool call is supported

View File

@ -6,7 +6,7 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Awaitable from collections.abc import Awaitable
from functools import cached_property from functools import cached_property
from typing import Any, Callable, List, Optional, Set, Union from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
@ -143,7 +143,7 @@ class ExecutorBase(ABC):
def execute_model( def execute_model(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: ) -> Optional[list[Union[SamplerOutput, PoolerOutput]]]:
output = self.collective_rpc("execute_model", args=(execute_model_req,)) output = self.collective_rpc("execute_model", args=(execute_model_req,))
return output[0] return output[0]
@ -163,7 +163,7 @@ class ExecutorBase(ABC):
assert lora_id > 0, "lora_id must be greater than 0." assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("pin_lora", args=(lora_id,))) return all(self.collective_rpc("pin_lora", args=(lora_id,)))
def list_loras(self) -> Set[int]: def list_loras(self) -> set[int]:
sets = self.collective_rpc("list_loras") sets = self.collective_rpc("list_loras")
for s in sets: for s in sets:
assert s == sets[0], "All workers should have the same LORAs." assert s == sets[0], "All workers should have the same LORAs."
@ -238,7 +238,7 @@ class ExecutorBase(ABC):
async def execute_model_async( async def execute_model_async(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
output = await make_async(self.execute_model)(execute_model_req) output = await make_async(self.execute_model)(execute_model_req)
return output return output
@ -272,7 +272,7 @@ class DistributedExecutorBase(ExecutorBase):
def execute_model( def execute_model(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
# TODO: unify into collective_rpc # TODO: unify into collective_rpc
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers( self.parallel_worker_tasks = self._run_workers(
@ -299,7 +299,7 @@ class DistributedExecutorBase(ExecutorBase):
@abstractmethod @abstractmethod
def _driver_execute_model( def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]: ) -> Optional[list[SamplerOutput]]:
"""Run execute_model in the driver worker. """Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution loop Passing None will cause the driver to stop the model execution loop
@ -346,7 +346,7 @@ class DistributedExecutorBase(ExecutorBase):
async def execute_model_async( async def execute_model_async(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers # Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task( self.parallel_worker_tasks = asyncio.create_task(
@ -371,7 +371,7 @@ class DistributedExecutorBase(ExecutorBase):
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None, execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
"""Execute the model asynchronously in the driver worker. """Execute the model asynchronously in the driver worker.
Passing None will cause the driver to stop the model execution Passing None will cause the driver to stop the model execution

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from array import array from array import array
from typing import Any, Type from typing import Any
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
@ -23,7 +23,7 @@ def encode_hook(obj: Any) -> Any:
return dict(obj) return dict(obj)
def decode_hook(type: Type, obj: Any) -> Any: def decode_hook(type: type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types and MultiModalKwargs. """Custom msgspec dec hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder

View File

@ -5,7 +5,7 @@ import asyncio
import os import os
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import cloudpickle import cloudpickle
import msgspec import msgspec
@ -114,10 +114,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(Optional[List[SamplerOutput]]) self.output_decoder = msgspec.msgpack.Decoder(Optional[list[SamplerOutput]])
self.use_v1 = envs.VLLM_USE_V1 self.use_v1 = envs.VLLM_USE_V1
self.pp_locks: Optional[List[asyncio.Lock]] = None self.pp_locks: Optional[list[asyncio.Lock]] = None
if not self.use_ray_compiled_dag: if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(self.driver_worker.execute_method) self.driver_exec_method = make_async(self.driver_worker.execute_method)
@ -137,7 +137,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
ray.kill(worker) ray.kill(worker)
self.forward_dag = None self.forward_dag = None
def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling # If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env. # configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
@ -164,12 +164,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
# It holds the resource for the driver worker. # It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors. # The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = [] self.workers: list[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank, # Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is # and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank. # the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight: if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs = self._configure_ray_workers_use_nsight(
@ -179,7 +179,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers. # Create the workers.
bundle_indices: List[int] bundle_indices: list[int]
if envs.VLLM_RAY_BUNDLE_INDICES: if envs.VLLM_RAY_BUNDLE_INDICES:
# Use the bundle indices specified by the user. # Use the bundle indices specified by the user.
bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
@ -200,7 +200,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
bundle_indices.append(bundle_id) bundle_indices.append(bundle_id)
bundle_indices = bundle_indices[: self.parallel_config.world_size] bundle_indices = bundle_indices[: self.parallel_config.world_size]
worker_metadata: List[RayWorkerMetaData] = [] worker_metadata: list[RayWorkerMetaData] = []
driver_ip = get_ip() driver_ip = get_ip()
for rank, bundle_id in enumerate(bundle_indices): for rank, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy( scheduling_strategy = PlacementGroupSchedulingStrategy(
@ -262,7 +262,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
"the driver on a GPU node." "the driver on a GPU node."
) )
ip_counts: Dict[str, int] = {} ip_counts: dict[str, int] = {}
for ip in worker_ips: for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1 ip_counts[ip] = ip_counts.get(ip, 0) + 1
@ -416,11 +416,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
# This is the list of workers that are rank 0 of each TP group EXCEPT # This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the # global rank 0. These are the workers that will broadcast to the
# rest of the workers. # rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = [] self.tp_driver_workers: list[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first # This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be # worker in a TP group. These are the workers that will be
# broadcasted to. # broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = [] self.non_driver_workers: list[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output. # Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers): for index, worker in enumerate(self.workers):
@ -433,7 +433,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _driver_execute_model( def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]: ) -> Optional[list[SamplerOutput]]:
"""Run execute_model in the driver worker. """Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution Passing None will cause the driver to stop the model execution
@ -446,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def execute_model( def execute_model(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
if not self.use_ray_spmd_worker: if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req) return super().execute_model(execute_model_req)
@ -675,7 +675,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
async def execute_model_async( async def execute_model_async(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
if not self.use_ray_spmd_worker: if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req) return await super().execute_model_async(execute_model_req)
@ -689,7 +689,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, execute_model_req: Optional[ExecuteModelRequest] = None self, execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
assert not self.use_ray_spmd_worker, ( assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
) )

View File

@ -4,7 +4,7 @@
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Union
import msgspec import msgspec
@ -59,7 +59,7 @@ try:
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: def get_node_and_gpu_ids(self) -> tuple[str, list[int]]:
node_id = ray.get_runtime_context().get_node_id() node_id = ray.get_runtime_context().get_node_id()
device_key = vllm.platforms.current_platform.ray_device_key device_key = vllm.platforms.current_platform.ray_device_key
if not device_key: if not device_key:
@ -72,7 +72,7 @@ try:
def execute_model_spmd( def execute_model_spmd(
self, self,
req_or_tuple: Union[bytes, Tuple[bytes, Optional[IntermediateTensors]]], req_or_tuple: Union[bytes, tuple[bytes, Optional[IntermediateTensors]]],
) -> bytes: ) -> bytes:
"""Execute model in SPMD fashion: used only when SPMD worker and """Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled. compiled DAG are both enabled.
@ -126,10 +126,10 @@ try:
def execute_model_ray( def execute_model_ray(
self, self,
scheduler_output: Union[ scheduler_output: Union[
"SchedulerOutput", Tuple["SchedulerOutput", "IntermediateTensors"] "SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
], ],
) -> Union[ ) -> Union[
"ModelRunnerOutput", Tuple["SchedulerOutput", "IntermediateTensors"] "ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
]: ]:
# This method is used by Ray Compiled Graph to execute the model, # This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary() # and it needs a special logic of self.setup_device_if_necessary()
@ -156,7 +156,7 @@ try:
output = output.get_output() output = output.get_output()
return output return output
def override_env_vars(self, vars: Dict[str, str]): def override_env_vars(self, vars: dict[str, str]):
os.environ.update(vars) os.environ.update(vars)
ray_import_err = None ray_import_err = None
@ -201,7 +201,7 @@ def _verify_bundles(
# bundle_idx -> bundle (e.g., {"GPU": 1}) # bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"] bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1}) # node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list)
for bundle_idx, node_id in bundle_to_node_ids.items(): for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx]) node_id_to_bundle[node_id].append(bundles[bundle_idx])
@ -383,7 +383,7 @@ def initialize_ray_cluster(
device_str, device_str,
) )
# Create a new placement group # Create a new placement group
placement_group_specs: List[Dict[str, float]] = [ placement_group_specs: list[dict[str, float]] = [
{device_str: 1.0} for _ in range(parallel_config.world_size) {device_str: 1.0} for _ in range(parallel_config.world_size)
] ]

View File

@ -4,7 +4,7 @@ import os
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property from functools import cached_property
from multiprocessing import Lock from multiprocessing import Lock
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -68,10 +68,10 @@ class UniProcExecutor(ExecutorBase):
self, self,
method: Union[str, Callable], method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: tuple = (),
kwargs: Optional[Dict] = None, kwargs: Optional[dict] = None,
non_block: bool = False, non_block: bool = False,
) -> List[Any]: ) -> list[Any]:
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
if self.mm_receiver_cache is not None and method == "execute_model": if self.mm_receiver_cache is not None and method == "execute_model":
@ -158,7 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
local_rank = int(os.environ["LOCAL_RANK"]) local_rank = int(os.environ["LOCAL_RANK"])
return distributed_init_method, rank, local_rank return distributed_init_method, rank, local_rank
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> tuple[int, int]:
""" """
Determine the number of available KV blocks. Determine the number of available KV blocks.
Add an additional all_reduce to get the min across all ranks. Add an additional all_reduce to get the min across all ranks.

View File

@ -1,9 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ( from typing import Optional
List, # noqa: UP035
Optional,
)
import torch import torch
@ -32,7 +29,7 @@ def flashinfer_fused_moe_blockscale_fp8(
intermediate_size: int, intermediate_size: int,
expert_offset: int, expert_offset: int,
local_num_experts: int, local_num_experts: int,
block_shape: List[int], # noqa: UP006 block_shape: list[int],
routed_scaling: float = 1.0, routed_scaling: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe

View File

@ -289,7 +289,7 @@ class MultiModalFieldElem:
return ( return (
(self.modality, self.key) == (other.modality, other.key) (self.modality, self.key) == (other.modality, other.key)
and data_equal and data_equal
and type(self.field) == type(other.field) and type(self.field) is type(other.field)
) # noqa: E721 ) # noqa: E721

View File

@ -4,7 +4,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.plugins import load_plugins_by_group from vllm.plugins import load_plugins_by_group
@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
def get_io_processor( def get_io_processor(
vllm_config: VllmConfig, plugin_from_init: Optional[str] = None vllm_config: VllmConfig, plugin_from_init: str | None = None
) -> IOProcessor | None: ) -> IOProcessor | None:
# Input.Output processors are loaded as plugins under the # Input.Output processors are loaded as plugins under the
# 'vllm.io_processor_plugins' group. Similar to platform # 'vllm.io_processor_plugins' group. Similar to platform

View File

@ -68,7 +68,6 @@ from typing import (
Generic, Generic,
Literal, Literal,
NamedTuple, NamedTuple,
Optional,
TextIO, TextIO,
TypeVar, TypeVar,
Union, Union,
@ -247,9 +246,7 @@ class CacheInfo(NamedTuple):
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
def __init__( def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
self, capacity: float, getsizeof: Optional[Callable[[_V], float]] = None
):
super().__init__(capacity, getsizeof) super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]() self.pinned_items = set[_K]()
@ -324,15 +321,15 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
self._LRUCache__order[key] = None # type: ignore self._LRUCache__order[key] = None # type: ignore
@overload @overload
def get(self, key: _K, /) -> Optional[_V]: ... def get(self, key: _K, /) -> _V | None: ...
@overload @overload
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ...
def get( def get(
self, key: _K, /, default: Optional[Union[_V, _T]] = None self, key: _K, /, default: Union[_V, _T] | None = None
) -> Optional[Union[_V, _T]]: ) -> Union[_V, _T] | None:
value: Optional[Union[_V, _T]] value: Union[_V, _T] | None
if key in self: if key in self:
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
@ -350,9 +347,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...
def pop( def pop(
self, key: _K, default: Optional[Union[_V, _T]] = None self, key: _K, default: Union[_V, _T] | None = None
) -> Optional[Union[_V, _T]]: ) -> Union[_V, _T] | None:
value: Optional[Union[_V, _T]] value: Union[_V, _T] | None
if key not in self: if key not in self:
return default return default
@ -379,7 +376,7 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
""" """
self.pinned_items.remove(key) self.pinned_items.remove(key)
def _on_remove(self, key: _K, value: Optional[_V]) -> None: def _on_remove(self, key: _K, value: _V | None) -> None:
pass pass
def remove_oldest(self, *, remove_pinned: bool = False) -> None: def remove_oldest(self, *, remove_pinned: bool = False) -> None:
@ -705,7 +702,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool:
def make_async( def make_async(
func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None func: Callable[P, T], executor: concurrent.futures.Executor | None = None
) -> Callable[P, Awaitable[T]]: ) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread. """Take a blocking function, and run it on in an executor thread.
@ -940,7 +937,7 @@ def _get_open_port() -> int:
return s.getsockname()[1] return s.getsockname()[1]
def find_process_using_port(port: int) -> Optional[psutil.Process]: def find_process_using_port(port: int) -> psutil.Process | None:
# TODO: We can not check for running processes with network # TODO: We can not check for running processes with network
# port on macOS. Therefore, we can not have a full graceful shutdown # port on macOS. Therefore, we can not have a full graceful shutdown
# of vLLM. For now, let's not look for processes in this case. # of vLLM. For now, let's not look for processes in this case.
@ -1025,8 +1022,8 @@ def _generate_random_fp8(
def get_kv_cache_torch_dtype( def get_kv_cache_torch_dtype(
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Union[str, torch.dtype] | None = None,
) -> torch.dtype: ) -> torch.dtype:
if isinstance(cache_dtype, str): if isinstance(cache_dtype, str):
if cache_dtype == "auto": if cache_dtype == "auto":
@ -1053,11 +1050,11 @@ def create_kv_caches_with_random_flash(
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Union[str, torch.dtype] | None = None,
seed: Optional[int] = None, seed: int | None = None,
device: Optional[str] = "cuda", device: str | None = "cuda",
cache_layout: Optional[str] = "NHD", cache_layout: str | None = "NHD",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -1095,10 +1092,10 @@ def create_kv_caches_with_random(
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Union[str, torch.dtype] | None = None,
seed: Optional[int] = None, seed: int | None = None,
device: Optional[str] = "cuda", device: str | None = "cuda",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16: if cache_dtype == "fp8" and head_size % 16:
raise ValueError( raise ValueError(
@ -1156,7 +1153,7 @@ def is_uva_available() -> bool:
class DeviceMemoryProfiler: class DeviceMemoryProfiler:
def __init__(self, device: Optional[torch.types.Device] = None): def __init__(self, device: torch.types.Device | None = None):
self.device = device self.device = device
def current_memory_usage(self) -> float: def current_memory_usage(self) -> float:
@ -1184,7 +1181,7 @@ def make_ndarray_with_pad(
pad: T, pad: T,
dtype: npt.DTypeLike, dtype: npt.DTypeLike,
*, *,
max_len: Optional[int] = None, max_len: int | None = None,
) -> npt.NDArray: ) -> npt.NDArray:
""" """
Make a padded array from 2D inputs. Make a padded array from 2D inputs.
@ -1209,8 +1206,8 @@ def make_tensor_with_pad(
pad: T, pad: T,
dtype: torch.dtype, dtype: torch.dtype,
*, *,
max_len: Optional[int] = None, max_len: int | None = None,
device: Optional[Union[str, torch.device]] = None, device: Union[str, torch.device] | None = None,
pin_memory: bool = False, pin_memory: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -1405,7 +1402,7 @@ def find_nccl_library() -> str:
return so_file return so_file
def find_nccl_include_paths() -> Optional[list[str]]: def find_nccl_include_paths() -> list[str] | None:
""" """
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
environment variable, or we find the library file brought by environment variable, or we find the library file brought by
@ -1525,7 +1522,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def deprecate_args( def deprecate_args(
start_index: int, start_index: int,
is_deprecated: Union[bool, Callable[[], bool]] = True, is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None, additional_message: str | None = None,
) -> Callable[[F], F]: ) -> Callable[[F], F]:
if not callable(is_deprecated): if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated) is_deprecated = partial(identity, is_deprecated)
@ -1565,7 +1562,7 @@ def deprecate_args(
def deprecate_kwargs( def deprecate_kwargs(
*kws: str, *kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True, is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None, additional_message: str | None = None,
) -> Callable[[F], F]: ) -> Callable[[F], F]:
deprecated_kws = set(kws) deprecated_kws = set(kws)
@ -1598,7 +1595,7 @@ def deprecate_kwargs(
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int: def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for # Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes. # LRU Cache purposes.
@ -1746,7 +1743,7 @@ class FlexibleArgumentParser(ArgumentParser):
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n' ' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n" " --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
) )
_search_keyword: Optional[str] = None _search_keyword: str | None = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# Set the default "formatter_class" to SortedHelpFormatter # Set the default "formatter_class" to SortedHelpFormatter
@ -2245,7 +2242,7 @@ def supports_kw(
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Mapping[str, object]], overrides: Mapping[str, object] | None,
*, *,
requires_kw_only: bool = True, requires_kw_only: bool = True,
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
@ -2695,10 +2692,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op( def direct_register_custom_op(
op_name: str, op_name: str,
op_func: Callable, op_func: Callable,
mutates_args: Optional[list[str]] = None, mutates_args: list[str] | None = None,
fake_impl: Optional[Callable] = None, fake_impl: Callable | None = None,
target_lib: Optional[Library] = None, target_lib: Library | None = None,
dispatch_key: Optional[str] = None, dispatch_key: str | None = None,
tags: tuple[torch.Tag, ...] = (), tags: tuple[torch.Tag, ...] = (),
): ):
""" """
@ -3016,7 +3013,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]:
return scheme, host, port return scheme, host, port
def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str: def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
"""Make a ZMQ path from its parts. """Make a ZMQ path from its parts.
Args: Args:
@ -3039,9 +3036,9 @@ def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str, path: str,
socket_type: Any, socket_type: Any,
bind: Optional[bool] = None, bind: bool | None = None,
identity: Optional[bytes] = None, identity: bytes | None = None,
linger: Optional[int] = None, linger: int | None = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics.""" """Make a ZMQ socket with the proper bind/connect semantics."""
@ -3098,9 +3095,9 @@ def make_zmq_socket(
def zmq_socket_ctx( def zmq_socket_ctx(
path: str, path: str,
socket_type: Any, socket_type: Any,
bind: Optional[bool] = None, bind: bool | None = None,
linger: int = 0, linger: int = 0,
identity: Optional[bytes] = None, identity: bytes | None = None,
) -> Iterator[zmq.Socket]: ) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket""" """Context manager for a ZMQ socket"""
@ -3163,7 +3160,7 @@ def get_mp_context():
def bind_kv_cache( def bind_kv_cache(
ctx: dict[str, Any], ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
shared_kv_cache_layers: Optional[dict[str, str]] = None, shared_kv_cache_layers: dict[str, str] | None = None,
) -> None: ) -> None:
# Bind the kv_cache tensor to Attention modules, similar to # Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
@ -3379,7 +3376,7 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
@contextlib.contextmanager @contextlib.contextmanager
def cprofile_context(save_file: Optional[str] = None): def cprofile_context(save_file: str | None = None):
"""Run a cprofile """Run a cprofile
Args: Args:
@ -3401,7 +3398,7 @@ def cprofile_context(save_file: Optional[str] = None):
prof.print_stats(sort="cumtime") prof.print_stats(sort="cumtime")
def cprofile(save_file: Optional[str] = None, enabled: bool = True): def cprofile(save_file: str | None = None, enabled: bool = True):
"""Decorator to profile a Python method using cProfile. """Decorator to profile a Python method using cProfile.
Args: Args:
@ -3608,7 +3605,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.write = write_with_prefix # type: ignore[method-assign] file.write = write_with_prefix # type: ignore[method-assign]
def decorate_logs(process_name: Optional[str] = None) -> None: def decorate_logs(process_name: str | None = None) -> None:
""" """
Adds a process-specific prefix to each line of output written to stdout and Adds a process-specific prefix to each line of output written to stdout and
stderr. stderr.
@ -3631,8 +3628,8 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
def length_from_prompt_token_ids_or_embeds( def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: Optional[list[int]], prompt_token_ids: list[int] | None,
prompt_embeds: Optional[torch.Tensor], prompt_embeds: torch.Tensor | None,
) -> int: ) -> int:
"""Calculate the request length (in number of tokens) give either """Calculate the request length (in number of tokens) give either
prompt_token_ids or prompt_embeds. prompt_token_ids or prompt_embeds.

View File

@ -10,7 +10,7 @@ from __future__ import annotations
import functools import functools
import importlib import importlib
import os import os
from typing import Any, Callable, NoReturn, Optional from typing import Any, Callable, NoReturn
import torch import torch
@ -325,7 +325,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
def should_use_deepgemm_for_fp8_linear( def should_use_deepgemm_for_fp8_linear(
output_dtype: torch.dtype, output_dtype: torch.dtype,
weight: torch.Tensor, weight: torch.Tensor,
supports_deep_gemm: Optional[bool] = None, supports_deep_gemm: bool | None = None,
): ):
if supports_deep_gemm is None: if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported() supports_deep_gemm = is_deep_gemm_supported()

View File

@ -12,7 +12,7 @@ import functools
import importlib import importlib
import importlib.util import importlib.util
import os import os
from typing import Any, Callable, NoReturn, Optional from typing import Any, Callable, NoReturn
import requests import requests
import torch import torch
@ -202,14 +202,14 @@ def supports_trtllm_attention() -> bool:
@functools.cache @functools.cache
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]: def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" """Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None: if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value return env_value
def force_use_trtllm_attention() -> Optional[bool]: def force_use_trtllm_attention() -> bool | None:
""" """
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set, Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
return ``True`` if TRTLLM attention is forced to be used, return ``True`` if TRTLLM attention is forced to be used,
@ -401,7 +401,7 @@ def flashinfer_scaled_fp8_mm(
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2 assert a.ndim == 2 and b.ndim == 2
assert a.shape[1] == b.shape[0] assert a.shape[1] == b.shape[0]

View File

@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional, Union from typing import ClassVar, Union
import numpy as np import numpy as np
import torch import torch
@ -254,12 +254,12 @@ class FlashInferMetadata:
# For cascade attention (CPU for planning). # For cascade attention (CPU for planning).
use_cascade: bool use_cascade: bool
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
qo_indptr_gpu: Optional[torch.Tensor] = None qo_indptr_gpu: torch.Tensor | None = None
paged_kv_indptr_gpu: Optional[torch.Tensor] = None paged_kv_indptr_gpu: torch.Tensor | None = None
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
@ -727,13 +727,13 @@ class FlashInferImpl(AttentionImpl):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: int, num_kv_heads: int,
alibi_slopes: Optional[list[float]], alibi_slopes: list[float] | None,
sliding_window: Optional[int], sliding_window: int | None,
kv_cache_dtype: str, kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None, logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None, kv_sharing_target_layer_name: int | None = None,
sinks: Optional[torch.Tensor] = None, sinks: torch.Tensor | None = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
@ -763,7 +763,7 @@ class FlashInferImpl(AttentionImpl):
"FlashInferImpl" "FlashInferImpl"
) )
self.sinks: Optional[torch.Tensor] = None self.sinks: torch.Tensor | None = None
if sinks is not None: if sinks is not None:
if sinks.shape[0] != num_heads: if sinks.shape[0] != num_heads:
raise ValueError( raise ValueError(
@ -776,9 +776,9 @@ class FlashInferImpl(AttentionImpl):
self.support_trtllm_attn = ( self.support_trtllm_attn = (
supports_trtllm_attention() and num_heads % num_kv_heads == 0 supports_trtllm_attention() and num_heads % num_kv_heads == 0
) )
self.bmm1_scale: Optional[float] = None self.bmm1_scale: float | None = None
self.bmm2_scale: Optional[float] = None self.bmm2_scale: float | None = None
self.o_sf_scale: Optional[float] = None self.o_sf_scale: float | None = None
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return ( return (
@ -795,9 +795,9 @@ class FlashInferImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None, output: torch.Tensor | None = None,
output_scale: Optional[torch.Tensor] = None, output_scale: torch.Tensor | None = None,
output_block_scale: Optional[torch.Tensor] = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashInfer. """Forward pass with FlashInfer.
@ -1093,13 +1093,13 @@ def fast_plan_decode(
page_size: int, page_size: int,
pos_encoding_mode: str = "NONE", pos_encoding_mode: str = "NONE",
window_left: int = -1, window_left: int = -1,
logits_soft_cap: Optional[float] = None, logits_soft_cap: float | None = None,
q_data_type: Optional[Union[str, torch.dtype]] = "float16", q_data_type: Union[str, torch.dtype] | None = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None, kv_data_type: Union[str, torch.dtype] | None = None,
data_type: Optional[Union[str, torch.dtype]] = None, data_type: Union[str, torch.dtype] | None = None,
sm_scale: Optional[float] = None, sm_scale: float | None = None,
rope_scale: Optional[float] = None, rope_scale: float | None = None,
rope_theta: Optional[float] = None, rope_theta: float | None = None,
non_blocking: bool = True, non_blocking: bool = True,
) -> None: ) -> None:
""" """

View File

@ -4,7 +4,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm._bc_linter import bc_linter_include from vllm._bc_linter import bc_linter_include
@ -25,14 +25,14 @@ if TYPE_CHECKING:
@dataclass @dataclass
class NewRequestData: class NewRequestData:
req_id: str req_id: str
prompt_token_ids: Optional[list[int]] prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec] mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams] sampling_params: SamplingParams | None
pooling_params: Optional[PoolingParams] pooling_params: PoolingParams | None
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
lora_request: Optional[LoRARequest] lora_request: LoRARequest | None
prompt_embeds: Optional[torch.Tensor] = None prompt_embeds: torch.Tensor | None = None
@classmethod @classmethod
def from_request( def from_request(
@ -98,7 +98,7 @@ class CachedRequestData:
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty. # When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]] new_token_ids: list[list[int]]
new_block_ids: list[Optional[tuple[list[int], ...]]] new_block_ids: list[tuple[list[int], ...] | None]
num_computed_tokens: list[int] num_computed_tokens: list[int]
num_output_tokens: list[int] num_output_tokens: list[int]
@ -160,7 +160,7 @@ class SchedulerOutput:
# for filling the next token bitmask # for filling the next token bitmask
structured_output_request_ids: dict[str, int] structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch # the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]] grammar_bitmask: npt.NDArray[np.int32] | None
# KV Cache Connector metadata. # KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None kv_connector_metadata: KVConnectorMetadata | None = None

View File

@ -7,7 +7,7 @@ import itertools
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional, Union from typing import Any, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
@ -64,7 +64,7 @@ class Scheduler(SchedulerInterface):
# request ids should be included in the EngineCoreOutputs returned # request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine # by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently. # case to track request lifetimes efficiently.
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( self.finished_req_ids_dict: dict[int, set[str]] | None = (
defaultdict(set) if include_finished_set else None defaultdict(set) if include_finished_set else None
) )
@ -708,7 +708,7 @@ class Scheduler(SchedulerInterface):
) -> CachedRequestData: ) -> CachedRequestData:
req_ids: list[str] = [] req_ids: list[str] = []
new_token_ids: list[list[int]] = [] new_token_ids: list[list[int]] = []
new_block_ids: list[Optional[tuple[list[int], ...]]] = [] new_block_ids: list[tuple[list[int], ...] | None] = []
num_computed_tokens: list[int] = [] num_computed_tokens: list[int] = []
num_output_tokens: list[int] = [] num_output_tokens: list[int] = []
@ -921,7 +921,7 @@ class Scheduler(SchedulerInterface):
kv_connector_output = model_runner_output.kv_connector_output kv_connector_output = model_runner_output.kv_connector_output
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats = ( kv_connector_stats = (
kv_connector_output.kv_connector_stats if kv_connector_output else None kv_connector_output.kv_connector_stats if kv_connector_output else None
) )
@ -1212,7 +1212,7 @@ class Scheduler(SchedulerInterface):
request.status = finished_status request.status = finished_status
self._free_request(request) self._free_request(request)
def _free_request(self, request: Request) -> Optional[dict[str, Any]]: def _free_request(self, request: Request) -> dict[str, Any] | None:
assert request.is_finished() assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request) delay_free_blocks, kv_xfer_params = self._connector_finished(request)
@ -1243,9 +1243,9 @@ class Scheduler(SchedulerInterface):
def make_stats( def make_stats(
self, self,
spec_decoding_stats: Optional[SpecDecodingStats] = None, spec_decoding_stats: SpecDecodingStats | None = None,
kv_connector_stats: Optional[KVConnectorStats] = None, kv_connector_stats: KVConnectorStats | None = None,
) -> Optional[SchedulerStats]: ) -> SchedulerStats | None:
if not self.log_stats: if not self.log_stats:
return None return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
@ -1262,10 +1262,10 @@ class Scheduler(SchedulerInterface):
def make_spec_decoding_stats( def make_spec_decoding_stats(
self, self,
spec_decoding_stats: Optional[SpecDecodingStats], spec_decoding_stats: SpecDecodingStats | None,
num_draft_tokens: int, num_draft_tokens: int,
num_accepted_tokens: int, num_accepted_tokens: int,
) -> Optional[SpecDecodingStats]: ) -> SpecDecodingStats | None:
if not self.log_stats: if not self.log_stats:
return None return None
if spec_decoding_stats is None: if spec_decoding_stats is None:
@ -1285,12 +1285,12 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods # KV Connector Related Methods
######################################################################## ########################################################################
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: def get_kv_connector(self) -> KVConnectorBase_V1 | None:
return self.connector return self.connector
def _connector_finished( def _connector_finished(
self, request: Request self, request: Request
) -> tuple[bool, Optional[dict[str, Any]]]: ) -> tuple[bool, dict[str, Any] | None]:
""" """
Invoke the KV connector request_finished() method if applicable. Invoke the KV connector request_finished() method if applicable.

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import multiprocessing import multiprocessing
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -35,11 +35,11 @@ class StructuredOutputManager:
"""Engine-level manager for structured output requests.""" """Engine-level manager for structured output requests."""
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None self.backend: StructuredOutputBackend | None = None
self.reasoner: Optional[ReasoningParser] = None self.reasoner: ReasoningParser | None = None
self.vllm_config = vllm_config self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None self._grammar_bitmask: torch.Tensor | None = None
self._full_mask = torch.tensor(-1, dtype=torch.int32) self._full_mask = torch.tensor(-1, dtype=torch.int32)
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
@ -168,7 +168,7 @@ class StructuredOutputManager:
requests: dict[str, Request], requests: dict[str, Request],
structured_output_request_ids: dict[str, int], structured_output_request_ids: dict[str, int],
scheduled_spec_decode_tokens: dict[str, list[int]], scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]: ) -> npt.NDArray[np.int32] | None:
# Prepare the structured output bitmask for this batch. # Prepare the structured output bitmask for this batch.
if not structured_output_request_ids: if not structured_output_request_ids:
return None return None

View File

@ -7,7 +7,7 @@ import copy
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Union
import torch import torch
@ -252,7 +252,7 @@ def serialize_guidance_grammar(
def validate_guidance_grammar( def validate_guidance_grammar(
sampling_params: SamplingParams, tokenizer: Optional[llguidance.LLTokenizer] = None sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
) -> None: ) -> None:
tp, grm = get_structured_output_key(sampling_params) tp, grm = get_structured_output_key(sampling_params)
guidance_grm = serialize_guidance_grammar(tp, grm) guidance_grm = serialize_guidance_grammar(tp, grm)

View File

@ -20,10 +20,10 @@ from vllm.v1.structured_output.backend_types import (
@dataclasses.dataclass @dataclasses.dataclass
class StructuredOutputRequest: class StructuredOutputRequest:
sampling_params: SamplingParams sampling_params: SamplingParams
_grammar: Optional[ _grammar: Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] | None = (
Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] None
] = None )
reasoning_ended: Optional[bool] = None reasoning_ended: bool | None = None
def _check_grammar_completion(self) -> bool: def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports # NOTE: We have to lazy import to gate circular imports
@ -43,7 +43,7 @@ class StructuredOutputRequest:
return self._check_grammar_completion() return self._check_grammar_completion()
@property @property
def grammar(self) -> Optional[StructuredOutputGrammar]: def grammar(self) -> StructuredOutputGrammar | None:
completed = self._check_grammar_completion() completed = self._check_grammar_completion()
return ( return (
cast(Optional[StructuredOutputGrammar], self._grammar) cast(Optional[StructuredOutputGrammar], self._grammar)

View File

@ -4,7 +4,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -78,8 +78,8 @@ class WorkerBase:
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
# Device and model state # Device and model state
self.device: Optional[torch.device] = None self.device: torch.device | None = None
self.model_runner: Optional[nn.Module] = None self.model_runner: nn.Module | None = None
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation.""" """Get specifications for KV cache implementation."""
@ -115,8 +115,8 @@ class WorkerBase:
raise NotImplementedError raise NotImplementedError
def execute_model( def execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] = None self, execute_model_req: ExecuteModelRequest | None = None
) -> Optional[list[SamplerOutput]]: ) -> list[SamplerOutput] | None:
raise NotImplementedError raise NotImplementedError
def start_worker_execution_loop(self) -> None: def start_worker_execution_loop(self) -> None:
@ -198,8 +198,8 @@ class WorkerWrapperBase:
group. group.
""" """
self.rpc_rank = rpc_rank self.rpc_rank = rpc_rank
self.worker: Optional[WorkerBase] = None self.worker: WorkerBase | None = None
self.vllm_config: Optional[VllmConfig] = None self.vllm_config: VllmConfig | None = None
# do not store this `vllm_config`, `init_worker` will set the final # do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in # one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be # `WorkerWrapperBase`, `init_cached_hf_modules` should be