From 1c0c68202cc128e740223e033273caa949c45f15 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sun, 5 Oct 2025 17:37:55 +0100 Subject: [PATCH] Fix per file ruff ignores related to typing (#26254) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- pyproject.toml | 18 +--- tests/compile/test_full_graph.py | 4 +- tests/entrypoints/openai/test_serving_chat.py | 6 +- .../prithvi_io_processor/prithvi_processor.py | 22 ++--- tests/v1/engine/test_llm_engine.py | 4 +- vllm/attention/backends/abstract.py | 12 +-- vllm/attention/layer.py | 6 +- .../layers/chunked_local_attention.py | 4 +- vllm/attention/ops/flashmla.py | 10 +- vllm/attention/ops/paged_attn.py | 10 +- vllm/engine/arg_utils.py | 17 ++-- vllm/engine/metrics.py | 42 ++++---- vllm/engine/metrics_types.py | 35 ++++--- vllm/entrypoints/harmony_utils.py | 24 ++--- vllm/executor/executor_base.py | 16 +-- vllm/executor/msgspec_utils.py | 4 +- vllm/executor/ray_distributed_executor.py | 30 +++--- vllm/executor/ray_utils.py | 16 +-- vllm/executor/uniproc_executor.py | 10 +- .../layers/fused_moe/flashinfer_trtllm_moe.py | 7 +- vllm/multimodal/inputs.py | 2 +- vllm/plugins/io_processors/__init__.py | 3 +- vllm/utils/__init__.py | 99 +++++++++---------- vllm/utils/deep_gemm.py | 4 +- vllm/utils/flashinfer.py | 8 +- vllm/v1/attention/backends/flashinfer.py | 50 +++++----- vllm/v1/core/sched/output.py | 18 ++-- vllm/v1/core/sched/scheduler.py | 24 ++--- vllm/v1/structured_output/__init__.py | 10 +- vllm/v1/structured_output/backend_guidance.py | 4 +- vllm/v1/structured_output/request.py | 10 +- vllm/v1/worker/worker_base.py | 14 +-- 32 files changed, 258 insertions(+), 285 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2b416d3206..34846a4f88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ include = ["vllm*"] "vllm/distributed/parallel_state.py" = ["SIM108"] "vllm/entrypoints/chat_utils.py" = ["SIM108"] "vllm/entrypoints/llm.py" = ["SIM108"] +"vllm/executor/ray_distributed_executor.py" = ["SIM108", "SIM112"] "vllm/model_executor/layers/batch_invariant.py" = ["SIM108"] "vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"] "vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"] @@ -134,23 +135,6 @@ include = ["vllm*"] "tools/profiler/print_layerwise_table.py" = ["SIM118"] ## Loop variable binding issues "tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"] -## Type annotation modernization and other rules -"vllm/attention/backends/abstract.py" = ["UP035", "UP006"] -"vllm/attention/layer.py" = ["UP035", "UP006"] -"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"] -"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"] -"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"] -"vllm/engine/arg_utils.py" = ["UP035", "UP006"] -"vllm/engine/metrics.py" = ["UP035", "UP006"] -"vllm/engine/metrics_types.py" = ["UP035", "UP006"] -"vllm/executor/executor_base.py" = ["UP035", "UP006"] -"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"] -"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"] -"vllm/executor/ray_utils.py" = ["UP035", "UP006"] -"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"] -"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"] -## Type comparison issues -"vllm/multimodal/inputs.py" = ["E721"] # End of temporary ignores [tool.ruff.lint] diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 319b31d3a9..8ccae4cfb9 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import tempfile -from typing import Any, Optional, Union +from typing import Any, Union import pytest import torch @@ -21,7 +21,7 @@ from vllm.utils import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test -def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): +def models_list(*, all: bool = True, keywords: list[str] | None = None): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), ( diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index a52b9a436f..abe5a5f4ff 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -6,7 +6,7 @@ from __future__ import annotations import asyncio from contextlib import suppress from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock, MagicMock import pytest @@ -233,9 +233,9 @@ class MockModelConfig: multimodal_config = MultiModalConfig() hf_config = MockHFConfig() logits_processor_pattern = None - diff_sampling_param: Optional[dict] = None + diff_sampling_param: dict | None = None allowed_local_media_path: str = "" - allowed_media_domains: Optional[list[str]] = None + allowed_media_domains: list[str] | None = None encoder_config = None generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index 1d8a7d2040..a2a8d0ec9a 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -9,7 +9,7 @@ import os import tempfile import urllib.request from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Union import albumentations import numpy as np @@ -98,9 +98,9 @@ def _convert_np_uint8(float_image: torch.Tensor): def read_geotiff( - file_path: Optional[str] = None, - path_type: Optional[str] = None, - file_data: Optional[bytes] = None, + file_path: str | None = None, + path_type: str | None = None, + file_data: bytes | None = None, ) -> tuple[torch.Tensor, dict, tuple[float, float] | None]: """Read all bands from *file_path* and return image + meta info. @@ -114,8 +114,8 @@ def read_geotiff( if all([x is None for x in [file_path, path_type, file_data]]): raise Exception("All input fields to read_geotiff are None") - write_to_file: Optional[bytes] = None - path: Optional[str] = None + write_to_file: bytes | None = None + path: str | None = None if file_data is not None: # with tempfile.NamedTemporaryFile() as tmpfile: # tmpfile.write(file_data) @@ -162,9 +162,9 @@ def read_geotiff( def load_image( data: Union[list[str]], path_type: str, - mean: Optional[list[float]] = None, - std: Optional[list[float]] = None, - indices: Optional[Union[list[int], None]] = None, + mean: list[float] | None = None, + std: list[float] | None = None, + indices: Union[list[int], None] | None = None, ): """Build an input example by loading images in *file_paths*. @@ -278,7 +278,7 @@ class PrithviMultimodalDataProcessor(IOProcessor): def pre_process( self, prompt: IOProcessorInput, - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, ) -> Union[PromptType, Sequence[PromptType]]: image_data = dict(prompt) @@ -359,7 +359,7 @@ class PrithviMultimodalDataProcessor(IOProcessor): def post_process( self, model_output: Sequence[PoolingRequestOutput], - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, ) -> IOProcessorOutput: pred_imgs_list = [] diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 820c270928..a19ba56213 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -3,7 +3,7 @@ from __future__ import annotations import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest @@ -78,7 +78,7 @@ def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch): def _get_test_sampling_params( prompt_list: list[str], - seed: Optional[int] = 42, + seed: int | None = 42, structured_outputs: bool = False, ) -> tuple[list[SamplingParams], list[int]]: """Generate random sampling params for a batch.""" diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index d746c3295c..bb2f362711 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar +from typing import Generic, Optional, Protocol, TypeVar import torch @@ -48,12 +48,12 @@ class AttentionBackend(ABC): @staticmethod @abstractmethod - def get_impl_cls() -> Type["AttentionImpl"]: + def get_impl_cls() -> type["AttentionImpl"]: raise NotImplementedError @staticmethod @abstractmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: + def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError @classmethod @@ -73,11 +73,11 @@ class AttentionBackend(ABC): num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: raise NotImplementedError @staticmethod - def get_kv_cache_stride_order() -> Tuple[int, ...]: + def get_kv_cache_stride_order() -> tuple[int, ...]: raise NotImplementedError @classmethod @@ -147,7 +147,7 @@ class AttentionImpl(ABC, Generic[T]): head_size: int, scale: float, num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[list[float]] = None, sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", logits_soft_cap: Optional[float] = None, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6f3b67b281..6994debd45 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import Callable, List, Optional +from typing import Callable, Optional import torch import torch.nn as nn @@ -126,7 +126,7 @@ class Attention(nn.Module, AttentionLayerBase): head_size: int, scale: float, num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[list[float]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, logits_soft_cap: Optional[float] = None, @@ -586,7 +586,7 @@ def wait_for_kv_layer_from_connector(layer_name: str): def maybe_save_kv_layer_to_connector( layer_name: str, - kv_cache_layer: List[torch.Tensor], + kv_cache_layer: list[torch.Tensor], ): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 7554a41022..3d37e90160 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import ClassVar, List, Optional +from typing import ClassVar, Optional import torch @@ -61,7 +61,7 @@ class ChunkedLocalAttention(Attention): scale: float, attention_chunk_size: int, num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[list[float]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, kv_sharing_target_layer_name: Optional[str] = None, diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 20eaeb6bd4..0fe01a51ec 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py -from typing import Optional, Tuple +from typing import Optional import torch @@ -31,7 +31,7 @@ else: _flashmla_extension_C_AVAILABLE = False -def is_flashmla_supported() -> Tuple[bool, Optional[str]]: +def is_flashmla_supported() -> tuple[bool, Optional[str]]: """ Return: is_supported_flag, unsupported_reason (optional). """ @@ -57,7 +57,7 @@ def get_mla_metadata( num_heads_q: Optional[int] = None, is_fp8_kvcache: bool = False, topk: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: - cache_seqlens: (batch_size), dtype torch.int32. @@ -101,7 +101,7 @@ def flash_mla_with_kvcache( descale_k: Optional[torch.Tensor] = None, is_fp8_kvcache: bool = False, indices: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). @@ -183,7 +183,7 @@ def flash_mla_sparse_prefill( indices: torch.Tensor, sm_scale: float, d_v: int = 512, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sparse attention prefill kernel diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index cdf0c929ce..4db7d1a3a3 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional import torch @@ -41,7 +41,7 @@ class PagedAttentionMetadata: class PagedAttention: @staticmethod - def get_supported_head_sizes() -> List[int]: + def get_supported_head_sizes() -> list[int]: return [32, 64, 80, 96, 112, 120, 128, 192, 256] @staticmethod @@ -51,7 +51,7 @@ class PagedAttention: num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) @staticmethod @@ -59,7 +59,7 @@ class PagedAttention: kv_cache: torch.Tensor, num_kv_heads: int, head_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: x = 16 // kv_cache.element_size() num_blocks = kv_cache.shape[1] @@ -255,7 +255,7 @@ class PagedAttention: @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d7ba70381d..a94ef598f2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -14,11 +14,8 @@ from typing import ( Annotated, Any, Callable, - Dict, - List, Literal, Optional, - Type, TypeVar, Union, cast, @@ -325,7 +322,7 @@ class EngineArgs: """Arguments for vLLM engine.""" model: str = ModelConfig.model - served_model_name: Optional[Union[str, List[str]]] = ModelConfig.served_model_name + served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name tokenizer: Optional[str] = ModelConfig.tokenizer hf_config_path: Optional[str] = ModelConfig.hf_config_path runner: RunnerOption = ModelConfig.runner @@ -350,7 +347,7 @@ class EngineArgs: # is intended for expert use only. The API may change without # notice. distributed_executor_backend: Optional[ - Union[str, DistributedExecutorBackend, Type[ExecutorBase]] + Union[str, DistributedExecutorBackend, type[ExecutorBase]] ] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size @@ -418,7 +415,7 @@ class EngineArgs: media_io_kwargs: dict[str, dict[str, Any]] = get_field( MultiModalConfig, "media_io_kwargs" ) - mm_processor_kwargs: Optional[Dict[str, Any]] = MultiModalConfig.mm_processor_kwargs + mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb mm_processor_cache_type: Optional[MMCacheType] = ( @@ -436,7 +433,7 @@ class EngineArgs: enable_lora_bias: bool = LoRAConfig.bias_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank - default_mm_loras: Optional[Dict[str, str]] = LoRAConfig.default_mm_loras + default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype @@ -446,7 +443,7 @@ class EngineArgs: num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") - ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns + ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input @@ -467,7 +464,7 @@ class EngineArgs: logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern - speculative_config: Optional[Dict[str, Any]] = None + speculative_config: Optional[dict[str, Any]] = None show_hidden_metrics_for_version: Optional[str] = ( ObservabilityConfig.show_hidden_metrics_for_version @@ -477,7 +474,7 @@ class EngineArgs: ObservabilityConfig.collect_detailed_traces ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy - scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls + scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config override_pooler_config: Optional[Union[dict, PoolerConfig]] = ( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 2f48a0d622..45b798ed96 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -3,7 +3,7 @@ import time from collections import Counter as CollectionsCounter -from typing import Dict, List, Optional, Type, Union, cast +from typing import Optional, Union, cast import numpy as np import prometheus_client @@ -43,7 +43,7 @@ class Metrics: _counter_cls = prometheus_client.Counter _histogram_cls = prometheus_client.Histogram - def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + def __init__(self, labelnames: list[str], vllm_config: VllmConfig): # Unregister any existing vLLM collectors (for CI/CD) self._unregister_vllm_metrics() @@ -304,7 +304,7 @@ class _RayGaugeWrapper: self, name: str, documentation: str = "", - labelnames: Optional[List[str]] = None, + labelnames: Optional[list[str]] = None, multiprocess_mode: str = "", ): del multiprocess_mode @@ -330,7 +330,7 @@ class _RayCounterWrapper: prometheus_client.Counter""" def __init__( - self, name: str, documentation: str = "", labelnames: Optional[List[str]] = None + self, name: str, documentation: str = "", labelnames: Optional[list[str]] = None ): labelnames_tuple = tuple(labelnames) if labelnames else None self._counter = ray_metrics.Counter( @@ -355,8 +355,8 @@ class _RayHistogramWrapper: self, name: str, documentation: str = "", - labelnames: Optional[List[str]] = None, - buckets: Optional[List[float]] = None, + labelnames: Optional[list[str]] = None, + buckets: Optional[list[float]] = None, ): labelnames_tuple = tuple(labelnames) if labelnames else None boundaries = buckets if buckets else [] @@ -381,17 +381,17 @@ class RayMetrics(Metrics): Provides the same metrics as Metrics but uses Ray's util.metrics library. """ - _gauge_cls: Type[prometheus_client.Gauge] = cast( - Type[prometheus_client.Gauge], _RayGaugeWrapper + _gauge_cls: type[prometheus_client.Gauge] = cast( + type[prometheus_client.Gauge], _RayGaugeWrapper ) - _counter_cls: Type[prometheus_client.Counter] = cast( - Type[prometheus_client.Counter], _RayCounterWrapper + _counter_cls: type[prometheus_client.Counter] = cast( + type[prometheus_client.Counter], _RayCounterWrapper ) - _histogram_cls: Type[prometheus_client.Histogram] = cast( - Type[prometheus_client.Histogram], _RayHistogramWrapper + _histogram_cls: type[prometheus_client.Histogram] = cast( + type[prometheus_client.Histogram], _RayHistogramWrapper ) - def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + def __init__(self, labelnames: list[str], vllm_config: VllmConfig): if ray_metrics is None: raise ImportError("RayMetrics requires Ray to be installed.") super().__init__(labelnames, vllm_config) @@ -401,14 +401,14 @@ class RayMetrics(Metrics): pass -def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: +def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by mantissa values until the value exceeds the specified maximum. """ exponent = 0 - buckets: List[int] = [] + buckets: list[int] = [] while True: for m in mantissa_lst: value = m * 10**exponent @@ -419,7 +419,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: exponent += 1 -def build_1_2_5_buckets(max_value: int) -> List[int]: +def build_1_2_5_buckets(max_value: int) -> list[int]: """ Example: >>> build_1_2_5_buckets(100) @@ -428,7 +428,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: return build_buckets([1, 2, 5], max_value) -def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: +def build_1_2_3_5_8_buckets(max_value: int) -> list[int]: """ Example: >>> build_1_2_3_5_8_buckets(100) @@ -442,7 +442,7 @@ def local_interval_elapsed(now: float, last_log: float, local_interval: float) - return elapsed_time > local_interval -def get_throughput(tracked_stats: List[int], now: float, last_log: float) -> float: +def get_throughput(tracked_stats: list[int], now: float, last_log: float) -> float: return float(np.sum(tracked_stats) / (now - last_log)) @@ -530,7 +530,7 @@ class PrometheusStatLogger(StatLoggerBase): _gauge_cls = prometheus_client.Gauge def __init__( - self, local_interval: float, labels: Dict[str, str], vllm_config: VllmConfig + self, local_interval: float, labels: dict[str, str], vllm_config: VllmConfig ) -> None: super().__init__(local_interval, vllm_config) # Prometheus metrics @@ -558,12 +558,12 @@ class PrometheusStatLogger(StatLoggerBase): for label, count in data.items(): counter.labels(**{**self.labels, label_key: label}).inc(count) - def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: + def _log_histogram(self, histogram, data: Union[list[int], list[float]]) -> None: # Convenience function for logging list to histogram. for datum in data: histogram.labels(**self.labels).observe(datum) - def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None: + def _log_gauge_string(self, gauge, data: dict[str, str]) -> None: gauge.labels(**data).set_to_current_time() def _log_prometheus(self, stats: Stats) -> None: diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index d9a53fed7c..ac796f4e1c 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -16,7 +16,6 @@ do this in Python code and lazily import prometheus_client. import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List from vllm.config import SupportsMetricsInfo, VllmConfig @@ -43,26 +42,26 @@ class Stats: num_prompt_tokens_iter: int num_generation_tokens_iter: int num_tokens_iter: int - time_to_first_tokens_iter: List[float] - inter_token_latencies_iter: List[float] + time_to_first_tokens_iter: list[float] + inter_token_latencies_iter: list[float] num_preemption_iter: int # Request stats (should have _requests suffix) # Latency - time_e2e_requests: List[float] - time_queue_requests: List[float] - time_inference_requests: List[float] - time_prefill_requests: List[float] - time_decode_requests: List[float] + time_e2e_requests: list[float] + time_queue_requests: list[float] + time_inference_requests: list[float] + time_prefill_requests: list[float] + time_decode_requests: list[float] # Metadata - num_prompt_tokens_requests: List[int] - num_generation_tokens_requests: List[int] - n_requests: List[int] - max_num_generation_tokens_requests: List[int] - max_tokens_requests: List[int] - finished_reason_requests: List[str] - waiting_lora_adapters: List[str] - running_lora_adapters: List[str] + num_prompt_tokens_requests: list[int] + num_generation_tokens_requests: list[int] + n_requests: list[int] + max_num_generation_tokens_requests: list[int] + max_tokens_requests: list[int] + finished_reason_requests: list[str] + waiting_lora_adapters: list[str] + running_lora_adapters: list[str] max_lora: str @@ -71,8 +70,8 @@ class StatLoggerBase(ABC): def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: # Tracked stats over current local logging interval. - self.num_prompt_tokens: List[int] = [] - self.num_generation_tokens: List[int] = [] + self.num_prompt_tokens: list[int] = [] + self.num_generation_tokens: list[int] = [] self.last_local_log = time.time() self.local_interval = local_interval diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 6a0bb152e4..bf6cc3e97c 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -6,7 +6,7 @@ from __future__ import annotations import datetime import json from collections.abc import Iterable, Sequence -from typing import Literal, Optional, Union +from typing import Literal, Union from openai.types.responses import ( ResponseFunctionToolCall, @@ -79,13 +79,13 @@ def get_encoding(): def get_system_message( - model_identity: Optional[str] = None, - reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, - start_date: Optional[str] = None, - browser_description: Optional[str] = None, - python_description: Optional[str] = None, - container_description: Optional[str] = None, - instructions: Optional[str] = None, + model_identity: str | None = None, + reasoning_effort: Literal["high", "medium", "low"] | None = None, + start_date: str | None = None, + browser_description: str | None = None, + python_description: str | None = None, + container_description: str | None = None, + instructions: str | None = None, with_custom_tools: bool = False, ) -> Message: sys_msg_content = SystemContent.new() @@ -137,8 +137,8 @@ def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]): def get_developer_message( - instructions: Optional[str] = None, - tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None, + instructions: str | None = None, + tools: list[Union[Tool, ChatCompletionToolsParam]] | None = None, ) -> Message: dev_msg_content = DeveloperContent.new() if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: @@ -202,7 +202,7 @@ def parse_response_input( msg = msg.with_channel("final") elif response_msg["type"] == "function_call_output": call_id = response_msg["call_id"] - call_response: Optional[ResponseFunctionToolCall] = None + call_response: ResponseFunctionToolCall | None = None for prev_response in reversed(prev_responses): if ( isinstance(prev_response, ResponseFunctionToolCall) @@ -450,7 +450,7 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: def parse_chat_output( token_ids: Sequence[int], -) -> tuple[Optional[str], Optional[str], bool]: +) -> tuple[str | None, str | None, bool]: parser = parse_output_into_messages(token_ids) output_msgs = parser.messages is_tool_call = False # TODO: update this when tool call is supported diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index af26918598..3a7347b8e4 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -6,7 +6,7 @@ import time from abc import ABC, abstractmethod from collections.abc import Awaitable from functools import cached_property -from typing import Any, Callable, List, Optional, Set, Union +from typing import Any, Callable, Optional, Union from typing_extensions import TypeVar @@ -143,7 +143,7 @@ class ExecutorBase(ABC): def execute_model( self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: + ) -> Optional[list[Union[SamplerOutput, PoolerOutput]]]: output = self.collective_rpc("execute_model", args=(execute_model_req,)) return output[0] @@ -163,7 +163,7 @@ class ExecutorBase(ABC): assert lora_id > 0, "lora_id must be greater than 0." return all(self.collective_rpc("pin_lora", args=(lora_id,))) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: sets = self.collective_rpc("list_loras") for s in sets: assert s == sets[0], "All workers should have the same LORAs." @@ -238,7 +238,7 @@ class ExecutorBase(ABC): async def execute_model_async( self, execute_model_req: ExecuteModelRequest - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: """Executes one model step on the given sequences.""" output = await make_async(self.execute_model)(execute_model_req) return output @@ -272,7 +272,7 @@ class DistributedExecutorBase(ExecutorBase): def execute_model( self, execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: # TODO: unify into collective_rpc if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( @@ -299,7 +299,7 @@ class DistributedExecutorBase(ExecutorBase): @abstractmethod def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[list[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop @@ -346,7 +346,7 @@ class DistributedExecutorBase(ExecutorBase): async def execute_model_async( self, execute_model_req: ExecuteModelRequest - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers self.parallel_worker_tasks = asyncio.create_task( @@ -371,7 +371,7 @@ class DistributedExecutorBase(ExecutorBase): async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: """Execute the model asynchronously in the driver worker. Passing None will cause the driver to stop the model execution diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py index 9ca190cd3c..ac16f06b16 100644 --- a/vllm/executor/msgspec_utils.py +++ b/vllm/executor/msgspec_utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from array import array -from typing import Any, Type +from typing import Any from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE @@ -23,7 +23,7 @@ def encode_hook(obj: Any) -> Any: return dict(obj) -def decode_hook(type: Type, obj: Any) -> Any: +def decode_hook(type: type, obj: Any) -> Any: """Custom msgspec dec hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index be124f7643..40f2915667 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -5,7 +5,7 @@ import asyncio import os from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import cloudpickle import msgspec @@ -114,10 +114,10 @@ class RayDistributedExecutor(DistributedExecutorBase): self._init_workers_ray(placement_group) self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - self.output_decoder = msgspec.msgpack.Decoder(Optional[List[SamplerOutput]]) + self.output_decoder = msgspec.msgpack.Decoder(Optional[list[SamplerOutput]]) self.use_v1 = envs.VLLM_USE_V1 - self.pp_locks: Optional[List[asyncio.Lock]] = None + self.pp_locks: Optional[list[asyncio.Lock]] = None if not self.use_ray_compiled_dag: self.driver_exec_method = make_async(self.driver_worker.execute_method) @@ -137,7 +137,7 @@ class RayDistributedExecutor(DistributedExecutorBase): ray.kill(worker) self.forward_dag = None - def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: + def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]: # If nsight profiling is enabled, we need to set the profiling # configuration for the ray workers as runtime env. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) @@ -164,12 +164,12 @@ class RayDistributedExecutor(DistributedExecutorBase): # It holds the resource for the driver worker. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerWrapper] = [] + self.workers: list[RayWorkerWrapper] = [] # Used in ray compiled DAG: indexed first by PP rank, # and then TP rank. In other words, the inner list is # the TP group of workers for a PP rank. - self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + self.pp_tp_workers: list[list[RayWorkerWrapper]] = [] if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( @@ -179,7 +179,7 @@ class RayDistributedExecutor(DistributedExecutorBase): logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) # Create the workers. - bundle_indices: List[int] + bundle_indices: list[int] if envs.VLLM_RAY_BUNDLE_INDICES: # Use the bundle indices specified by the user. bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) @@ -200,7 +200,7 @@ class RayDistributedExecutor(DistributedExecutorBase): bundle_indices.append(bundle_id) bundle_indices = bundle_indices[: self.parallel_config.world_size] - worker_metadata: List[RayWorkerMetaData] = [] + worker_metadata: list[RayWorkerMetaData] = [] driver_ip = get_ip() for rank, bundle_id in enumerate(bundle_indices): scheduling_strategy = PlacementGroupSchedulingStrategy( @@ -262,7 +262,7 @@ class RayDistributedExecutor(DistributedExecutorBase): "the driver on a GPU node." ) - ip_counts: Dict[str, int] = {} + ip_counts: dict[str, int] = {} for ip in worker_ips: ip_counts[ip] = ip_counts.get(ip, 0) + 1 @@ -416,11 +416,11 @@ class RayDistributedExecutor(DistributedExecutorBase): # This is the list of workers that are rank 0 of each TP group EXCEPT # global rank 0. These are the workers that will broadcast to the # rest of the workers. - self.tp_driver_workers: List[RayWorkerWrapper] = [] + self.tp_driver_workers: list[RayWorkerWrapper] = [] # This is the list of workers that are not drivers and not the first # worker in a TP group. These are the workers that will be # broadcasted to. - self.non_driver_workers: List[RayWorkerWrapper] = [] + self.non_driver_workers: list[RayWorkerWrapper] = [] # Enforce rank order for correct rank to return final output. for index, worker in enumerate(self.workers): @@ -433,7 +433,7 @@ class RayDistributedExecutor(DistributedExecutorBase): def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[list[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution @@ -446,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase): def execute_model( self, execute_model_req: ExecuteModelRequest - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: if not self.use_ray_spmd_worker: return super().execute_model(execute_model_req) @@ -675,7 +675,7 @@ class RayDistributedExecutor(DistributedExecutorBase): async def execute_model_async( self, execute_model_req: ExecuteModelRequest - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: if not self.use_ray_spmd_worker: return await super().execute_model_async(execute_model_req) @@ -689,7 +689,7 @@ class RayDistributedExecutor(DistributedExecutorBase): async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: assert not self.use_ray_spmd_worker, ( "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" ) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index abe3d2be9f..c3c8a70678 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -4,7 +4,7 @@ import os import time from collections import defaultdict -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import msgspec @@ -59,7 +59,7 @@ try: def get_node_ip(self) -> str: return get_ip() - def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + def get_node_and_gpu_ids(self) -> tuple[str, list[int]]: node_id = ray.get_runtime_context().get_node_id() device_key = vllm.platforms.current_platform.ray_device_key if not device_key: @@ -72,7 +72,7 @@ try: def execute_model_spmd( self, - req_or_tuple: Union[bytes, Tuple[bytes, Optional[IntermediateTensors]]], + req_or_tuple: Union[bytes, tuple[bytes, Optional[IntermediateTensors]]], ) -> bytes: """Execute model in SPMD fashion: used only when SPMD worker and compiled DAG are both enabled. @@ -126,10 +126,10 @@ try: def execute_model_ray( self, scheduler_output: Union[ - "SchedulerOutput", Tuple["SchedulerOutput", "IntermediateTensors"] + "SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"] ], ) -> Union[ - "ModelRunnerOutput", Tuple["SchedulerOutput", "IntermediateTensors"] + "ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"] ]: # This method is used by Ray Compiled Graph to execute the model, # and it needs a special logic of self.setup_device_if_necessary() @@ -156,7 +156,7 @@ try: output = output.get_output() return output - def override_env_vars(self, vars: Dict[str, str]): + def override_env_vars(self, vars: dict[str, str]): os.environ.update(vars) ray_import_err = None @@ -201,7 +201,7 @@ def _verify_bundles( # bundle_idx -> bundle (e.g., {"GPU": 1}) bundles = pg_data["bundles"] # node_id -> List of bundle (e.g., {"GPU": 1}) - node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list) for bundle_idx, node_id in bundle_to_node_ids.items(): node_id_to_bundle[node_id].append(bundles[bundle_idx]) @@ -383,7 +383,7 @@ def initialize_ray_cluster( device_str, ) # Create a new placement group - placement_group_specs: List[Dict[str, float]] = [ + placement_group_specs: list[dict[str, float]] = [ {device_str: 1.0} for _ in range(parallel_config.world_size) ] diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index ced054143c..8206f23d18 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -4,7 +4,7 @@ import os from concurrent.futures import Future, ThreadPoolExecutor from functools import cached_property from multiprocessing import Lock -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -68,10 +68,10 @@ class UniProcExecutor(ExecutorBase): self, method: Union[str, Callable], timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict] = None, + args: tuple = (), + kwargs: Optional[dict] = None, non_block: bool = False, - ) -> List[Any]: + ) -> list[Any]: if kwargs is None: kwargs = {} if self.mm_receiver_cache is not None and method == "execute_model": @@ -158,7 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor): local_rank = int(os.environ["LOCAL_RANK"]) return distributed_init_method, rank, local_rank - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_num_available_blocks(self) -> tuple[int, int]: """ Determine the number of available KV blocks. Add an additional all_reduce to get the min across all ranks. diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 88fc460d90..d12d059155 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import ( - List, # noqa: UP035 - Optional, -) +from typing import Optional import torch @@ -32,7 +29,7 @@ def flashinfer_fused_moe_blockscale_fp8( intermediate_size: int, expert_offset: int, local_num_experts: int, - block_shape: List[int], # noqa: UP006 + block_shape: list[int], routed_scaling: float = 1.0, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 51e33ea263..45e6ac2ada 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -289,7 +289,7 @@ class MultiModalFieldElem: return ( (self.modality, self.key) == (other.modality, other.key) and data_equal - and type(self.field) == type(other.field) + and type(self.field) is type(other.field) ) # noqa: E721 diff --git a/vllm/plugins/io_processors/__init__.py b/vllm/plugins/io_processors/__init__.py index 8ec96ed009..7a914442c4 100644 --- a/vllm/plugins/io_processors/__init__.py +++ b/vllm/plugins/io_processors/__init__.py @@ -4,7 +4,6 @@ from __future__ import annotations import logging -from typing import Optional from vllm.config import VllmConfig from vllm.plugins import load_plugins_by_group @@ -15,7 +14,7 @@ logger = logging.getLogger(__name__) def get_io_processor( - vllm_config: VllmConfig, plugin_from_init: Optional[str] = None + vllm_config: VllmConfig, plugin_from_init: str | None = None ) -> IOProcessor | None: # Input.Output processors are loaded as plugins under the # 'vllm.io_processor_plugins' group. Similar to platform diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index c9999649b5..166380219b 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -68,7 +68,6 @@ from typing import ( Generic, Literal, NamedTuple, - Optional, TextIO, TypeVar, Union, @@ -247,9 +246,7 @@ class CacheInfo(NamedTuple): class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): - def __init__( - self, capacity: float, getsizeof: Optional[Callable[[_V], float]] = None - ): + def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None): super().__init__(capacity, getsizeof) self.pinned_items = set[_K]() @@ -324,15 +321,15 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): self._LRUCache__order[key] = None # type: ignore @overload - def get(self, key: _K, /) -> Optional[_V]: ... + def get(self, key: _K, /) -> _V | None: ... @overload def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... def get( - self, key: _K, /, default: Optional[Union[_V, _T]] = None - ) -> Optional[Union[_V, _T]]: - value: Optional[Union[_V, _T]] + self, key: _K, /, default: Union[_V, _T] | None = None + ) -> Union[_V, _T] | None: + value: Union[_V, _T] | None if key in self: value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] @@ -350,9 +347,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... def pop( - self, key: _K, default: Optional[Union[_V, _T]] = None - ) -> Optional[Union[_V, _T]]: - value: Optional[Union[_V, _T]] + self, key: _K, default: Union[_V, _T] | None = None + ) -> Union[_V, _T] | None: + value: Union[_V, _T] | None if key not in self: return default @@ -379,7 +376,7 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): """ self.pinned_items.remove(key) - def _on_remove(self, key: _K, value: Optional[_V]) -> None: + def _on_remove(self, key: _K, value: _V | None) -> None: pass def remove_oldest(self, *, remove_pinned: bool = False) -> None: @@ -705,7 +702,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool: def make_async( - func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None + func: Callable[P, T], executor: concurrent.futures.Executor | None = None ) -> Callable[P, Awaitable[T]]: """Take a blocking function, and run it on in an executor thread. @@ -940,7 +937,7 @@ def _get_open_port() -> int: return s.getsockname()[1] -def find_process_using_port(port: int) -> Optional[psutil.Process]: +def find_process_using_port(port: int) -> psutil.Process | None: # TODO: We can not check for running processes with network # port on macOS. Therefore, we can not have a full graceful shutdown # of vLLM. For now, let's not look for processes in this case. @@ -1025,8 +1022,8 @@ def _generate_random_fp8( def get_kv_cache_torch_dtype( - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, ) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": @@ -1053,11 +1050,11 @@ def create_kv_caches_with_random_flash( num_layers: int, num_heads: int, head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = None, - device: Optional[str] = "cuda", - cache_layout: Optional[str] = "NHD", + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, + seed: int | None = None, + device: str | None = "cuda", + cache_layout: str | None = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform @@ -1095,10 +1092,10 @@ def create_kv_caches_with_random( num_layers: int, num_heads: int, head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = None, - device: Optional[str] = "cuda", + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, + seed: int | None = None, + device: str | None = "cuda", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: if cache_dtype == "fp8" and head_size % 16: raise ValueError( @@ -1156,7 +1153,7 @@ def is_uva_available() -> bool: class DeviceMemoryProfiler: - def __init__(self, device: Optional[torch.types.Device] = None): + def __init__(self, device: torch.types.Device | None = None): self.device = device def current_memory_usage(self) -> float: @@ -1184,7 +1181,7 @@ def make_ndarray_with_pad( pad: T, dtype: npt.DTypeLike, *, - max_len: Optional[int] = None, + max_len: int | None = None, ) -> npt.NDArray: """ Make a padded array from 2D inputs. @@ -1209,8 +1206,8 @@ def make_tensor_with_pad( pad: T, dtype: torch.dtype, *, - max_len: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, + max_len: int | None = None, + device: Union[str, torch.device] | None = None, pin_memory: bool = False, ) -> torch.Tensor: """ @@ -1405,7 +1402,7 @@ def find_nccl_library() -> str: return so_file -def find_nccl_include_paths() -> Optional[list[str]]: +def find_nccl_include_paths() -> list[str] | None: """ We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` environment variable, or we find the library file brought by @@ -1525,7 +1522,7 @@ F = TypeVar("F", bound=Callable[..., Any]) def deprecate_args( start_index: int, is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: Optional[str] = None, + additional_message: str | None = None, ) -> Callable[[F], F]: if not callable(is_deprecated): is_deprecated = partial(identity, is_deprecated) @@ -1565,7 +1562,7 @@ def deprecate_args( def deprecate_kwargs( *kws: str, is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: Optional[str] = None, + additional_message: str | None = None, ) -> Callable[[F], F]: deprecated_kws = set(kws) @@ -1598,7 +1595,7 @@ def deprecate_kwargs( @lru_cache(maxsize=8) -def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int: +def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. @@ -1746,7 +1743,7 @@ class FlexibleArgumentParser(ArgumentParser): ' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n' " --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n" ) - _search_keyword: Optional[str] = None + _search_keyword: str | None = None def __init__(self, *args, **kwargs): # Set the default "formatter_class" to SortedHelpFormatter @@ -2245,7 +2242,7 @@ def supports_kw( def get_allowed_kwarg_only_overrides( callable: Callable[..., object], - overrides: Optional[Mapping[str, object]], + overrides: Mapping[str, object] | None, *, requires_kw_only: bool = True, allow_var_kwargs: bool = False, @@ -2695,10 +2692,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa def direct_register_custom_op( op_name: str, op_func: Callable, - mutates_args: Optional[list[str]] = None, - fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None, - dispatch_key: Optional[str] = None, + mutates_args: list[str] | None = None, + fake_impl: Callable | None = None, + target_lib: Library | None = None, + dispatch_key: str | None = None, tags: tuple[torch.Tag, ...] = (), ): """ @@ -3016,7 +3013,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]: return scheme, host, port -def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str: +def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str: """Make a ZMQ path from its parts. Args: @@ -3039,9 +3036,9 @@ def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] path: str, socket_type: Any, - bind: Optional[bool] = None, - identity: Optional[bytes] = None, - linger: Optional[int] = None, + bind: bool | None = None, + identity: bytes | None = None, + linger: int | None = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -3098,9 +3095,9 @@ def make_zmq_socket( def zmq_socket_ctx( path: str, socket_type: Any, - bind: Optional[bool] = None, + bind: bool | None = None, linger: int = 0, - identity: Optional[bytes] = None, + identity: bytes | None = None, ) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" @@ -3163,7 +3160,7 @@ def get_mp_context(): def bind_kv_cache( ctx: dict[str, Any], kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] - shared_kv_cache_layers: Optional[dict[str, str]] = None, + shared_kv_cache_layers: dict[str, str] | None = None, ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -3379,7 +3376,7 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: @contextlib.contextmanager -def cprofile_context(save_file: Optional[str] = None): +def cprofile_context(save_file: str | None = None): """Run a cprofile Args: @@ -3401,7 +3398,7 @@ def cprofile_context(save_file: Optional[str] = None): prof.print_stats(sort="cumtime") -def cprofile(save_file: Optional[str] = None, enabled: bool = True): +def cprofile(save_file: str | None = None, enabled: bool = True): """Decorator to profile a Python method using cProfile. Args: @@ -3608,7 +3605,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: file.write = write_with_prefix # type: ignore[method-assign] -def decorate_logs(process_name: Optional[str] = None) -> None: +def decorate_logs(process_name: str | None = None) -> None: """ Adds a process-specific prefix to each line of output written to stdout and stderr. @@ -3631,8 +3628,8 @@ def decorate_logs(process_name: Optional[str] = None) -> None: def length_from_prompt_token_ids_or_embeds( - prompt_token_ids: Optional[list[int]], - prompt_embeds: Optional[torch.Tensor], + prompt_token_ids: list[int] | None, + prompt_embeds: torch.Tensor | None, ) -> int: """Calculate the request length (in number of tokens) give either prompt_token_ids or prompt_embeds. diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index ac4fcc0156..1d7f05cf67 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -10,7 +10,7 @@ from __future__ import annotations import functools import importlib import os -from typing import Any, Callable, NoReturn, Optional +from typing import Any, Callable, NoReturn import torch @@ -325,7 +325,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): def should_use_deepgemm_for_fp8_linear( output_dtype: torch.dtype, weight: torch.Tensor, - supports_deep_gemm: Optional[bool] = None, + supports_deep_gemm: bool | None = None, ): if supports_deep_gemm is None: supports_deep_gemm = is_deep_gemm_supported() diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 22dfbe60f8..ab0cf2051f 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -12,7 +12,7 @@ import functools import importlib import importlib.util import os -from typing import Any, Callable, NoReturn, Optional +from typing import Any, Callable, NoReturn import requests import torch @@ -202,14 +202,14 @@ def supports_trtllm_attention() -> bool: @functools.cache -def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]: +def _force_use_trtllm_attention(env_value: bool | None) -> bool | None: """Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" if env_value is not None: logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) return env_value -def force_use_trtllm_attention() -> Optional[bool]: +def force_use_trtllm_attention() -> bool | None: """ Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set, return ``True`` if TRTLLM attention is forced to be used, @@ -401,7 +401,7 @@ def flashinfer_scaled_fp8_mm( scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert a.shape[1] == b.shape[0] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 07316cd12a..c7a826a67d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import ClassVar, Union import numpy as np import torch @@ -254,12 +254,12 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None - cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None + decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None + cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None - qo_indptr_gpu: Optional[torch.Tensor] = None - paged_kv_indptr_gpu: Optional[torch.Tensor] = None + qo_indptr_gpu: torch.Tensor | None = None + paged_kv_indptr_gpu: torch.Tensor | None = None class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): @@ -727,13 +727,13 @@ class FlashInferImpl(AttentionImpl): head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -763,7 +763,7 @@ class FlashInferImpl(AttentionImpl): "FlashInferImpl" ) - self.sinks: Optional[torch.Tensor] = None + self.sinks: torch.Tensor | None = None if sinks is not None: if sinks.shape[0] != num_heads: raise ValueError( @@ -776,9 +776,9 @@ class FlashInferImpl(AttentionImpl): self.support_trtllm_attn = ( supports_trtllm_attention() and num_heads % num_kv_heads == 0 ) - self.bmm1_scale: Optional[float] = None - self.bmm2_scale: Optional[float] = None - self.o_sf_scale: Optional[float] = None + self.bmm1_scale: float | None = None + self.bmm2_scale: float | None = None + self.o_sf_scale: float | None = None def fused_output_quant_supported(self, quant_key: QuantKey): return ( @@ -795,9 +795,9 @@ class FlashInferImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -1093,13 +1093,13 @@ def fast_plan_decode( page_size: int, pos_encoding_mode: str = "NONE", window_left: int = -1, - logits_soft_cap: Optional[float] = None, - q_data_type: Optional[Union[str, torch.dtype]] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, - data_type: Optional[Union[str, torch.dtype]] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, + logits_soft_cap: float | None = None, + q_data_type: Union[str, torch.dtype] | None = "float16", + kv_data_type: Union[str, torch.dtype] | None = None, + data_type: Union[str, torch.dtype] | None = None, + sm_scale: float | None = None, + rope_scale: float | None = None, + rope_theta: float | None = None, non_blocking: bool = True, ) -> None: """ diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 5d31811662..cbce91b990 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm._bc_linter import bc_linter_include @@ -25,14 +25,14 @@ if TYPE_CHECKING: @dataclass class NewRequestData: req_id: str - prompt_token_ids: Optional[list[int]] + prompt_token_ids: list[int] | None mm_features: list[MultiModalFeatureSpec] - sampling_params: Optional[SamplingParams] - pooling_params: Optional[PoolingParams] + sampling_params: SamplingParams | None + pooling_params: PoolingParams | None block_ids: tuple[list[int], ...] num_computed_tokens: int - lora_request: Optional[LoRARequest] - prompt_embeds: Optional[torch.Tensor] = None + lora_request: LoRARequest | None + prompt_embeds: torch.Tensor | None = None @classmethod def from_request( @@ -98,7 +98,7 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] - new_block_ids: list[Optional[tuple[list[int], ...]]] + new_block_ids: list[tuple[list[int], ...] | None] num_computed_tokens: list[int] num_output_tokens: list[int] @@ -160,7 +160,7 @@ class SchedulerOutput: # for filling the next token bitmask structured_output_request_ids: dict[str, int] # the bitmask for the whole batch - grammar_bitmask: Optional[npt.NDArray[np.int32]] + grammar_bitmask: npt.NDArray[np.int32] | None # KV Cache Connector metadata. - kv_connector_metadata: Optional[KVConnectorMetadata] = None + kv_connector_metadata: KVConnectorMetadata | None = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 24ff87cd0a..d9a0ff1aa5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,7 +7,7 @@ import itertools import time from collections import defaultdict from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -64,7 +64,7 @@ class Scheduler(SchedulerInterface): # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + self.finished_req_ids_dict: dict[int, set[str]] | None = ( defaultdict(set) if include_finished_set else None ) @@ -708,7 +708,7 @@ class Scheduler(SchedulerInterface): ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] - new_block_ids: list[Optional[tuple[list[int], ...]]] = [] + new_block_ids: list[tuple[list[int], ...] | None] = [] num_computed_tokens: list[int] = [] num_output_tokens: list[int] = [] @@ -921,7 +921,7 @@ class Scheduler(SchedulerInterface): kv_connector_output = model_runner_output.kv_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None + spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats = ( kv_connector_output.kv_connector_stats if kv_connector_output else None ) @@ -1212,7 +1212,7 @@ class Scheduler(SchedulerInterface): request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + def _free_request(self, request: Request) -> dict[str, Any] | None: assert request.is_finished() delay_free_blocks, kv_xfer_params = self._connector_finished(request) @@ -1243,9 +1243,9 @@ class Scheduler(SchedulerInterface): def make_stats( self, - spec_decoding_stats: Optional[SpecDecodingStats] = None, - kv_connector_stats: Optional[KVConnectorStats] = None, - ) -> Optional[SchedulerStats]: + spec_decoding_stats: SpecDecodingStats | None = None, + kv_connector_stats: KVConnectorStats | None = None, + ) -> SchedulerStats | None: if not self.log_stats: return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() @@ -1262,10 +1262,10 @@ class Scheduler(SchedulerInterface): def make_spec_decoding_stats( self, - spec_decoding_stats: Optional[SpecDecodingStats], + spec_decoding_stats: SpecDecodingStats | None, num_draft_tokens: int, num_accepted_tokens: int, - ) -> Optional[SpecDecodingStats]: + ) -> SpecDecodingStats | None: if not self.log_stats: return None if spec_decoding_stats is None: @@ -1285,12 +1285,12 @@ class Scheduler(SchedulerInterface): # KV Connector Related Methods ######################################################################## - def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + def get_kv_connector(self) -> KVConnectorBase_V1 | None: return self.connector def _connector_finished( self, request: Request - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Invoke the KV connector request_finished() method if applicable. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 1f51f98ca9..1b5e75313d 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations import multiprocessing from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm.config import VllmConfig from vllm.logger import init_logger @@ -35,11 +35,11 @@ class StructuredOutputManager: """Engine-level manager for structured output requests.""" def __init__(self, vllm_config: VllmConfig): - self.backend: Optional[StructuredOutputBackend] = None - self.reasoner: Optional[ReasoningParser] = None + self.backend: StructuredOutputBackend | None = None + self.reasoner: ReasoningParser | None = None self.vllm_config = vllm_config - self._grammar_bitmask: Optional[torch.Tensor] = None + self._grammar_bitmask: torch.Tensor | None = None self._full_mask = torch.tensor(-1, dtype=torch.int32) max_batch_size = self.vllm_config.scheduler_config.max_num_seqs @@ -168,7 +168,7 @@ class StructuredOutputManager: requests: dict[str, Request], structured_output_request_ids: dict[str, int], scheduled_spec_decode_tokens: dict[str, list[int]], - ) -> Optional[npt.NDArray[np.int32]]: + ) -> npt.NDArray[np.int32] | None: # Prepare the structured output bitmask for this batch. if not structured_output_request_ids: return None diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index a48a705e8f..081cdfdc99 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -7,7 +7,7 @@ import copy import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch @@ -252,7 +252,7 @@ def serialize_guidance_grammar( def validate_guidance_grammar( - sampling_params: SamplingParams, tokenizer: Optional[llguidance.LLTokenizer] = None + sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None ) -> None: tp, grm = get_structured_output_key(sampling_params) guidance_grm = serialize_guidance_grammar(tp, grm) diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 26f72ae50c..233c7c1e78 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -20,10 +20,10 @@ from vllm.v1.structured_output.backend_types import ( @dataclasses.dataclass class StructuredOutputRequest: sampling_params: SamplingParams - _grammar: Optional[ - Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] - ] = None - reasoning_ended: Optional[bool] = None + _grammar: Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] | None = ( + None + ) + reasoning_ended: bool | None = None def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports @@ -43,7 +43,7 @@ class StructuredOutputRequest: return self._check_grammar_completion() @property - def grammar(self) -> Optional[StructuredOutputGrammar]: + def grammar(self) -> StructuredOutputGrammar | None: completed = self._check_grammar_completion() return ( cast(Optional[StructuredOutputGrammar], self._grammar) diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 5f5c6bcea0..dc9bb3910f 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -4,7 +4,7 @@ from __future__ import annotations import os -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, TypeVar, Union import torch import torch.nn as nn @@ -78,8 +78,8 @@ class WorkerBase: self.is_driver_worker = is_driver_worker # Device and model state - self.device: Optional[torch.device] = None - self.model_runner: Optional[nn.Module] = None + self.device: torch.device | None = None + self.model_runner: nn.Module | None = None def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """Get specifications for KV cache implementation.""" @@ -115,8 +115,8 @@ class WorkerBase: raise NotImplementedError def execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[list[SamplerOutput]]: + self, execute_model_req: ExecuteModelRequest | None = None + ) -> list[SamplerOutput] | None: raise NotImplementedError def start_worker_execution_loop(self) -> None: @@ -198,8 +198,8 @@ class WorkerWrapperBase: group. """ self.rpc_rank = rpc_rank - self.worker: Optional[WorkerBase] = None - self.vllm_config: Optional[VllmConfig] = None + self.worker: WorkerBase | None = None + self.vllm_config: VllmConfig | None = None # do not store this `vllm_config`, `init_worker` will set the final # one. TODO: investigate if we can remove this field in # `WorkerWrapperBase`, `init_cached_hf_modules` should be