[V0 Deprecation] Remove V0 executors (#27142)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -157,11 +157,9 @@ def test_models_distributed(
|
||||
and distributed_executor_backend == "ray"
|
||||
and attention_backend == ""
|
||||
and test_suite == "L4"
|
||||
and enable_prompt_embeds
|
||||
): # noqa
|
||||
if enable_prompt_embeds:
|
||||
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
||||
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
||||
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||
|
||||
if attention_backend:
|
||||
monkeypatch_context.setenv(
|
||||
|
||||
@ -18,8 +18,8 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from vllm import initialize_ray_cluster
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.executor.ray_utils import _wait_until_pg_removed
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.v1.executor.ray_utils import _wait_until_pg_removed
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
||||
@ -305,10 +305,8 @@ def _compare_tp(
|
||||
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
|
||||
|
||||
if distributed_backend == "ray":
|
||||
# For V1, test Ray Compiled Graph for all the tests
|
||||
# Test Ray Compiled Graph for all the tests
|
||||
pp_env = {
|
||||
"VLLM_USE_RAY_COMPILED_DAG": "1",
|
||||
"VLLM_USE_RAY_SPMD_WORKER": "1",
|
||||
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
|
||||
}
|
||||
# Temporary. Currently when zeromq + SPMD is used, it does not properly
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import UniProcExecutor
|
||||
from vllm.v1.executor import UniProcExecutor
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
MODEL_REF = "facebook/opt-125m"
|
||||
|
||||
@ -15,7 +15,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.uniproc_executor import UniProcExecutor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
@ -17,8 +17,6 @@ import regex as re
|
||||
# add to this list if absolutely necessary and after careful security review.
|
||||
ALLOWED_FILES = {
|
||||
# pickle
|
||||
"vllm/v1/serial_utils.py",
|
||||
"vllm/v1/executor/multiproc_executor.py",
|
||||
"vllm/multimodal/hasher.py",
|
||||
"vllm/transformers_utils/config.py",
|
||||
"vllm/model_executor/models/registry.py",
|
||||
@ -38,11 +36,13 @@ ALLOWED_FILES = {
|
||||
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
|
||||
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
|
||||
# cloudpickle
|
||||
"vllm/executor/mp_distributed_executor.py",
|
||||
"vllm/executor/ray_distributed_executor.py",
|
||||
"vllm/v1/executor/multiproc_executor.py",
|
||||
"vllm/v1/executor/ray_executor.py",
|
||||
"vllm/entrypoints/llm.py",
|
||||
"vllm/utils/__init__.py",
|
||||
"tests/utils.py",
|
||||
# pickle and cloudpickle
|
||||
"vllm/v1/serial_utils.py",
|
||||
}
|
||||
|
||||
PICKLE_RE = re.compile(
|
||||
|
||||
@ -21,7 +21,7 @@ MODULE_ATTRS = {
|
||||
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
|
||||
"LLMEngine": ".engine.llm_engine:LLMEngine",
|
||||
"LLM": ".entrypoints.llm:LLM",
|
||||
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
|
||||
"initialize_ray_cluster": ".v1.executor.ray_utils:initialize_ray_cluster",
|
||||
"PromptType": ".inputs:PromptType",
|
||||
"TextPrompt": ".inputs:TextPrompt",
|
||||
"TokensPrompt": ".inputs:TokensPrompt",
|
||||
@ -45,7 +45,6 @@ if typing.TYPE_CHECKING:
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import (
|
||||
@ -62,6 +61,7 @@ if typing.TYPE_CHECKING:
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.executor.ray_utils import initialize_ray_cluster
|
||||
|
||||
from ._bc_linter import bc_linter_include, bc_linter_skip
|
||||
else:
|
||||
|
||||
@ -25,11 +25,11 @@ if TYPE_CHECKING:
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.v1.executor import Executor
|
||||
else:
|
||||
RuntimeEnv = Any
|
||||
PlacementGroup = Any
|
||||
ExecutorBase = Any
|
||||
Executor = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -189,7 +189,7 @@ class ParallelConfig:
|
||||
"""ray distributed model workers placement group."""
|
||||
|
||||
distributed_executor_backend: (
|
||||
str | DistributedExecutorBackend | type[ExecutorBase] | None
|
||||
str | DistributedExecutorBackend | type[Executor] | None
|
||||
) = None
|
||||
"""Backend to use for distributed model
|
||||
workers, either "ray" or "mp" (multiprocessing). If the product
|
||||
@ -511,7 +511,7 @@ class ParallelConfig:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
|
||||
from vllm.executor import ray_utils
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
backend: DistributedExecutorBackend = "mp"
|
||||
ray_found = ray_utils.ray_is_available()
|
||||
@ -553,6 +553,12 @@ class ParallelConfig:
|
||||
if self.distributed_executor_backend is None and self.world_size == 1:
|
||||
self.distributed_executor_backend = "uni"
|
||||
|
||||
if self.max_parallel_loading_workers is not None:
|
||||
logger.warning(
|
||||
"max_parallel_loading_workers is currently "
|
||||
"not supported and will be ignored."
|
||||
)
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
return self.distributed_executor_backend == "ray" or (
|
||||
@ -563,7 +569,7 @@ class ParallelConfig:
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
# Lazy import to avoid circular import
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.v1.executor import Executor
|
||||
|
||||
# Enable batch invariance settings if requested
|
||||
if vllm_is_batch_invariant():
|
||||
@ -574,17 +580,17 @@ class ParallelConfig:
|
||||
and not isinstance(self.distributed_executor_backend, str)
|
||||
and not (
|
||||
isinstance(self.distributed_executor_backend, type)
|
||||
and issubclass(self.distributed_executor_backend, ExecutorBase)
|
||||
and issubclass(self.distributed_executor_backend, Executor)
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend "
|
||||
f"{self.distributed_executor_backend}. Supported "
|
||||
"values are 'ray', 'mp' 'uni', 'external_launcher', "
|
||||
" custom ExecutorBase subclass or its import path."
|
||||
" custom Executor subclass or its import path."
|
||||
)
|
||||
if self.use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
ray_utils.assert_ray_available()
|
||||
|
||||
|
||||
@ -107,12 +107,6 @@ class SchedulerConfig:
|
||||
NOTE: This is not currently configurable. It will be overridden by
|
||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||
|
||||
send_delta_data: bool = False
|
||||
"""Private API. If used, scheduler sends delta data to
|
||||
workers instead of an entire data. It should be enabled only
|
||||
when SPMD worker architecture is enabled. I.e.,
|
||||
VLLM_USE_RAY_SPMD_WORKER=1"""
|
||||
|
||||
policy: SchedulerPolicy = "fcfs"
|
||||
"""The scheduling policy to use:\n
|
||||
- "fcfs" means first come first served, i.e. requests are handled in order
|
||||
|
||||
@ -31,7 +31,7 @@ if not USE_TPU_INFERENCE:
|
||||
)
|
||||
|
||||
if USE_RAY:
|
||||
from vllm.executor import ray_utils
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
|
||||
class TpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
@ -88,12 +88,12 @@ from vllm.utils.network_utils import get_ip
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.executor import Executor
|
||||
else:
|
||||
ExecutorBase = Any
|
||||
Executor = Any
|
||||
QuantizationMethods = Any
|
||||
LoadFormats = Any
|
||||
UsageContext = Any
|
||||
@ -369,7 +369,7 @@ class EngineArgs:
|
||||
# is intended for expert use only. The API may change without
|
||||
# notice.
|
||||
distributed_executor_backend: (
|
||||
str | DistributedExecutorBackend | type[ExecutorBase] | None
|
||||
str | DistributedExecutorBackend | type[Executor] | None
|
||||
) = ParallelConfig.distributed_executor_backend
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
@ -1549,7 +1549,6 @@ class EngineArgs:
|
||||
disable_chunked_mm_input=self.disable_chunked_mm_input,
|
||||
is_multimodal_model=model_config.is_multimodal_model,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray),
|
||||
policy=self.scheduling_policy,
|
||||
scheduler_cls=self.scheduler_cls,
|
||||
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.utils import (
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
|
||||
|
||||
|
||||
23
vllm/envs.py
23
vllm/envs.py
@ -56,8 +56,6 @@ if TYPE_CHECKING:
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
|
||||
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
|
||||
VLLM_USE_RAY_SPMD_WORKER: bool = False
|
||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
|
||||
@ -623,22 +621,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),
|
||||
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
|
||||
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
|
||||
# If the env var is set, then all workers will execute as separate
|
||||
# processes from the engine, and we use the same mechanism to trigger
|
||||
# execution on all workers.
|
||||
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
|
||||
"VLLM_USE_RAY_SPMD_WORKER": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))
|
||||
),
|
||||
# If the env var is set, it uses the Ray's Compiled Graph
|
||||
# (previously known as ADAG) API which optimizes the
|
||||
# control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
# Note that this variable is set to 1 in V1 by default
|
||||
# when ray distributed executor is used.
|
||||
"VLLM_USE_RAY_COMPILED_DAG": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))
|
||||
),
|
||||
# If the env var is set, Ray Compiled Graph uses the specified
|
||||
# channel type to communicate between workers belonging to
|
||||
# different pipeline-parallel stages.
|
||||
@ -646,20 +628,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "auto": use the default channel type
|
||||
# - "nccl": use NCCL for communication
|
||||
# - "shm": use shared memory and gRPC for communication
|
||||
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
|
||||
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices(
|
||||
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto", ["auto", "nccl", "shm"]
|
||||
),
|
||||
# If the env var is set, it enables GPU communication overlap
|
||||
# (experimental feature) in Ray's Compiled Graph. This flag is ignored if
|
||||
# VLLM_USE_RAY_COMPILED_DAG is not set.
|
||||
# (experimental feature) in Ray's Compiled Graph.
|
||||
"VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0"))
|
||||
),
|
||||
# If the env var is set, it uses a Ray Communicator wrapping
|
||||
# vLLM's pipeline parallelism communicator to interact with Ray's
|
||||
# Compiled Graph. Otherwise, it uses Ray's NCCL communicator.
|
||||
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
|
||||
"VLLM_USE_RAY_WRAPPED_PP_COMM": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1"))
|
||||
),
|
||||
|
||||
@ -1,393 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.platforms
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.async_utils import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class ExecutorBase(ABC):
|
||||
"""Base class for all executors.
|
||||
|
||||
An executor is responsible for executing the model on one device,
|
||||
or it can be a distributed executor
|
||||
that can execute the model on multiple devices.
|
||||
"""
|
||||
|
||||
uses_ray: bool # whether the executor uses Ray for orchestration.
|
||||
supports_pp: bool = False # whether the executor supports PP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self._init_executor()
|
||||
self.is_sleeping = False
|
||||
self.sleeping_tags: set[str] = set()
|
||||
self.kv_output_aggregator: KVOutputAggregator | None = None
|
||||
|
||||
@abstractmethod
|
||||
def _init_executor(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: str | Callable[[WorkerBase], _R],
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> list[_R]:
|
||||
"""
|
||||
Execute an RPC call on all workers.
|
||||
|
||||
Args:
|
||||
method: Name of the worker method to execute, or a callable that
|
||||
is serialized and sent to all workers to execute.
|
||||
|
||||
If the method is a callable, it should accept an additional
|
||||
`self` argument, in addition to the arguments passed in `args`
|
||||
and `kwargs`. The `self` argument will be the worker object.
|
||||
timeout: Maximum time in seconds to wait for execution. Raises a
|
||||
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
|
||||
args: Positional arguments to pass to the worker method.
|
||||
kwargs: Keyword arguments to pass to the worker method.
|
||||
|
||||
Returns:
|
||||
A list containing the results from each worker.
|
||||
|
||||
Note:
|
||||
It is recommended to use this API to only pass control messages,
|
||||
and set up data-plane communication to pass data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
swappable CPU KV cache.
|
||||
|
||||
Normally, this should simply delegate to the underlying Worker. Some
|
||||
ExecutorBase may require modification of the result, e.g. to ensure the
|
||||
selected cache sizes are compatible with all workers.
|
||||
|
||||
Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
|
||||
`num_gpu_blocks` are blocks that are "active" on the device and can be
|
||||
appended to.
|
||||
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
|
||||
appended to.
|
||||
"""
|
||||
results = self.collective_rpc("determine_num_available_blocks")
|
||||
a = min([r[0] for r in results])
|
||||
b = min([r[1] for r in results])
|
||||
return a, b
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker."""
|
||||
# NOTE: This is logged in the executor because there can be >1 workers.
|
||||
logger.info(
|
||||
"# %s blocks: %d, # CPU blocks: %d",
|
||||
vllm.platforms.current_platform.device_name,
|
||||
num_gpu_blocks,
|
||||
num_cpu_blocks,
|
||||
)
|
||||
max_concurrency = (
|
||||
num_gpu_blocks
|
||||
* self.cache_config.block_size
|
||||
/ self.model_config.max_model_len
|
||||
)
|
||||
logger.info(
|
||||
"Maximum concurrency for %s tokens per request: %.2fx",
|
||||
self.model_config.max_model_len,
|
||||
max_concurrency,
|
||||
)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
|
||||
|
||||
@cached_property # Avoid unnecessary RPC calls
|
||||
def supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
output = self.collective_rpc("get_supported_tasks")
|
||||
return output[0]
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> list[SamplerOutput]:
|
||||
output = self.collective_rpc("execute_model", args=(execute_model_req,))
|
||||
assert output[0] is not None
|
||||
return output[0]
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
"""Releases parallel workers from model loop."""
|
||||
return
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("add_lora", args=(lora_request,)))
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("remove_lora", args=(lora_id,)))
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
sets = self.collective_rpc("list_loras")
|
||||
for s in sets:
|
||||
assert s == sets[0], "All workers should have the same LORAs."
|
||||
return sets[0]
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""Reset the multi-modal cache in each worker."""
|
||||
self.collective_rpc("reset_mm_cache")
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.collective_rpc("start_profile")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.collective_rpc("stop_profile")
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
if self.is_sleeping:
|
||||
logger.warning("Executor is already sleeping.")
|
||||
return
|
||||
time_before_sleep = time.perf_counter()
|
||||
self.collective_rpc("sleep", kwargs=dict(level=level))
|
||||
time_after_sleep = time.perf_counter()
|
||||
self.sleeping_tags = {"weights", "kv_cache"}
|
||||
self.is_sleeping = True
|
||||
logger.info(
|
||||
"It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
|
||||
)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
if not self.is_sleeping:
|
||||
logger.warning("Executor is not sleeping.")
|
||||
return
|
||||
if tags:
|
||||
for tag in tags:
|
||||
if tag not in self.sleeping_tags:
|
||||
logger.warning(
|
||||
"Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
|
||||
)
|
||||
return
|
||||
time_before_wakeup = time.perf_counter()
|
||||
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
|
||||
time_after_wakeup = time.perf_counter()
|
||||
logger.info(
|
||||
"It took %.6f seconds to wake up tags %s.",
|
||||
time_after_wakeup - time_before_wakeup,
|
||||
tags if tags is not None else self.sleeping_tags,
|
||||
)
|
||||
if tags:
|
||||
for tag in tags:
|
||||
self.sleeping_tags.remove(tag)
|
||||
else:
|
||||
self.sleeping_tags.clear()
|
||||
if not self.sleeping_tags:
|
||||
self.is_sleeping = False
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str | None = None,
|
||||
max_size: int | None = None,
|
||||
) -> None:
|
||||
self.collective_rpc(
|
||||
"save_sharded_state",
|
||||
kwargs=dict(path=path, pattern=pattern, max_size=max_size),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def check_health(self) -> None:
|
||||
"""Checks if the executor is healthy. If not, it should raise an
|
||||
exception."""
|
||||
raise NotImplementedError
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the executor."""
|
||||
self.collective_rpc("shutdown")
|
||||
|
||||
async def execute_model_async(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> list[SamplerOutput]:
|
||||
"""Executes one model step on the given sequences."""
|
||||
output = await make_async(self.execute_model)(execute_model_req)
|
||||
return output
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Releases parallel workers from model loop."""
|
||||
return
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
"""Checks if the executor is healthy. If not, it should raise an
|
||||
exception."""
|
||||
self.check_health()
|
||||
|
||||
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
|
||||
"""Init KVOutputAggregator"""
|
||||
self.kv_output_aggregator = KVOutputAggregator(
|
||||
finished_count or self.parallel_config.world_size
|
||||
)
|
||||
|
||||
|
||||
class DistributedExecutorBase(ExecutorBase):
|
||||
"""Abstract superclass of distributed executor implementations."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# This is non-None when the execute model loop is running
|
||||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||
self.parallel_worker_tasks: Any | Awaitable[Any] | None = None
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> list[SamplerOutput]:
|
||||
# TODO: unify into collective_rpc
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
async_run_tensor_parallel_workers_only=True,
|
||||
)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
driver_outputs = self._driver_execute_model(execute_model_req)
|
||||
assert driver_outputs is not None
|
||||
return driver_outputs
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
self._driver_execute_model(execute_model_req=None)
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
self._wait_for_tasks_completion(parallel_worker_tasks)
|
||||
|
||||
@abstractmethod
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest | None
|
||||
) -> list[SamplerOutput] | None:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution loop
|
||||
running in each of the remote workers. In this case, this method
|
||||
returns None. Otherwise, this method returns the model output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: str | Callable,
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> list[Any]:
|
||||
return self._run_workers(method, *args, **(kwargs or {}))
|
||||
|
||||
@abstractmethod
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str | Callable,
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: int | None = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
|
||||
# TODO: simplify and merge with collective_rpc
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def execute_model_async(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> list[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
# Start model execution loop running in the parallel workers
|
||||
self.parallel_worker_tasks = asyncio.create_task(
|
||||
self._start_worker_execution_loop()
|
||||
)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return await self._driver_execute_model_async(execute_model_req)
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
await self._driver_execute_model_async()
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
await parallel_worker_tasks
|
||||
|
||||
@abstractmethod
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest | None = None,
|
||||
) -> list[SamplerOutput]:
|
||||
"""Execute the model asynchronously in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _start_worker_execution_loop(self):
|
||||
"""Run execution loop on all workers. It guarantees all workers run
|
||||
the loop or None of them is running the loop. Loop can be stopped by
|
||||
`stop_remote_worker_execution_loop`.
|
||||
The API is idempotent (guarantee only 1 loop run at any moment)."""
|
||||
raise NotImplementedError
|
||||
@ -1,36 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from array import array
|
||||
from typing import Any
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
|
||||
|
||||
|
||||
def encode_hook(obj: Any) -> Any:
|
||||
"""Custom msgspec enc hook that supports array types and MultiModalKwargs.
|
||||
|
||||
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
||||
"""
|
||||
if isinstance(obj, array):
|
||||
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
|
||||
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
|
||||
f"Given array has a type code of {obj.typecode}."
|
||||
)
|
||||
return obj.tobytes()
|
||||
if isinstance(obj, MultiModalKwargs):
|
||||
return dict(obj)
|
||||
|
||||
|
||||
def decode_hook(type: type, obj: Any) -> Any:
|
||||
"""Custom msgspec dec hook that supports array types and MultiModalKwargs.
|
||||
|
||||
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
||||
"""
|
||||
if type is array:
|
||||
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
||||
deserialized.frombytes(obj)
|
||||
return deserialized
|
||||
if type is MultiModalKwargs:
|
||||
return MultiModalKwargs(obj)
|
||||
@ -5,7 +5,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -92,12 +91,3 @@ class IntermediateTensors:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
class ExecuteModelRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True,
|
||||
): # type: ignore[call-arg]
|
||||
# Placeholder. Remove.
|
||||
pass
|
||||
|
||||
@ -943,7 +943,7 @@ def maybe_register_config_serialize_by_value() -> None:
|
||||
cloudpickle.register_pickle_by_value(transformers_modules)
|
||||
|
||||
# ray vendors its own version of cloudpickle
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.v1.executor.ray_utils import ray
|
||||
|
||||
if ray:
|
||||
ray.cloudpickle.register_pickle_by_value(transformers_modules)
|
||||
|
||||
@ -39,7 +39,7 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import (
|
||||
StatLoggerFactory,
|
||||
StatLoggerManager,
|
||||
|
||||
@ -60,7 +60,7 @@ from vllm.v1.engine.utils import (
|
||||
EngineZmqAddresses,
|
||||
get_device_indices,
|
||||
)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
@ -322,7 +322,6 @@ class EngineCore:
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = self.model_executor.execute_model(scheduler_output)
|
||||
|
||||
assert isinstance(model_output, ModelRunnerOutput)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
@ -364,7 +363,7 @@ class EngineCore:
|
||||
if self.scheduler.has_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
batch_queue.appendleft((future, scheduler_output)) # type: ignore[arg-type]
|
||||
batch_queue.appendleft((future, scheduler_output))
|
||||
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if (
|
||||
@ -463,14 +462,6 @@ class EngineCore:
|
||||
) -> list[_R]:
|
||||
return self.model_executor.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config,
|
||||
) -> None:
|
||||
self.model_executor.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config,
|
||||
)
|
||||
|
||||
def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
|
||||
"""Preprocess the request.
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ from vllm.v1.engine.utils import (
|
||||
CoreEngineProcManager,
|
||||
launch_core_engines,
|
||||
)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -32,7 +32,7 @@ from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
@ -23,7 +23,7 @@ from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.utils import get_mp_context
|
||||
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .abstract import Executor
|
||||
from .uniproc_executor import UniProcExecutor
|
||||
|
||||
__all__ = ["Executor", "UniProcExecutor"]
|
||||
|
||||
@ -1,31 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from functools import cached_property
|
||||
from typing import Literal, TypeVar, overload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0,
|
||||
)
|
||||
from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R")
|
||||
|
||||
FailureCallback = Callable[[], None]
|
||||
|
||||
|
||||
class Executor(ExecutorBase):
|
||||
class Executor(ABC):
|
||||
"""Abstract base class for vLLM executors."
|
||||
|
||||
An executor is responsible for executing the model on one device,
|
||||
or it can be a distributed executor that can execute the model on multiple devices.
|
||||
"""
|
||||
Abstract class for v1 executors, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in ExecutorBase"""
|
||||
|
||||
uses_ray: bool = False # whether the executor uses Ray for orchestration.
|
||||
supports_pp: bool = False # whether the executor supports PP
|
||||
|
||||
@staticmethod
|
||||
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
|
||||
@ -34,16 +43,14 @@ class Executor(ExecutorBase):
|
||||
distributed_executor_backend = parallel_config.distributed_executor_backend
|
||||
# distributed_executor_backend must be set in VllmConfig.__post_init__
|
||||
if isinstance(distributed_executor_backend, type):
|
||||
if not issubclass(distributed_executor_backend, ExecutorBase):
|
||||
if not issubclass(distributed_executor_backend, Executor):
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorBase. Got {distributed_executor_backend}."
|
||||
f"Executor. Got {distributed_executor_backend}."
|
||||
)
|
||||
executor_class = distributed_executor_backend
|
||||
elif distributed_executor_backend == "ray":
|
||||
from vllm.v1.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor,
|
||||
)
|
||||
from vllm.v1.executor.ray_executor import RayDistributedExecutor
|
||||
|
||||
executor_class = RayDistributedExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
@ -51,6 +58,8 @@ class Executor(ExecutorBase):
|
||||
|
||||
executor_class = MultiprocExecutor
|
||||
elif distributed_executor_backend == "uni":
|
||||
from vllm.v1.executor.uniproc_executor import UniProcExecutor
|
||||
|
||||
executor_class = UniProcExecutor
|
||||
elif distributed_executor_backend == "external_launcher":
|
||||
# TODO: make v1 scheduling deterministic
|
||||
@ -58,10 +67,10 @@ class Executor(ExecutorBase):
|
||||
executor_class = ExecutorWithExternalLauncher
|
||||
elif isinstance(distributed_executor_backend, str):
|
||||
executor_class = resolve_obj_by_qualname(distributed_executor_backend)
|
||||
if not issubclass(executor_class, ExecutorBase):
|
||||
if not issubclass(executor_class, Executor):
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorBase. Got {executor_class}."
|
||||
f"Executor. Got {executor_class}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -69,6 +78,29 @@ class Executor(ExecutorBase):
|
||||
)
|
||||
return executor_class
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self._init_executor()
|
||||
self.is_sleeping = False
|
||||
self.sleeping_tags: set[str] = set()
|
||||
self.kv_output_aggregator: KVOutputAggregator | None = None
|
||||
|
||||
@abstractmethod
|
||||
def _init_executor(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
@ -77,7 +109,7 @@ class Executor(ExecutorBase):
|
||||
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
def register_failure_callback(self, callback: FailureCallback): # noqa: B027
|
||||
"""
|
||||
Register a function to be called if the executor enters a permanent
|
||||
failed state.
|
||||
@ -90,22 +122,78 @@ class Executor(ExecutorBase):
|
||||
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
|
||||
return self.collective_rpc("get_kv_cache_spec")
|
||||
|
||||
@overload
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: str | Callable,
|
||||
method: str | Callable[[WorkerBase], _R],
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict | None = None,
|
||||
non_block: bool = False,
|
||||
) -> list[Any]:
|
||||
non_block: Literal[False] = False,
|
||||
) -> list[_R]:
|
||||
"""
|
||||
Execute an RPC call on all workers.
|
||||
|
||||
Args:
|
||||
method: Name of the worker method to execute, or a callable that
|
||||
is serialized and sent to all workers to execute.
|
||||
|
||||
If the method is a callable, it should accept an additional
|
||||
`self` argument, in addition to the arguments passed in `args`
|
||||
and `kwargs`. The `self` argument will be the worker object.
|
||||
timeout: Maximum time in seconds to wait for execution. Raises a
|
||||
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
|
||||
args: Positional arguments to pass to the worker method.
|
||||
kwargs: Keyword arguments to pass to the worker method.
|
||||
non_block: If `True`, returns a list of Futures instead of waiting
|
||||
for the results.
|
||||
|
||||
Returns:
|
||||
A list containing the results from each worker.
|
||||
|
||||
Note:
|
||||
It is recommended to use this API to only pass control messages,
|
||||
and set up data-plane communication to pass data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@overload
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: str | Callable[[WorkerBase], _R],
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict | None = None,
|
||||
non_block: Literal[True] = True,
|
||||
) -> list[Future[_R]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collective_rpc(
|
||||
self, method, timeout=None, args=(), kwargs=None, non_block: bool = False
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@overload
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: bool = False,
|
||||
non_block: Literal[False] = False,
|
||||
) -> ModelRunnerOutput:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: Literal[True] = True,
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
output = self.collective_rpc(
|
||||
output = self.collective_rpc( # type: ignore[call-overload]
|
||||
"execute_model", args=(scheduler_output,), non_block=non_block
|
||||
)
|
||||
return output[0]
|
||||
@ -114,7 +202,7 @@ class Executor(ExecutorBase):
|
||||
self.collective_rpc("execute_dummy_batch")
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
output = self.collective_rpc("take_draft_token_ids")
|
||||
output: list[DraftTokenIds] = self.collective_rpc("take_draft_token_ids")
|
||||
return output[0]
|
||||
|
||||
@property
|
||||
@ -124,19 +212,120 @@ class Executor(ExecutorBase):
|
||||
def profile(self, is_start: bool = True):
|
||||
self.collective_rpc("profile", args=(is_start,))
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str | None = None,
|
||||
max_size: int | None = None,
|
||||
) -> None:
|
||||
self.collective_rpc(
|
||||
"save_sharded_state",
|
||||
kwargs=dict(path=path, pattern=pattern, max_size=max_size),
|
||||
)
|
||||
|
||||
class UniProcExecutor(UniProcExecutorV0, Executor):
|
||||
pass
|
||||
@abstractmethod
|
||||
def check_health(self) -> None:
|
||||
"""Checks if the executor is healthy. If not, it should raise an
|
||||
exception."""
|
||||
raise NotImplementedError
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the executor."""
|
||||
self.collective_rpc("shutdown")
|
||||
|
||||
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
|
||||
"""Init KVOutputAggregator"""
|
||||
self.kv_output_aggregator = KVOutputAggregator(
|
||||
finished_count or self.parallel_config.world_size
|
||||
)
|
||||
|
||||
@cached_property # Avoid unnecessary RPC calls
|
||||
def supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
output: list[tuple[SupportedTask, ...]]
|
||||
output = self.collective_rpc("get_supported_tasks")
|
||||
return output[0]
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("add_lora", args=(lora_request,)))
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("remove_lora", args=(lora_id,)))
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
sets: list[set[int]] = self.collective_rpc("list_loras")
|
||||
for s in sets:
|
||||
assert s == sets[0], "All workers should have the same LORAs."
|
||||
return sets[0]
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""Reset the multi-modal cache in each worker."""
|
||||
self.collective_rpc("reset_mm_cache")
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.collective_rpc("start_profile")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.collective_rpc("stop_profile")
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
if self.is_sleeping:
|
||||
logger.warning("Executor is already sleeping.")
|
||||
return
|
||||
time_before_sleep = time.perf_counter()
|
||||
self.collective_rpc("sleep", kwargs=dict(level=level))
|
||||
time_after_sleep = time.perf_counter()
|
||||
self.sleeping_tags = {"weights", "kv_cache"}
|
||||
self.is_sleeping = True
|
||||
logger.info(
|
||||
"It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
|
||||
)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
if not self.is_sleeping:
|
||||
logger.warning("Executor is not sleeping.")
|
||||
return
|
||||
if tags:
|
||||
for tag in tags:
|
||||
if tag not in self.sleeping_tags:
|
||||
logger.warning(
|
||||
"Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
|
||||
)
|
||||
return
|
||||
time_before_wakeup = time.perf_counter()
|
||||
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
|
||||
time_after_wakeup = time.perf_counter()
|
||||
logger.info(
|
||||
"It took %.6f seconds to wake up tags %s.",
|
||||
time_after_wakeup - time_before_wakeup,
|
||||
tags if tags is not None else self.sleeping_tags,
|
||||
)
|
||||
if tags:
|
||||
for tag in tags:
|
||||
self.sleeping_tags.remove(tag)
|
||||
else:
|
||||
self.sleeping_tags.clear()
|
||||
if not self.sleeping_tags:
|
||||
self.is_sleeping = False
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
|
||||
def determine_available_memory(self) -> list[int]: # in bytes
|
||||
# same as determine_num_available_blocks in v0,
|
||||
# we need to get the min across all ranks.
|
||||
memory = super().determine_available_memory()
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.v1.executor.uniproc_executor import ( # noqa: E402
|
||||
ExecutorWithExternalLauncher as _ExecutorWithExternalLauncher,
|
||||
)
|
||||
from vllm.v1.executor.uniproc_executor import ( # noqa: E402
|
||||
UniProcExecutor as _UniProcExecutor,
|
||||
)
|
||||
|
||||
cpu_group = get_world_group().cpu_group
|
||||
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
|
||||
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
return [memory_tensor.item()]
|
||||
# For backwards compatibility.
|
||||
UniProcExecutor = _UniProcExecutor
|
||||
ExecutorWithExternalLauncher = _ExecutorWithExternalLauncher
|
||||
|
||||
@ -179,7 +179,7 @@ class MultiprocExecutor(Executor):
|
||||
else:
|
||||
self.failure_callback = callback
|
||||
|
||||
def execute_model(
|
||||
def execute_model( # type: ignore[override]
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: bool = False,
|
||||
@ -204,6 +204,7 @@ class MultiprocExecutor(Executor):
|
||||
)
|
||||
|
||||
# aggregate all workers output to a single output
|
||||
assert self.kv_output_aggregator is not None
|
||||
if non_block:
|
||||
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||
|
||||
@ -1,111 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0,
|
||||
from vllm.v1.executor.ray_executor import (
|
||||
RayDistributedExecutor as _RayDistributedExecutor,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
"""A wrapper around Ray output reference to meet the interface
|
||||
of .execute_model(): The top level (core busy loop) expects .result() api
|
||||
to block and return a single output.
|
||||
|
||||
If aggregator is provided, the outputs from all workers are aggregated upon
|
||||
the result() call. If not only the first worker's output is returned.
|
||||
"""
|
||||
|
||||
def __init__(self, refs, aggregator: KVOutputAggregator | None = None):
|
||||
super().__init__()
|
||||
self.refs = refs
|
||||
self.aggregator = aggregator
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise NotImplementedError("timeout is not supported")
|
||||
|
||||
if self.aggregator is None:
|
||||
return self.refs[0].get()
|
||||
|
||||
outputs = [ref.get() for ref in self.refs]
|
||||
return self.aggregator.aggregate(outputs, output_rank=0)
|
||||
|
||||
|
||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
"""Ray distributed executor using Ray Compiled Graphs."""
|
||||
|
||||
supports_pp: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
super()._init_executor()
|
||||
|
||||
# KV connector setup
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
"""Ray distributed executor supports pipeline parallelism,
|
||||
meaning that it allows PP size batches to be executed concurrently.
|
||||
"""
|
||||
if self.scheduler_config.async_scheduling:
|
||||
return 2
|
||||
return self.parallel_config.pipeline_parallel_size
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: bool = False,
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
"""Execute the model on the Ray workers.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output to execute.
|
||||
non_block: If True, the method will return a Future.
|
||||
|
||||
Returns:
|
||||
The model runner output.
|
||||
"""
|
||||
# Build the compiled DAG for the first time.
|
||||
if self.forward_dag is None: # type: ignore
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||
|
||||
if not self.has_connector:
|
||||
# Get output only from a single worker (output_rank)
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if not non_block:
|
||||
return refs[0].get()
|
||||
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs)
|
||||
|
||||
# Get output from all workers when connector is present
|
||||
if not non_block:
|
||||
# Block and get results from all workers
|
||||
outputs = [ref.get() for ref in refs]
|
||||
return self.kv_output_aggregator.aggregate(outputs)
|
||||
|
||||
# Return a future that will aggregate outputs from all workers
|
||||
return FutureWrapper(refs, self.kv_output_aggregator)
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self._run_workers("reinitialize_distributed", reconfig_request)
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
self.shutdown()
|
||||
# For backwards compatibility.
|
||||
RayDistributedExecutor = _RayDistributedExecutor
|
||||
|
||||
@ -1,31 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import DistributedExecutorBase
|
||||
from vllm.executor.msgspec_utils import encode_hook
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, initialize_ray_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils.async_utils import make_async
|
||||
from vllm.utils.network_utils import (
|
||||
get_distributed_init_method,
|
||||
get_ip,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.ray_utils import (
|
||||
FutureWrapper,
|
||||
RayWorkerWrapper,
|
||||
initialize_ray_cluster,
|
||||
ray,
|
||||
)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
if ray is not None:
|
||||
from ray.actor import ActorHandle
|
||||
@ -53,7 +56,7 @@ class RayWorkerMetaData:
|
||||
ip: str = ""
|
||||
|
||||
|
||||
class RayDistributedExecutor(DistributedExecutorBase):
|
||||
class RayDistributedExecutor(Executor):
|
||||
"""Ray-based distributed executor"""
|
||||
|
||||
# These env vars are worker-specific, therefore are NOT copied
|
||||
@ -69,37 +72,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}
|
||||
|
||||
uses_ray: bool = True
|
||||
supports_pp: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self.forward_dag: ray.dag.CompiledDAG | None = None
|
||||
if envs.VLLM_USE_V1:
|
||||
# V1 uses SPMD worker and compiled DAG
|
||||
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
|
||||
|
||||
# For TPU or XPU, avoid compiling NVIDIA's NCCL
|
||||
if current_platform.is_tpu() or current_platform.is_xpu():
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
# Currently, this requires USE_RAY_SPMD_WORKER=True.
|
||||
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
# If the env var is set, then we do not distinguish between the
|
||||
# "driver worker" vs other workers. Also, the rank 0 worker will
|
||||
# be executed in a remote Ray worker. Currently this requires
|
||||
# USE_RAY_COMPILED_DAG=True.
|
||||
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
if self.use_ray_compiled_dag:
|
||||
assert self.use_ray_spmd_worker, (
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_RAY_SPMD_WORKER=1"
|
||||
)
|
||||
if self.use_ray_spmd_worker:
|
||||
# TODO: Support SPMD worker for non-DAG Ray executor.
|
||||
assert self.use_ray_compiled_dag, (
|
||||
"VLLM_USE_RAY_SPMD_WORKER=1 requires VLLM_USE_RAY_COMPILED_DAG=1"
|
||||
)
|
||||
# For TPU or XPU, avoid compiling NVIDIA's NCCL
|
||||
if current_platform.is_tpu() or current_platform.is_xpu():
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
|
||||
|
||||
assert self.uses_ray
|
||||
initialize_ray_cluster(self.parallel_config)
|
||||
@ -113,13 +93,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
self.output_decoder = msgspec.msgpack.Decoder(list[SamplerOutput] | None)
|
||||
self.use_v1 = envs.VLLM_USE_V1
|
||||
# KV connector setup
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
self.pp_locks: list[asyncio.Lock] | None = None
|
||||
if not self.use_ray_compiled_dag:
|
||||
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
"""Ray distributed executor supports pipeline parallelism,
|
||||
meaning that it allows PP size batches to be executed concurrently.
|
||||
"""
|
||||
if self.scheduler_config.async_scheduling:
|
||||
return 2
|
||||
return self.parallel_config.pipeline_parallel_size
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if logger:
|
||||
@ -176,8 +160,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
ray_remote_kwargs
|
||||
)
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
bundle_indices: list[int]
|
||||
if envs.VLLM_RAY_BUNDLE_INDICES:
|
||||
@ -241,30 +223,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
for each, ip in zip(worker_metadata, worker_ips):
|
||||
each.ip = ip
|
||||
|
||||
if not self.use_ray_spmd_worker:
|
||||
for i, each in enumerate(worker_metadata):
|
||||
# find and remove the dummy worker from the list
|
||||
worker = each.worker
|
||||
worker_ip = each.ip
|
||||
if self.driver_dummy_worker is None and worker_ip == driver_ip:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
vllm_config=self.vllm_config, rpc_rank=0
|
||||
)
|
||||
worker_metadata.pop(i)
|
||||
break
|
||||
|
||||
logger.debug("workers: %s", worker_metadata)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node."
|
||||
f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
|
||||
"Consider adjusting the Ray placement group or running "
|
||||
"the driver on a GPU node."
|
||||
)
|
||||
|
||||
ip_counts: dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
@ -281,7 +241,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
should be placed first.
|
||||
"""
|
||||
ip = item.ip
|
||||
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
|
||||
return 0 if ip == driver_ip else 1, ip_counts[ip], ip
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
@ -289,14 +249,13 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
sorted_worker_metadata = sorted(
|
||||
worker_metadata, key=sort_by_driver_then_worker_ip
|
||||
)
|
||||
start_rank = 0 if self.use_ray_spmd_worker else 1
|
||||
for i, item in enumerate(sorted_worker_metadata):
|
||||
item.adjusted_rank = i + start_rank
|
||||
item.adjusted_rank = i
|
||||
self.workers = [item.worker for item in sorted_worker_metadata]
|
||||
rerank_mapping = {
|
||||
item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
|
||||
}
|
||||
self._run_workers("adjust_rank", rerank_mapping)
|
||||
self.collective_rpc("adjust_rank", args=(rerank_mapping,))
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = []
|
||||
@ -365,8 +324,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
|
||||
self._env_vars_for_all_workers = all_args_to_update_environment_variables
|
||||
|
||||
self._run_workers(
|
||||
"update_environment_variables", self._get_env_vars_to_be_updated()
|
||||
self.collective_rpc(
|
||||
"update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
|
||||
)
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
@ -396,138 +355,95 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
all_kwargs.append(kwargs)
|
||||
self._run_workers("init_worker", all_kwargs)
|
||||
self.collective_rpc("init_worker", args=(all_kwargs,))
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers(
|
||||
"load_model",
|
||||
max_concurrent_workers=self.parallel_config.max_parallel_loading_workers,
|
||||
)
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (
|
||||
pp_rank * self.parallel_config.tensor_parallel_size
|
||||
) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: list[RayWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: list[RayWorkerWrapper] = []
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self.collective_rpc("reinitialize_distributed", args=(reconfig_request,))
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
self.shutdown()
|
||||
|
||||
# Enforce rank order for correct rank to return final output.
|
||||
for index, worker in enumerate(self.workers):
|
||||
# The driver worker is rank 0 and not in self.workers.
|
||||
rank = index + 1
|
||||
if rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest | None
|
||||
) -> list[SamplerOutput] | None:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
|
||||
)
|
||||
return self.driver_worker.execute_method("execute_model", execute_model_req)
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> list[SamplerOutput]:
|
||||
if not self.use_ray_spmd_worker:
|
||||
return super().execute_model(execute_model_req)
|
||||
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
if self.use_v1:
|
||||
serialized_data = execute_model_req
|
||||
else:
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
outputs = ray.get(self.forward_dag.execute(serialized_data))
|
||||
output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0])
|
||||
return output
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str | Callable,
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: int | None = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
def execute_model( # type: ignore[override]
|
||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
"""Execute the model on the Ray workers.
|
||||
|
||||
Args:
|
||||
- async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
scheduler_output: The scheduler output to execute.
|
||||
non_block: If True, the method will return a Future.
|
||||
|
||||
Returns:
|
||||
The model runner output.
|
||||
"""
|
||||
# Build the compiled DAG for the first time.
|
||||
if self.forward_dag is None: # type: ignore
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||
|
||||
if not self.has_connector:
|
||||
# Get output only from a single worker (output_rank)
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if not non_block:
|
||||
return refs[0].get()
|
||||
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs)
|
||||
|
||||
# Get output from all workers when connector is present
|
||||
assert self.kv_output_aggregator is not None
|
||||
if not non_block:
|
||||
# Block and get results from all workers
|
||||
outputs = [ref.get() for ref in refs]
|
||||
return self.kv_output_aggregator.aggregate(outputs)
|
||||
|
||||
# Return a future that will aggregate outputs from all workers
|
||||
return FutureWrapper(refs, self.kv_output_aggregator)
|
||||
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: str | Callable,
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
non_block: bool = False,
|
||||
) -> list[Any]:
|
||||
"""Runs the given method on all workers."""
|
||||
sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
|
||||
del method
|
||||
if self.use_ray_spmd_worker:
|
||||
assert not async_run_tensor_parallel_workers_only, (
|
||||
"async_run_tensor_parallel_workers_only is not supported for spmd mode."
|
||||
)
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError("max_concurrent_workers is not supported yet.")
|
||||
|
||||
# Start the ray workers first.
|
||||
ray_workers = self.workers
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote( # type: ignore[attr-defined]
|
||||
sent_method, *args, **kwargs
|
||||
)
|
||||
for worker in ray_workers
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
driver_worker_output = []
|
||||
# In SPMD mode, the driver worker is the same as any other worker,
|
||||
# so we only explicitly execute on the driver worker if using a
|
||||
# non-SPMD worker class.
|
||||
if not self.use_ray_spmd_worker:
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = [
|
||||
self.driver_worker.execute_method(sent_method, *args, **kwargs)
|
||||
]
|
||||
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
if non_block:
|
||||
return [FutureWrapper((output,)) for output in ray_worker_outputs]
|
||||
|
||||
return driver_worker_output + ray_worker_outputs
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
ray.get(parallel_worker_tasks)
|
||||
return ray.get(ray_worker_outputs, timeout=timeout)
|
||||
|
||||
def _check_ray_cgraph_installation(self):
|
||||
import importlib.metadata
|
||||
@ -595,13 +511,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
with InputNode() as input_data:
|
||||
# Example DAG: PP=2, TP=4
|
||||
#
|
||||
# For V0:
|
||||
# ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501
|
||||
# ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501
|
||||
# ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501
|
||||
# ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501
|
||||
#
|
||||
# For V1:
|
||||
# SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501
|
||||
# SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501
|
||||
# SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501
|
||||
@ -613,20 +522,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
|
||||
# Each PP worker takes in the output of the previous PP worker,
|
||||
# and the TP group executes in SPMD fashion.
|
||||
if self.use_v1:
|
||||
outputs = [
|
||||
worker.execute_model_ray.bind( # type: ignore[attr-defined]
|
||||
outputs[i]
|
||||
)
|
||||
for i, worker in enumerate(tp_group)
|
||||
]
|
||||
else:
|
||||
outputs = [
|
||||
worker.execute_model_spmd.bind( # type: ignore[attr-defined]
|
||||
outputs[i]
|
||||
)
|
||||
for i, worker in enumerate(tp_group)
|
||||
]
|
||||
outputs = [
|
||||
worker.execute_model_ray.bind(outputs[i]) # type: ignore[attr-defined]
|
||||
for i, worker in enumerate(tp_group)
|
||||
]
|
||||
|
||||
last_pp_rank = len(self.pp_tp_workers) - 1
|
||||
if (
|
||||
@ -674,82 +573,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
async def execute_model_async(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> list[SamplerOutput]:
|
||||
if not self.use_ray_spmd_worker:
|
||||
return await super().execute_model_async(execute_model_req)
|
||||
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
|
||||
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
dag_future = await self.forward_dag.execute_async(serialized_data)
|
||||
output = await dag_future[0]
|
||||
return self.output_decoder.decode(output)
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self, execute_model_req: ExecuteModelRequest | None = None
|
||||
) -> list[SamplerOutput]:
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
|
||||
)
|
||||
if not self.tp_driver_workers:
|
||||
return await self.driver_exec_method("execute_model", execute_model_req)
|
||||
if self.pp_locks is None:
|
||||
# This locks each pipeline parallel stage so multiple virtual
|
||||
# engines can't execute on the same stage at the same time
|
||||
# We create the locks here to avoid creating them in the constructor
|
||||
# which uses a different asyncio loop.
|
||||
self.pp_locks = [
|
||||
asyncio.Lock()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(
|
||||
self.driver_exec_method,
|
||||
self.pp_locks[0],
|
||||
"execute_model",
|
||||
execute_model_req,
|
||||
)
|
||||
)
|
||||
]
|
||||
for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(
|
||||
driver_worker.execute_method.remote, # type: ignore[attr-defined]
|
||||
self.pp_locks[pp_rank],
|
||||
"execute_model",
|
||||
execute_model_req,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Only the last PP stage has the final results.
|
||||
return results[-1]
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1"
|
||||
)
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop") # type: ignore[attr-defined]
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# Assume that the Ray workers are healthy.
|
||||
# TODO: check the health of the Ray workers
|
||||
return
|
||||
|
||||
|
||||
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
@ -4,17 +4,16 @@
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
import vllm.platforms
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
@ -51,11 +50,6 @@ try:
|
||||
# that thread.
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
|
||||
self.input_decoder = msgspec.msgpack.Decoder(
|
||||
ExecuteModelRequest, dec_hook=decode_hook
|
||||
)
|
||||
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
|
||||
def get_node_ip(self) -> str:
|
||||
return get_ip()
|
||||
|
||||
@ -70,47 +64,6 @@ try:
|
||||
gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key]
|
||||
return node_id, gpu_ids
|
||||
|
||||
def execute_model_spmd(
|
||||
self,
|
||||
req_or_tuple: bytes | tuple[bytes, IntermediateTensors | None],
|
||||
) -> bytes:
|
||||
"""Execute model in SPMD fashion: used only when SPMD worker and
|
||||
compiled DAG are both enabled.
|
||||
|
||||
Args:
|
||||
req_or_tuple: A request or a tuple containing the
|
||||
request and intermediate tensors. Intermediate tensors are
|
||||
None unless if it is provided because it is > 0 pipeline
|
||||
stage. The request is serialized by msgspec.
|
||||
"""
|
||||
if isinstance(req_or_tuple, bytes):
|
||||
serialized_req, intermediate_tensors = req_or_tuple, None
|
||||
else:
|
||||
serialized_req, intermediate_tensors = req_or_tuple
|
||||
|
||||
execute_model_req = self.input_decoder.decode(serialized_req)
|
||||
|
||||
assert self.worker is not None, "Worker is not initialized"
|
||||
|
||||
# TODO(swang): This is needed right now because Ray Compiled Graph
|
||||
# executes on a background thread, so we need to reset torch's
|
||||
# current device.
|
||||
if not self.compiled_dag_cuda_device_set:
|
||||
assert self.worker.device is not None
|
||||
current_platform.set_device(self.worker.device)
|
||||
self.compiled_dag_cuda_device_set = True
|
||||
|
||||
output = self.worker._execute_model_spmd( # type: ignore[attr-defined]
|
||||
execute_model_req, intermediate_tensors
|
||||
)
|
||||
# Pipeline model request and output to the next pipeline stage.
|
||||
if isinstance(output, IntermediateTensors):
|
||||
output = serialized_req, output
|
||||
else:
|
||||
output = self.output_encoder.encode(output)
|
||||
|
||||
return output
|
||||
|
||||
def setup_device_if_necessary(self):
|
||||
# TODO(swang): This is needed right now because Ray CG executes
|
||||
# on a background thread, so we need to reset torch's current
|
||||
@ -174,6 +127,31 @@ except ImportError as e:
|
||||
RayWorkerWrapper = None # type: ignore
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
"""A wrapper around Ray output reference to meet the interface
|
||||
of .execute_model(): The top level (core busy loop) expects .result() api
|
||||
to block and return a single output.
|
||||
|
||||
If aggregator is provided, the outputs from all workers are aggregated upon
|
||||
the result() call. If not only the first worker's output is returned.
|
||||
"""
|
||||
|
||||
def __init__(self, refs, aggregator: KVOutputAggregator | None = None):
|
||||
super().__init__()
|
||||
self.refs = refs
|
||||
self.aggregator = aggregator
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise NotImplementedError("timeout is not supported")
|
||||
|
||||
if self.aggregator is None:
|
||||
return self.refs[0].get()
|
||||
|
||||
outputs = [ref.get() for ref in self.refs]
|
||||
return self.aggregator.aggregate(outputs, output_rank=0)
|
||||
|
||||
|
||||
def ray_is_available() -> bool:
|
||||
"""Returns True if Ray is available."""
|
||||
return ray is not None
|
||||
@ -11,20 +11,18 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import run_method
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class UniProcExecutor(ExecutorBase):
|
||||
uses_ray: bool = False
|
||||
|
||||
class UniProcExecutor(Executor):
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model."""
|
||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
|
||||
@ -44,9 +42,9 @@ class UniProcExecutor(ExecutorBase):
|
||||
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
|
||||
)
|
||||
|
||||
self.collective_rpc("init_worker", args=([kwargs],))
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
self.driver_worker.init_worker(all_kwargs=[kwargs])
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _distributed_args(self) -> tuple[str, int, int]:
|
||||
"""Return (distributed_init_method, rank, local_rank)."""
|
||||
@ -101,16 +99,12 @@ class UniProcExecutor(ExecutorBase):
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
self.shutdown()
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if worker := self.driver_worker:
|
||||
worker.shutdown()
|
||||
|
||||
|
||||
UniProcExecutorAsync = UniProcExecutor
|
||||
|
||||
|
||||
class ExecutorWithExternalLauncher(UniProcExecutor):
|
||||
"""An executor that uses external launchers to launch engines,
|
||||
specially designed for torchrun-compatible launchers, for
|
||||
@ -128,8 +122,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
||||
and they don't need to synchronize the states with each other.
|
||||
"""
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model."""
|
||||
if envs.VLLM_USE_V1:
|
||||
@ -152,22 +144,12 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
return distributed_init_method, rank, local_rank
|
||||
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
"""
|
||||
Determine the number of available KV blocks.
|
||||
Add an additional all_reduce to get the min across all ranks.
|
||||
Note that even if we have the same `gpu_memory_utilization` and
|
||||
`swap_space`, the available memory in every rank might still
|
||||
differ because NCCL can take different amounts of memory in
|
||||
different ranks. Therefore, it is necessary to test if all ranks
|
||||
agree on the same KV cache configuration.
|
||||
"""
|
||||
a, b = super().determine_num_available_blocks()
|
||||
def determine_available_memory(self) -> list[int]: # in bytes
|
||||
# we need to get the min across all ranks.
|
||||
memory = super().determine_available_memory()
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
cpu_group = get_world_group().cpu_group
|
||||
a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
|
||||
b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
|
||||
dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
return a_tensor.item(), b_tensor.item()
|
||||
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
|
||||
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
return [memory_tensor.item()]
|
||||
@ -128,28 +128,6 @@ class WorkerBase:
|
||||
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
"""Execute model loop in parallel worker.
|
||||
|
||||
You can stop the loop by executing a driver worker with an empty output.
|
||||
See `stop_remote_worker_execution_loop` for more details.
|
||||
"""
|
||||
raise NotImplementedError("Dead V0 code")
|
||||
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
swappable CPU KV cache.
|
||||
|
||||
The implementation may run profiling or other heuristics to determine
|
||||
the size of caches.
|
||||
|
||||
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||
are blocks that are "active" on the device and can be appended to.
|
||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||
appended to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size of a single cache block, in bytes. Used in
|
||||
speculative decoding.
|
||||
|
||||
Reference in New Issue
Block a user