[Misc] Clean up utils (#27552)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -65,7 +65,9 @@ ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
|
|||||||
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
|
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
|
||||||
cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
|
cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
|
||||||
run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
|
run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
|
||||||
FlexibleArgumentParser = auto_mock("vllm.utils", "FlexibleArgumentParser")
|
FlexibleArgumentParser = auto_mock(
|
||||||
|
"vllm.utils.argparse_utils", "FlexibleArgumentParser"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MarkdownFormatter(HelpFormatter):
|
class MarkdownFormatter(HelpFormatter):
|
||||||
|
|||||||
@ -45,9 +45,7 @@ from vllm.entrypoints.cli.serve import ServeSubcommand
|
|||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils import (
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
FlexibleArgumentParser,
|
|
||||||
)
|
|
||||||
from vllm.utils.mem_constants import GB_bytes
|
from vllm.utils.mem_constants import GB_bytes
|
||||||
from vllm.utils.network_utils import get_open_port
|
from vllm.utils.network_utils import get_open_port
|
||||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||||
|
|||||||
@ -4,23 +4,15 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
import yaml
|
import yaml
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
|
||||||
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
||||||
|
|
||||||
from vllm.utils import (
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
FlexibleArgumentParser,
|
from ..utils import flat_product
|
||||||
bind_kv_cache,
|
|
||||||
)
|
|
||||||
from ..utils import create_new_process_for_each_test, flat_product
|
|
||||||
|
|
||||||
|
|
||||||
# Tests for FlexibleArgumentParser
|
# Tests for FlexibleArgumentParser
|
||||||
@ -256,87 +248,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
|||||||
assert "-O.mode" in caplog_vllm.text
|
assert "-O.mode" in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache():
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
ctx = {
|
|
||||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = [
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
]
|
|
||||||
bind_kv_cache(ctx, [kv_cache])
|
|
||||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
|
|
||||||
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
|
|
||||||
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2]
|
|
||||||
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3]
|
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache_kv_sharing():
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
ctx = {
|
|
||||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
|
||||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = [
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
]
|
|
||||||
shared_kv_cache_layers = {
|
|
||||||
"layers.2.self_attn": "layers.1.self_attn",
|
|
||||||
"layers.3.self_attn": "layers.0.self_attn",
|
|
||||||
}
|
|
||||||
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
|
|
||||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
|
|
||||||
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
|
|
||||||
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1]
|
|
||||||
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache_non_attention():
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
# example from Jamba PP=2
|
|
||||||
ctx = {
|
|
||||||
"model.layers.20.attn": Attention(32, 128, 0.1),
|
|
||||||
"model.layers.28.attn": Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = [
|
|
||||||
torch.zeros((1,)),
|
|
||||||
torch.zeros((1,)),
|
|
||||||
]
|
|
||||||
bind_kv_cache(ctx, [kv_cache])
|
|
||||||
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0]
|
|
||||||
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1]
|
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache_pp():
|
|
||||||
with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2):
|
|
||||||
# this test runs with 1 GPU, but we simulate 2 GPUs
|
|
||||||
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
|
|
||||||
with set_current_vllm_config(cfg):
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
ctx = {
|
|
||||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]]
|
|
||||||
bind_kv_cache(ctx, kv_cache)
|
|
||||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0]
|
|
||||||
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_specification(
|
def test_model_specification(
|
||||||
parser_with_config, cli_config_file, cli_config_file_with_model
|
parser_with_config, cli_config_file, cli_config_file_with_model
|
||||||
):
|
):
|
||||||
@ -14,7 +14,7 @@ from vllm.utils.serial_utils import (
|
|||||||
|
|
||||||
@pytest.mark.parametrize("endianness", ENDIANNESS)
|
@pytest.mark.parametrize("endianness", ENDIANNESS)
|
||||||
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
|
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
|
||||||
@torch.inference_mode
|
@torch.inference_mode()
|
||||||
def test_encode_and_decode(embed_dtype: str, endianness: str):
|
def test_encode_and_decode(embed_dtype: str, endianness: str):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
|
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
|
||||||
|
|||||||
@ -42,7 +42,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
|||||||
)
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
|||||||
@ -44,7 +44,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils impo
|
|||||||
)
|
)
|
||||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import cdiv, get_kv_cache_torch_dtype
|
from vllm.utils import get_kv_cache_torch_dtype
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
|||||||
@ -51,7 +51,8 @@ from vllm.entrypoints.utils import (
|
|||||||
with_cancellation,
|
with_cancellation,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import FlexibleArgumentParser, set_ulimit
|
from vllm.utils import set_ulimit
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_se
|
|||||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.utils.network_utils import get_tcp_uri
|
from vllm.utils.network_utils import get_tcp_uri
|
||||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||||
from vllm.v1.engine.core import EngineCoreProc
|
from vllm.v1.engine.core import EngineCoreProc
|
||||||
|
|||||||
@ -108,7 +108,8 @@ from vllm.entrypoints.utils import (
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.reasoning import ReasoningParserManager
|
from vllm.reasoning import ReasoningParserManager
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Device, FlexibleArgumentParser, set_ulimit
|
from vllm.utils import Device, set_ulimit
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||||
from vllm.utils.system_utils import decorate_logs
|
from vllm.utils.system_utils import decorate_logs
|
||||||
from vllm.v1.engine.exceptions import EngineDeadError
|
from vllm.v1.engine.exceptions import EngineDeadError
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.triton_utils import HAS_TRITON, triton
|
from vllm.triton_utils import HAS_TRITON, triton
|
||||||
from vllm.utils import round_up
|
from vllm.utils.math_utils import round_up
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from vllm.lora.ops.triton_ops import (
|
from vllm.lora.ops.triton_ops import (
|
||||||
|
|||||||
@ -48,9 +48,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_s
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import round_up
|
|
||||||
from vllm.utils.flashinfer import has_flashinfer
|
from vllm.utils.flashinfer import has_flashinfer
|
||||||
from vllm.utils.import_utils import has_triton_kernels
|
from vllm.utils.import_utils import has_triton_kernels
|
||||||
|
from vllm.utils.math_utils import round_up
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -12,12 +12,10 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import cache, partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import TYPE_CHECKING, Any, TypeVar
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
@ -28,34 +26,6 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import enable_trace_function_call, init_logger
|
from vllm.logger import enable_trace_function_call, init_logger
|
||||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||||
|
|
||||||
# Import utilities from specialized modules for backward compatibility
|
|
||||||
from vllm.utils.argparse_utils import (
|
|
||||||
FlexibleArgumentParser,
|
|
||||||
SortedHelpFormatter,
|
|
||||||
StoreBoolean,
|
|
||||||
)
|
|
||||||
from vllm.utils.math_utils import (
|
|
||||||
cdiv,
|
|
||||||
next_power_of_2,
|
|
||||||
prev_power_of_2,
|
|
||||||
round_down,
|
|
||||||
round_up,
|
|
||||||
)
|
|
||||||
from vllm.utils.platform_utils import cuda_is_initialized, xpu_is_initialized
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Argparse utilities
|
|
||||||
"FlexibleArgumentParser",
|
|
||||||
"SortedHelpFormatter",
|
|
||||||
"StoreBoolean",
|
|
||||||
# Math utilities
|
|
||||||
"cdiv",
|
|
||||||
"next_power_of_2",
|
|
||||||
"prev_power_of_2",
|
|
||||||
"round_down",
|
|
||||||
"round_up",
|
|
||||||
]
|
|
||||||
|
|
||||||
_DEPRECATED_MAPPINGS = {
|
_DEPRECATED_MAPPINGS = {
|
||||||
"cprofile": "profiling",
|
"cprofile": "profiling",
|
||||||
"cprofile_context": "profiling",
|
"cprofile_context": "profiling",
|
||||||
@ -84,12 +54,8 @@ def __dir__() -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from argparse import Namespace
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
else:
|
else:
|
||||||
Namespace = object
|
|
||||||
|
|
||||||
ModelConfig = object
|
ModelConfig = object
|
||||||
VllmConfig = object
|
VllmConfig = object
|
||||||
|
|
||||||
@ -149,37 +115,35 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AtomicCounter:
|
||||||
|
"""An atomic, thread-safe counter"""
|
||||||
|
|
||||||
|
def __init__(self, initial=0):
|
||||||
|
"""Initialize a new atomic counter to given initial value"""
|
||||||
|
self._value = initial
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def inc(self, num=1):
|
||||||
|
"""Atomically increment the counter by num and return the new value"""
|
||||||
|
with self._lock:
|
||||||
|
self._value += num
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def dec(self, num=1):
|
||||||
|
"""Atomically decrement the counter by num and return the new value"""
|
||||||
|
with self._lock:
|
||||||
|
self._value -= num
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
|
||||||
def random_uuid() -> str:
|
def random_uuid() -> str:
|
||||||
return str(uuid.uuid4().hex)
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
def update_environment_variables(envs: dict[str, str]):
|
|
||||||
for k, v in envs.items():
|
|
||||||
if k in os.environ and os.environ[k] != v:
|
|
||||||
logger.warning(
|
|
||||||
"Overwriting environment variable %s from '%s' to '%s'",
|
|
||||||
k,
|
|
||||||
os.environ[k],
|
|
||||||
v,
|
|
||||||
)
|
|
||||||
os.environ[k] = v
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def is_pin_memory_available() -> bool:
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
return current_platform.is_pin_memory_available()
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def is_uva_available() -> bool:
|
|
||||||
"""Check if Unified Virtual Addressing (UVA) is available."""
|
|
||||||
# UVA requires pinned memory.
|
|
||||||
# TODO: Add more requirements for UVA if needed.
|
|
||||||
return is_pin_memory_available()
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This function can be removed if transformer_modules classes are
|
# TODO: This function can be removed if transformer_modules classes are
|
||||||
# serialized by value when communicating between processes
|
# serialized by value when communicating between processes
|
||||||
def init_cached_hf_modules() -> None:
|
def init_cached_hf_modules() -> None:
|
||||||
@ -212,47 +176,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
|||||||
enable_trace_function_call(log_path)
|
enable_trace_function_call(log_path)
|
||||||
|
|
||||||
|
|
||||||
def weak_bind(
|
|
||||||
bound_method: Callable[..., Any],
|
|
||||||
) -> Callable[..., None]:
|
|
||||||
"""Make an instance method that weakly references
|
|
||||||
its associated instance and no-ops once that
|
|
||||||
instance is collected."""
|
|
||||||
ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined]
|
|
||||||
unbound = bound_method.__func__ # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def weak_bound(*args, **kwargs) -> None:
|
|
||||||
if inst := ref():
|
|
||||||
unbound(inst, *args, **kwargs)
|
|
||||||
|
|
||||||
return weak_bound
|
|
||||||
|
|
||||||
|
|
||||||
class AtomicCounter:
|
|
||||||
"""An atomic, thread-safe counter"""
|
|
||||||
|
|
||||||
def __init__(self, initial=0):
|
|
||||||
"""Initialize a new atomic counter to given initial value"""
|
|
||||||
self._value = initial
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def inc(self, num=1):
|
|
||||||
"""Atomically increment the counter by num and return the new value"""
|
|
||||||
with self._lock:
|
|
||||||
self._value += num
|
|
||||||
return self._value
|
|
||||||
|
|
||||||
def dec(self, num=1):
|
|
||||||
"""Atomically decrement the counter by num and return the new value"""
|
|
||||||
with self._lock:
|
|
||||||
self._value -= num
|
|
||||||
return self._value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def value(self):
|
|
||||||
return self._value
|
|
||||||
|
|
||||||
|
|
||||||
def kill_process_tree(pid: int):
|
def kill_process_tree(pid: int):
|
||||||
"""
|
"""
|
||||||
Kills all descendant processes of the given pid by sending SIGKILL.
|
Kills all descendant processes of the given pid by sending SIGKILL.
|
||||||
@ -303,13 +226,6 @@ def set_ulimit(target_soft_limit=65535):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
|
|
||||||
def get_exception_traceback():
|
|
||||||
etype, value, tb = sys.exc_info()
|
|
||||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
|
||||||
return err_str
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_force_spawn():
|
def _maybe_force_spawn():
|
||||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||||
method.
|
method.
|
||||||
@ -327,6 +243,8 @@ def _maybe_force_spawn():
|
|||||||
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
|
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
|
||||||
reasons.append("In a Ray actor and can only be spawned")
|
reasons.append("In a Ray actor and can only be spawned")
|
||||||
|
|
||||||
|
from .platform_utils import cuda_is_initialized, xpu_is_initialized
|
||||||
|
|
||||||
if cuda_is_initialized():
|
if cuda_is_initialized():
|
||||||
reasons.append("CUDA is initialized")
|
reasons.append("CUDA is initialized")
|
||||||
elif xpu_is_initialized():
|
elif xpu_is_initialized():
|
||||||
@ -356,55 +274,6 @@ def get_mp_context():
|
|||||||
return multiprocessing.get_context(mp_method)
|
return multiprocessing.get_context(mp_method)
|
||||||
|
|
||||||
|
|
||||||
def bind_kv_cache(
|
|
||||||
ctx: dict[str, Any],
|
|
||||||
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
|
|
||||||
shared_kv_cache_layers: dict[str, str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
# Bind the kv_cache tensor to Attention modules, similar to
|
|
||||||
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
|
|
||||||
# Special things handled here:
|
|
||||||
# 1. Some models have non-attention layers, e.g., Jamba
|
|
||||||
# 2. Pipeline parallelism, each rank only has a subset of layers
|
|
||||||
# 3. Encoder attention has no kv cache
|
|
||||||
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
|
|
||||||
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
|
|
||||||
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
|
|
||||||
# tensor
|
|
||||||
# 5. Some models have attention layers that share kv cache with previous
|
|
||||||
# layers, this is specified through shared_kv_cache_layers
|
|
||||||
if shared_kv_cache_layers is None:
|
|
||||||
shared_kv_cache_layers = {}
|
|
||||||
from vllm.attention import AttentionType
|
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
|
||||||
|
|
||||||
layer_need_kv_cache = [
|
|
||||||
layer_name
|
|
||||||
for layer_name in ctx
|
|
||||||
if (
|
|
||||||
hasattr(ctx[layer_name], "attn_type")
|
|
||||||
and ctx[layer_name].attn_type
|
|
||||||
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
|
|
||||||
)
|
|
||||||
and ctx[layer_name].kv_sharing_target_layer_name is None
|
|
||||||
]
|
|
||||||
layer_index_sorted = sorted(
|
|
||||||
set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
|
|
||||||
)
|
|
||||||
for layer_name in layer_need_kv_cache:
|
|
||||||
kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
|
|
||||||
forward_ctx = ctx[layer_name]
|
|
||||||
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
|
||||||
for ve, ve_kv_cache in enumerate(kv_cache):
|
|
||||||
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
|
||||||
if shared_kv_cache_layers is not None:
|
|
||||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
|
||||||
assert extract_layer_index(target_layer_name) < extract_layer_index(
|
|
||||||
layer_name
|
|
||||||
), "v0 doesn't support interleaving kv sharing"
|
|
||||||
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
def run_method(
|
def run_method(
|
||||||
obj: Any,
|
obj: Any,
|
||||||
method: str | bytes | Callable,
|
method: str | bytes | Callable,
|
||||||
|
|||||||
Reference in New Issue
Block a user