[Misc] Clean up utils (#27552)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -65,7 +65,9 @@ ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
|
||||
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
|
||||
cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
|
||||
run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
|
||||
FlexibleArgumentParser = auto_mock("vllm.utils", "FlexibleArgumentParser")
|
||||
FlexibleArgumentParser = auto_mock(
|
||||
"vllm.utils.argparse_utils", "FlexibleArgumentParser"
|
||||
)
|
||||
|
||||
|
||||
class MarkdownFormatter(HelpFormatter):
|
||||
|
||||
@ -45,9 +45,7 @@ from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import (
|
||||
FlexibleArgumentParser,
|
||||
)
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.mem_constants import GB_bytes
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
@ -4,23 +4,15 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
||||
|
||||
from vllm.utils import (
|
||||
FlexibleArgumentParser,
|
||||
bind_kv_cache,
|
||||
)
|
||||
from ..utils import create_new_process_for_each_test, flat_product
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from ..utils import flat_product
|
||||
|
||||
|
||||
# Tests for FlexibleArgumentParser
|
||||
@ -256,87 +248,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
||||
assert "-O.mode" in caplog_vllm.text
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = [
|
||||
torch.zeros((1,)),
|
||||
torch.zeros((1,)),
|
||||
torch.zeros((1,)),
|
||||
torch.zeros((1,)),
|
||||
]
|
||||
bind_kv_cache(ctx, [kv_cache])
|
||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
|
||||
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
|
||||
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2]
|
||||
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3]
|
||||
|
||||
|
||||
def test_bind_kv_cache_kv_sharing():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = [
|
||||
torch.zeros((1,)),
|
||||
torch.zeros((1,)),
|
||||
torch.zeros((1,)),
|
||||
torch.zeros((1,)),
|
||||
]
|
||||
shared_kv_cache_layers = {
|
||||
"layers.2.self_attn": "layers.1.self_attn",
|
||||
"layers.3.self_attn": "layers.0.self_attn",
|
||||
}
|
||||
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
|
||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
|
||||
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
|
||||
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1]
|
||||
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0]
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
"model.layers.20.attn": Attention(32, 128, 0.1),
|
||||
"model.layers.28.attn": Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = [
|
||||
torch.zeros((1,)),
|
||||
torch.zeros((1,)),
|
||||
]
|
||||
bind_kv_cache(ctx, [kv_cache])
|
||||
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0]
|
||||
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1]
|
||||
|
||||
|
||||
def test_bind_kv_cache_pp():
|
||||
with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2):
|
||||
# this test runs with 1 GPU, but we simulate 2 GPUs
|
||||
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
|
||||
with set_current_vllm_config(cfg):
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]]
|
||||
bind_kv_cache(ctx, kv_cache)
|
||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0]
|
||||
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
|
||||
|
||||
|
||||
def test_model_specification(
|
||||
parser_with_config, cli_config_file, cli_config_file_with_model
|
||||
):
|
||||
@ -14,7 +14,7 @@ from vllm.utils.serial_utils import (
|
||||
|
||||
@pytest.mark.parametrize("endianness", ENDIANNESS)
|
||||
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
|
||||
@torch.inference_mode
|
||||
@torch.inference_mode()
|
||||
def test_encode_and_decode(embed_dtype: str, endianness: str):
|
||||
for i in range(10):
|
||||
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
|
||||
|
||||
@ -42,7 +42,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
|
||||
@ -44,7 +44,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils impo
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv, get_kv_cache_torch_dtype
|
||||
from vllm.utils import get_kv_cache_torch_dtype
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
@ -51,7 +51,8 @@ from vllm.entrypoints.utils import (
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit
|
||||
from vllm.utils import set_ulimit
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_se
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
|
||||
@ -108,7 +108,8 @@ from vllm.entrypoints.utils import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, FlexibleArgumentParser, set_ulimit
|
||||
from vllm.utils import Device, set_ulimit
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.utils.system_utils import decorate_logs
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
|
||||
@ -13,7 +13,7 @@ import torch
|
||||
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils import round_up
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import (
|
||||
|
||||
@ -48,9 +48,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_s
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import round_up
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -12,12 +12,10 @@ import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from functools import cache, partial, wraps
|
||||
from functools import partial, wraps
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import cloudpickle
|
||||
@ -28,34 +26,6 @@ import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||
|
||||
# Import utilities from specialized modules for backward compatibility
|
||||
from vllm.utils.argparse_utils import (
|
||||
FlexibleArgumentParser,
|
||||
SortedHelpFormatter,
|
||||
StoreBoolean,
|
||||
)
|
||||
from vllm.utils.math_utils import (
|
||||
cdiv,
|
||||
next_power_of_2,
|
||||
prev_power_of_2,
|
||||
round_down,
|
||||
round_up,
|
||||
)
|
||||
from vllm.utils.platform_utils import cuda_is_initialized, xpu_is_initialized
|
||||
|
||||
__all__ = [
|
||||
# Argparse utilities
|
||||
"FlexibleArgumentParser",
|
||||
"SortedHelpFormatter",
|
||||
"StoreBoolean",
|
||||
# Math utilities
|
||||
"cdiv",
|
||||
"next_power_of_2",
|
||||
"prev_power_of_2",
|
||||
"round_down",
|
||||
"round_up",
|
||||
]
|
||||
|
||||
_DEPRECATED_MAPPINGS = {
|
||||
"cprofile": "profiling",
|
||||
"cprofile_context": "profiling",
|
||||
@ -84,12 +54,8 @@ def __dir__() -> list[str]:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
Namespace = object
|
||||
|
||||
ModelConfig = object
|
||||
VllmConfig = object
|
||||
|
||||
@ -149,37 +115,35 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
"""An atomic, thread-safe counter"""
|
||||
|
||||
def __init__(self, initial=0):
|
||||
"""Initialize a new atomic counter to given initial value"""
|
||||
self._value = initial
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def inc(self, num=1):
|
||||
"""Atomically increment the counter by num and return the new value"""
|
||||
with self._lock:
|
||||
self._value += num
|
||||
return self._value
|
||||
|
||||
def dec(self, num=1):
|
||||
"""Atomically decrement the counter by num and return the new value"""
|
||||
with self._lock:
|
||||
self._value -= num
|
||||
return self._value
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def update_environment_variables(envs: dict[str, str]):
|
||||
for k, v in envs.items():
|
||||
if k in os.environ and os.environ[k] != v:
|
||||
logger.warning(
|
||||
"Overwriting environment variable %s from '%s' to '%s'",
|
||||
k,
|
||||
os.environ[k],
|
||||
v,
|
||||
)
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@cache
|
||||
def is_pin_memory_available() -> bool:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.is_pin_memory_available()
|
||||
|
||||
|
||||
@cache
|
||||
def is_uva_available() -> bool:
|
||||
"""Check if Unified Virtual Addressing (UVA) is available."""
|
||||
# UVA requires pinned memory.
|
||||
# TODO: Add more requirements for UVA if needed.
|
||||
return is_pin_memory_available()
|
||||
|
||||
|
||||
# TODO: This function can be removed if transformer_modules classes are
|
||||
# serialized by value when communicating between processes
|
||||
def init_cached_hf_modules() -> None:
|
||||
@ -212,47 +176,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
||||
enable_trace_function_call(log_path)
|
||||
|
||||
|
||||
def weak_bind(
|
||||
bound_method: Callable[..., Any],
|
||||
) -> Callable[..., None]:
|
||||
"""Make an instance method that weakly references
|
||||
its associated instance and no-ops once that
|
||||
instance is collected."""
|
||||
ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined]
|
||||
unbound = bound_method.__func__ # type: ignore[attr-defined]
|
||||
|
||||
def weak_bound(*args, **kwargs) -> None:
|
||||
if inst := ref():
|
||||
unbound(inst, *args, **kwargs)
|
||||
|
||||
return weak_bound
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
"""An atomic, thread-safe counter"""
|
||||
|
||||
def __init__(self, initial=0):
|
||||
"""Initialize a new atomic counter to given initial value"""
|
||||
self._value = initial
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def inc(self, num=1):
|
||||
"""Atomically increment the counter by num and return the new value"""
|
||||
with self._lock:
|
||||
self._value += num
|
||||
return self._value
|
||||
|
||||
def dec(self, num=1):
|
||||
"""Atomically decrement the counter by num and return the new value"""
|
||||
with self._lock:
|
||||
self._value -= num
|
||||
return self._value
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
|
||||
def kill_process_tree(pid: int):
|
||||
"""
|
||||
Kills all descendant processes of the given pid by sending SIGKILL.
|
||||
@ -303,13 +226,6 @@ def set_ulimit(target_soft_limit=65535):
|
||||
)
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
|
||||
def get_exception_traceback():
|
||||
etype, value, tb = sys.exc_info()
|
||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
||||
return err_str
|
||||
|
||||
|
||||
def _maybe_force_spawn():
|
||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||
method.
|
||||
@ -327,6 +243,8 @@ def _maybe_force_spawn():
|
||||
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
|
||||
reasons.append("In a Ray actor and can only be spawned")
|
||||
|
||||
from .platform_utils import cuda_is_initialized, xpu_is_initialized
|
||||
|
||||
if cuda_is_initialized():
|
||||
reasons.append("CUDA is initialized")
|
||||
elif xpu_is_initialized():
|
||||
@ -356,55 +274,6 @@ def get_mp_context():
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
ctx: dict[str, Any],
|
||||
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
|
||||
shared_kv_cache_layers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
# Bind the kv_cache tensor to Attention modules, similar to
|
||||
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
|
||||
# Special things handled here:
|
||||
# 1. Some models have non-attention layers, e.g., Jamba
|
||||
# 2. Pipeline parallelism, each rank only has a subset of layers
|
||||
# 3. Encoder attention has no kv cache
|
||||
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
|
||||
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
|
||||
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
|
||||
# tensor
|
||||
# 5. Some models have attention layers that share kv cache with previous
|
||||
# layers, this is specified through shared_kv_cache_layers
|
||||
if shared_kv_cache_layers is None:
|
||||
shared_kv_cache_layers = {}
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
layer_need_kv_cache = [
|
||||
layer_name
|
||||
for layer_name in ctx
|
||||
if (
|
||||
hasattr(ctx[layer_name], "attn_type")
|
||||
and ctx[layer_name].attn_type
|
||||
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
|
||||
)
|
||||
and ctx[layer_name].kv_sharing_target_layer_name is None
|
||||
]
|
||||
layer_index_sorted = sorted(
|
||||
set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
|
||||
)
|
||||
for layer_name in layer_need_kv_cache:
|
||||
kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
|
||||
forward_ctx = ctx[layer_name]
|
||||
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
||||
for ve, ve_kv_cache in enumerate(kv_cache):
|
||||
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
||||
if shared_kv_cache_layers is not None:
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
assert extract_layer_index(target_layer_name) < extract_layer_index(
|
||||
layer_name
|
||||
), "v0 doesn't support interleaving kv sharing"
|
||||
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
|
||||
|
||||
|
||||
def run_method(
|
||||
obj: Any,
|
||||
method: str | bytes | Callable,
|
||||
|
||||
Reference in New Issue
Block a user