[V0 Deprecation] Remove V0 executors (#27142)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-10-21 11:09:37 -07:00
committed by GitHub
parent ddeec11ba9
commit 647214f3d5
31 changed files with 425 additions and 1043 deletions

View File

@ -157,11 +157,9 @@ def test_models_distributed(
and distributed_executor_backend == "ray"
and attention_backend == ""
and test_suite == "L4"
and enable_prompt_embeds
): # noqa
if enable_prompt_embeds:
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
if attention_backend:
monkeypatch_context.setenv(

View File

@ -18,8 +18,8 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from vllm import initialize_ray_cluster
from vllm.config import ParallelConfig
from vllm.executor.ray_utils import _wait_until_pg_removed
from vllm.utils.network_utils import get_ip
from vllm.v1.executor.ray_utils import _wait_until_pg_removed
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

View File

@ -305,10 +305,8 @@ def _compare_tp(
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
if distributed_backend == "ray":
# For V1, test Ray Compiled Graph for all the tests
# Test Ray Compiled Graph for all the tests
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
}
# Temporary. Currently when zeromq + SPMD is used, it does not properly

View File

@ -9,7 +9,7 @@ from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import UniProcExecutor
from vllm.v1.executor import UniProcExecutor
from vllm.v1.worker.worker_base import WorkerWrapperBase
MODEL_REF = "facebook/opt-125m"

View File

@ -15,7 +15,8 @@ from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor, UniProcExecutor
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.uniproc_executor import UniProcExecutor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput

View File

@ -17,8 +17,6 @@ import regex as re
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = {
# pickle
"vllm/v1/serial_utils.py",
"vllm/v1/executor/multiproc_executor.py",
"vllm/multimodal/hasher.py",
"vllm/transformers_utils/config.py",
"vllm/model_executor/models/registry.py",
@ -38,11 +36,13 @@ ALLOWED_FILES = {
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
# cloudpickle
"vllm/executor/mp_distributed_executor.py",
"vllm/executor/ray_distributed_executor.py",
"vllm/v1/executor/multiproc_executor.py",
"vllm/v1/executor/ray_executor.py",
"vllm/entrypoints/llm.py",
"vllm/utils/__init__.py",
"tests/utils.py",
# pickle and cloudpickle
"vllm/v1/serial_utils.py",
}
PICKLE_RE = re.compile(

View File

@ -21,7 +21,7 @@ MODULE_ATTRS = {
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
"LLMEngine": ".engine.llm_engine:LLMEngine",
"LLM": ".entrypoints.llm:LLM",
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
"initialize_ray_cluster": ".v1.executor.ray_utils:initialize_ray_cluster",
"PromptType": ".inputs:PromptType",
"TextPrompt": ".inputs:TextPrompt",
"TokensPrompt": ".inputs:TokensPrompt",
@ -45,7 +45,6 @@ if typing.TYPE_CHECKING:
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (
@ -62,6 +61,7 @@ if typing.TYPE_CHECKING:
)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.executor.ray_utils import initialize_ray_cluster
from ._bc_linter import bc_linter_include, bc_linter_skip
else:

View File

@ -25,11 +25,11 @@ if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv
from ray.util.placement_group import PlacementGroup
from vllm.executor.executor_base import ExecutorBase
from vllm.v1.executor import Executor
else:
RuntimeEnv = Any
PlacementGroup = Any
ExecutorBase = Any
Executor = Any
logger = init_logger(__name__)
@ -189,7 +189,7 @@ class ParallelConfig:
"""ray distributed model workers placement group."""
distributed_executor_backend: (
str | DistributedExecutorBackend | type[ExecutorBase] | None
str | DistributedExecutorBackend | type[Executor] | None
) = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
@ -511,7 +511,7 @@ class ParallelConfig:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
from vllm.executor import ray_utils
from vllm.v1.executor import ray_utils
backend: DistributedExecutorBackend = "mp"
ray_found = ray_utils.ray_is_available()
@ -553,6 +553,12 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
if self.max_parallel_loading_workers is not None:
logger.warning(
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
@ -563,7 +569,7 @@ class ParallelConfig:
@model_validator(mode="after")
def _verify_args(self) -> Self:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
from vllm.v1.executor import Executor
# Enable batch invariance settings if requested
if vllm_is_batch_invariant():
@ -574,17 +580,17 @@ class ParallelConfig:
and not isinstance(self.distributed_executor_backend, str)
and not (
isinstance(self.distributed_executor_backend, type)
and issubclass(self.distributed_executor_backend, ExecutorBase)
and issubclass(self.distributed_executor_backend, Executor)
)
):
raise ValueError(
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' 'uni', 'external_launcher', "
" custom ExecutorBase subclass or its import path."
" custom Executor subclass or its import path."
)
if self.use_ray:
from vllm.executor import ray_utils
from vllm.v1.executor import ray_utils
ray_utils.assert_ray_available()

View File

@ -107,12 +107,6 @@ class SchedulerConfig:
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
send_delta_data: bool = False
"""Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1"""
policy: SchedulerPolicy = "fcfs"
"""The scheduling policy to use:\n
- "fcfs" means first come first served, i.e. requests are handled in order

View File

@ -31,7 +31,7 @@ if not USE_TPU_INFERENCE:
)
if USE_RAY:
from vllm.executor import ray_utils
from vllm.v1.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase):

View File

@ -88,12 +88,12 @@ from vllm.utils.network_utils import get_ip
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.model_loader import LoadFormats
from vllm.usage.usage_lib import UsageContext
from vllm.v1.executor import Executor
else:
ExecutorBase = Any
Executor = Any
QuantizationMethods = Any
LoadFormats = Any
UsageContext = Any
@ -369,7 +369,7 @@ class EngineArgs:
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: (
str | DistributedExecutorBackend | type[ExecutorBase] | None
str | DistributedExecutorBackend | type[Executor] | None
) = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
@ -1549,7 +1549,6 @@ class EngineArgs:
disable_chunked_mm_input=self.disable_chunked_mm_input,
is_multimodal_model=model_config.is_multimodal_model,
is_encoder_decoder=model_config.is_encoder_decoder,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray),
policy=self.scheduling_policy,
scheduler_cls=self.scheduler_cls,
max_num_partial_prefills=self.max_num_partial_prefills,

View File

@ -26,7 +26,7 @@ from vllm.utils import (
from vllm.utils.network_utils import get_tcp_uri
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor import Executor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure

View File

@ -56,8 +56,6 @@ if TYPE_CHECKING:
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
@ -623,22 +621,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
# If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger
# execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER": lambda: bool(
int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))
),
# If the env var is set, it uses the Ray's Compiled Graph
# (previously known as ADAG) API which optimizes the
# control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Note that this variable is set to 1 in V1 by default
# when ray distributed executor is used.
"VLLM_USE_RAY_COMPILED_DAG": lambda: bool(
int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))
),
# If the env var is set, Ray Compiled Graph uses the specified
# channel type to communicate between workers belonging to
# different pipeline-parallel stages.
@ -646,20 +628,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "auto": use the default channel type
# - "nccl": use NCCL for communication
# - "shm": use shared memory and gRPC for communication
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices(
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto", ["auto", "nccl", "shm"]
),
# If the env var is set, it enables GPU communication overlap
# (experimental feature) in Ray's Compiled Graph. This flag is ignored if
# VLLM_USE_RAY_COMPILED_DAG is not set.
# (experimental feature) in Ray's Compiled Graph.
"VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": lambda: bool(
int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0"))
),
# If the env var is set, it uses a Ray Communicator wrapping
# vLLM's pipeline parallelism communicator to interact with Ray's
# Compiled Graph. Otherwise, it uses Ray's NCCL communicator.
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_WRAPPED_PP_COMM": lambda: bool(
int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1"))
),

View File

@ -1,393 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from functools import cached_property
from typing import Any
from typing_extensions import TypeVar
import vllm.platforms
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest
from vllm.tasks import SupportedTask
from vllm.utils.async_utils import make_async
from vllm.v1.outputs import SamplerOutput
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class ExecutorBase(ABC):
"""Base class for all executors.
An executor is responsible for executing the model on one device,
or it can be a distributed executor
that can execute the model on multiple devices.
"""
uses_ray: bool # whether the executor uses Ray for orchestration.
supports_pp: bool = False # whether the executor supports PP
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self._init_executor()
self.is_sleeping = False
self.sleeping_tags: set[str] = set()
self.kv_output_aggregator: KVOutputAggregator | None = None
@abstractmethod
def _init_executor(self) -> None:
raise NotImplementedError
@abstractmethod
def collective_rpc(
self,
method: str | Callable[[WorkerBase], _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
raise NotImplementedError
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
Normally, this should simply delegate to the underlying Worker. Some
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
`num_gpu_blocks` are blocks that are "active" on the device and can be
appended to.
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
results = self.collective_rpc("determine_num_available_blocks")
a = min([r[0] for r in results])
b = min([r[1] for r in results])
return a, b
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: This is logged in the executor because there can be >1 workers.
logger.info(
"# %s blocks: %d, # CPU blocks: %d",
vllm.platforms.current_platform.device_name,
num_gpu_blocks,
num_cpu_blocks,
)
max_concurrency = (
num_gpu_blocks
* self.cache_config.block_size
/ self.model_config.max_model_len
)
logger.info(
"Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len,
max_concurrency,
)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
@cached_property # Avoid unnecessary RPC calls
def supported_tasks(self) -> tuple[SupportedTask, ...]:
output = self.collective_rpc("get_supported_tasks")
return output[0]
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> list[SamplerOutput]:
output = self.collective_rpc("execute_model", args=(execute_model_req,))
assert output[0] is not None
return output[0]
def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("add_lora", args=(lora_request,)))
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("remove_lora", args=(lora_id,)))
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
def list_loras(self) -> set[int]:
sets = self.collective_rpc("list_loras")
for s in sets:
assert s == sets[0], "All workers should have the same LORAs."
return sets[0]
def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache in each worker."""
self.collective_rpc("reset_mm_cache")
def start_profile(self) -> None:
self.collective_rpc("start_profile")
def stop_profile(self) -> None:
self.collective_rpc("stop_profile")
def sleep(self, level: int = 1):
if self.is_sleeping:
logger.warning("Executor is already sleeping.")
return
time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True
logger.info(
"It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
)
def wake_up(self, tags: list[str] | None = None):
if not self.is_sleeping:
logger.warning("Executor is not sleeping.")
return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning(
"Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
)
return
time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter()
logger.info(
"It took %.6f seconds to wake up tags %s.",
time_after_wakeup - time_before_wakeup,
tags if tags is not None else self.sleeping_tags,
)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state(
self,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
self.collective_rpc(
"save_sharded_state",
kwargs=dict(path=path, pattern=pattern, max_size=max_size),
)
@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
def shutdown(self) -> None:
"""Shutdown the executor."""
self.collective_rpc("shutdown")
async def execute_model_async(
self, execute_model_req: ExecuteModelRequest
) -> list[SamplerOutput]:
"""Executes one model step on the given sequences."""
output = await make_async(self.execute_model)(execute_model_req)
return output
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
return
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
"""Init KVOutputAggregator"""
self.kv_output_aggregator = KVOutputAggregator(
finished_count or self.parallel_config.world_size
)
class DistributedExecutorBase(ExecutorBase):
"""Abstract superclass of distributed executor implementations."""
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Any | Awaitable[Any] | None = None
super().__init__(*args, **kwargs)
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> list[SamplerOutput]:
# TODO: unify into collective_rpc
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_tensor_parallel_workers_only=True,
)
# Only the driver worker returns the sampling results.
driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model(execute_model_req=None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
@abstractmethod
def _driver_execute_model(
self, execute_model_req: ExecuteModelRequest | None
) -> list[SamplerOutput] | None:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution loop
running in each of the remote workers. In this case, this method
returns None. Otherwise, this method returns the model output.
"""
raise NotImplementedError
def collective_rpc(
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[Any]:
return self._run_workers(method, *args, **(kwargs or {}))
@abstractmethod
def _run_workers(
self,
method: str | Callable,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: int | None = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
# TODO: simplify and merge with collective_rpc
"""
raise NotImplementedError
@abstractmethod
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise NotImplementedError
async def execute_model_async(
self, execute_model_req: ExecuteModelRequest
) -> list[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop()
)
# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
@abstractmethod
async def _driver_execute_model_async(
self,
execute_model_req: ExecuteModelRequest | None = None,
) -> list[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
@abstractmethod
async def _start_worker_execution_loop(self):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise NotImplementedError

View File

@ -1,36 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from array import array
from typing import Any
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if isinstance(obj, array):
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}."
)
return obj.tobytes()
if isinstance(obj, MultiModalKwargs):
return dict(obj)
def decode_hook(type: type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if type is array:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized
if type is MultiModalKwargs:
return MultiModalKwargs(obj)

View File

@ -5,7 +5,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import msgspec
import torch
if TYPE_CHECKING:
@ -92,12 +91,3 @@ class IntermediateTensors:
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"
class ExecuteModelRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True,
): # type: ignore[call-arg]
# Placeholder. Remove.
pass

View File

@ -943,7 +943,7 @@ def maybe_register_config_serialize_by_value() -> None:
cloudpickle.register_pickle_by_value(transformers_modules)
# ray vendors its own version of cloudpickle
from vllm.executor.ray_utils import ray
from vllm.v1.executor.ray_utils import ray
if ray:
ray.cloudpickle.register_pickle_by_value(transformers_modules)

View File

@ -39,7 +39,7 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import (
StatLoggerFactory,
StatLoggerManager,

View File

@ -60,7 +60,7 @@ from vllm.v1.engine.utils import (
EngineZmqAddresses,
get_device_indices,
)
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
@ -322,7 +322,6 @@ class EngineCore:
with self.log_error_detail(scheduler_output):
model_output = self.model_executor.execute_model(scheduler_output)
assert isinstance(model_output, ModelRunnerOutput)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
@ -364,7 +363,7 @@ class EngineCore:
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
batch_queue.appendleft((future, scheduler_output)) # type: ignore[arg-type]
batch_queue.appendleft((future, scheduler_output))
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if (
@ -463,14 +462,6 @@ class EngineCore:
) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args, kwargs)
def save_tensorized_model(
self,
tensorizer_config,
) -> None:
self.model_executor.save_tensorized_model(
tensorizer_config=tensorizer_config,
)
def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
"""Preprocess the request.

View File

@ -46,7 +46,7 @@ from vllm.v1.engine.utils import (
CoreEngineProcManager,
launch_core_engines,
)
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
logger = init_logger(__name__)

View File

@ -32,7 +32,7 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats

View File

@ -23,7 +23,7 @@ from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.utils import get_mp_context
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor import Executor
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
if TYPE_CHECKING:

View File

@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .abstract import Executor
from .uniproc_executor import UniProcExecutor
__all__ = ["Executor", "UniProcExecutor"]

View File

@ -1,31 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from concurrent.futures import Future
from typing import Any
import torch
import torch.distributed as dist
from functools import cached_property
from typing import Literal, TypeVar, overload
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0,
)
from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
_R = TypeVar("_R")
FailureCallback = Callable[[], None]
class Executor(ExecutorBase):
class Executor(ABC):
"""Abstract base class for vLLM executors."
An executor is responsible for executing the model on one device,
or it can be a distributed executor that can execute the model on multiple devices.
"""
Abstract class for v1 executors, mainly define some methods for v1.
For methods shared by v0 and v1, define them in ExecutorBase"""
uses_ray: bool = False # whether the executor uses Ray for orchestration.
supports_pp: bool = False # whether the executor supports PP
@staticmethod
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
@ -34,16 +43,14 @@ class Executor(ExecutorBase):
distributed_executor_backend = parallel_config.distributed_executor_backend
# distributed_executor_backend must be set in VllmConfig.__post_init__
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
if not issubclass(distributed_executor_backend, Executor):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}."
f"Executor. Got {distributed_executor_backend}."
)
executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
from vllm.v1.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor,
)
from vllm.v1.executor.ray_executor import RayDistributedExecutor
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
@ -51,6 +58,8 @@ class Executor(ExecutorBase):
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
from vllm.v1.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# TODO: make v1 scheduling deterministic
@ -58,10 +67,10 @@ class Executor(ExecutorBase):
executor_class = ExecutorWithExternalLauncher
elif isinstance(distributed_executor_backend, str):
executor_class = resolve_obj_by_qualname(distributed_executor_backend)
if not issubclass(executor_class, ExecutorBase):
if not issubclass(executor_class, Executor):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {executor_class}."
f"Executor. Got {executor_class}."
)
else:
raise ValueError(
@ -69,6 +78,29 @@ class Executor(ExecutorBase):
)
return executor_class
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self._init_executor()
self.is_sleeping = False
self.sleeping_tags: set[str] = set()
self.kv_output_aggregator: KVOutputAggregator | None = None
@abstractmethod
def _init_executor(self) -> None:
raise NotImplementedError
def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
@ -77,7 +109,7 @@ class Executor(ExecutorBase):
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
self.collective_rpc("compile_or_warm_up_model")
def register_failure_callback(self, callback: FailureCallback):
def register_failure_callback(self, callback: FailureCallback): # noqa: B027
"""
Register a function to be called if the executor enters a permanent
failed state.
@ -90,22 +122,78 @@ class Executor(ExecutorBase):
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
return self.collective_rpc("get_kv_cache_spec")
@overload
def collective_rpc(
self,
method: str | Callable,
method: str | Callable[[WorkerBase], _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
non_block: bool = False,
) -> list[Any]:
non_block: Literal[False] = False,
) -> list[_R]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
non_block: If `True`, returns a list of Futures instead of waiting
for the results.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
pass
@overload
def collective_rpc(
self,
method: str | Callable[[WorkerBase], _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
non_block: Literal[True] = True,
) -> list[Future[_R]]:
pass
@abstractmethod
def collective_rpc(
self, method, timeout=None, args=(), kwargs=None, non_block: bool = False
):
raise NotImplementedError
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
non_block: Literal[False] = False,
) -> ModelRunnerOutput:
pass
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[True] = True,
) -> Future[ModelRunnerOutput]:
pass
def execute_model(
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc(
output = self.collective_rpc( # type: ignore[call-overload]
"execute_model", args=(scheduler_output,), non_block=non_block
)
return output[0]
@ -114,7 +202,7 @@ class Executor(ExecutorBase):
self.collective_rpc("execute_dummy_batch")
def take_draft_token_ids(self) -> DraftTokenIds | None:
output = self.collective_rpc("take_draft_token_ids")
output: list[DraftTokenIds] = self.collective_rpc("take_draft_token_ids")
return output[0]
@property
@ -124,19 +212,120 @@ class Executor(ExecutorBase):
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start,))
def save_sharded_state(
self,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
self.collective_rpc(
"save_sharded_state",
kwargs=dict(path=path, pattern=pattern, max_size=max_size),
)
class UniProcExecutor(UniProcExecutorV0, Executor):
pass
@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
def shutdown(self) -> None:
"""Shutdown the executor."""
self.collective_rpc("shutdown")
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
"""Init KVOutputAggregator"""
self.kv_output_aggregator = KVOutputAggregator(
finished_count or self.parallel_config.world_size
)
@cached_property # Avoid unnecessary RPC calls
def supported_tasks(self) -> tuple[SupportedTask, ...]:
output: list[tuple[SupportedTask, ...]]
output = self.collective_rpc("get_supported_tasks")
return output[0]
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("add_lora", args=(lora_request,)))
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("remove_lora", args=(lora_id,)))
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
def list_loras(self) -> set[int]:
sets: list[set[int]] = self.collective_rpc("list_loras")
for s in sets:
assert s == sets[0], "All workers should have the same LORAs."
return sets[0]
def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache in each worker."""
self.collective_rpc("reset_mm_cache")
def start_profile(self) -> None:
self.collective_rpc("start_profile")
def stop_profile(self) -> None:
self.collective_rpc("stop_profile")
def sleep(self, level: int = 1):
if self.is_sleeping:
logger.warning("Executor is already sleeping.")
return
time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True
logger.info(
"It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
)
def wake_up(self, tags: list[str] | None = None):
if not self.is_sleeping:
logger.warning("Executor is not sleeping.")
return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning(
"Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
)
return
time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter()
logger.info(
"It took %.6f seconds to wake up tags %s.",
time_after_wakeup - time_before_wakeup,
tags if tags is not None else self.sleeping_tags,
)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
raise NotImplementedError
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
def determine_available_memory(self) -> list[int]: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
from vllm.v1.executor.uniproc_executor import ( # noqa: E402
ExecutorWithExternalLauncher as _ExecutorWithExternalLauncher,
)
from vllm.v1.executor.uniproc_executor import ( # noqa: E402
UniProcExecutor as _UniProcExecutor,
)
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return [memory_tensor.item()]
# For backwards compatibility.
UniProcExecutor = _UniProcExecutor
ExecutorWithExternalLauncher = _ExecutorWithExternalLauncher

View File

@ -179,7 +179,7 @@ class MultiprocExecutor(Executor):
else:
self.failure_callback = callback
def execute_model(
def execute_model( # type: ignore[override]
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
@ -204,6 +204,7 @@ class MultiprocExecutor(Executor):
)
# aggregate all workers output to a single output
assert self.kv_output_aggregator is not None
if non_block:
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)

View File

@ -1,111 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0,
from vllm.v1.executor.ray_executor import (
RayDistributedExecutor as _RayDistributedExecutor,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
class FutureWrapper(Future):
"""A wrapper around Ray output reference to meet the interface
of .execute_model(): The top level (core busy loop) expects .result() api
to block and return a single output.
If aggregator is provided, the outputs from all workers are aggregated upon
the result() call. If not only the first worker's output is returned.
"""
def __init__(self, refs, aggregator: KVOutputAggregator | None = None):
super().__init__()
self.refs = refs
self.aggregator = aggregator
def result(self, timeout=None):
if timeout is not None:
raise NotImplementedError("timeout is not supported")
if self.aggregator is None:
return self.refs[0].get()
outputs = [ref.get() for ref in self.refs]
return self.aggregator.aggregate(outputs, output_rank=0)
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
"""Ray distributed executor using Ray Compiled Graphs."""
supports_pp: bool = True
def _init_executor(self) -> None:
super()._init_executor()
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
if self.scheduler_config.async_scheduling:
return 2
return self.parallel_config.pipeline_parallel_size
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
"""Execute the model on the Ray workers.
Args:
scheduler_output: The scheduler output to execute.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore
if not self.has_connector:
# Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available.
if not non_block:
return refs[0].get()
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs)
# Get output from all workers when connector is present
if not non_block:
# Block and get results from all workers
outputs = [ref.get() for ref in refs]
return self.kv_output_aggregator.aggregate(outputs)
# Return a future that will aggregate outputs from all workers
return FutureWrapper(refs, self.kv_output_aggregator)
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self._run_workers("reinitialize_distributed", reconfig_request)
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
# For backwards compatibility.
RayDistributedExecutor = _RayDistributedExecutor

View File

@ -1,31 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import Future
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import cloudpickle
import msgspec
import vllm.envs as envs
from vllm.executor.executor_base import DistributedExecutorBase
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.sequence import ExecuteModelRequest
from vllm.utils.async_utils import make_async
from vllm.utils.network_utils import (
get_distributed_init_method,
get_ip,
get_open_port,
)
from vllm.v1.outputs import SamplerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
FutureWrapper,
RayWorkerWrapper,
initialize_ray_cluster,
ray,
)
from vllm.v1.outputs import ModelRunnerOutput
if ray is not None:
from ray.actor import ActorHandle
@ -53,7 +56,7 @@ class RayWorkerMetaData:
ip: str = ""
class RayDistributedExecutor(DistributedExecutorBase):
class RayDistributedExecutor(Executor):
"""Ray-based distributed executor"""
# These env vars are worker-specific, therefore are NOT copied
@ -69,37 +72,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}
uses_ray: bool = True
supports_pp: bool = True
def _init_executor(self) -> None:
self.forward_dag: ray.dag.CompiledDAG | None = None
if envs.VLLM_USE_V1:
# V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
# For TPU or XPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu() or current_platform.is_xpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if self.use_ray_compiled_dag:
assert self.use_ray_spmd_worker, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_RAY_SPMD_WORKER=1"
)
if self.use_ray_spmd_worker:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert self.use_ray_compiled_dag, (
"VLLM_USE_RAY_SPMD_WORKER=1 requires VLLM_USE_RAY_COMPILED_DAG=1"
)
# For TPU or XPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu() or current_platform.is_xpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
assert self.uses_ray
initialize_ray_cluster(self.parallel_config)
@ -113,13 +93,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(list[SamplerOutput] | None)
self.use_v1 = envs.VLLM_USE_V1
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.pp_locks: list[asyncio.Lock] | None = None
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(self.driver_worker.execute_method)
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
if self.scheduler_config.async_scheduling:
return 2
return self.parallel_config.pipeline_parallel_size
def shutdown(self) -> None:
if logger:
@ -176,8 +160,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
ray_remote_kwargs
)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
bundle_indices: list[int]
if envs.VLLM_RAY_BUNDLE_INDICES:
@ -241,30 +223,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
for each, ip in zip(worker_metadata, worker_ips):
each.ip = ip
if not self.use_ray_spmd_worker:
for i, each in enumerate(worker_metadata):
# find and remove the dummy worker from the list
worker = each.worker
worker_ip = each.ip
if self.driver_dummy_worker is None and worker_ip == driver_ip:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config, rpc_rank=0
)
worker_metadata.pop(i)
break
logger.debug("workers: %s", worker_metadata)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node."
f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
"Consider adjusting the Ray placement group or running "
"the driver on a GPU node."
)
ip_counts: dict[str, int] = {}
for ip in worker_ips:
@ -281,7 +241,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
should be placed first.
"""
ip = item.ip
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
return 0 if ip == driver_ip else 1, ip_counts[ip], ip
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
@ -289,14 +249,13 @@ class RayDistributedExecutor(DistributedExecutorBase):
sorted_worker_metadata = sorted(
worker_metadata, key=sort_by_driver_then_worker_ip
)
start_rank = 0 if self.use_ray_spmd_worker else 1
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i + start_rank
item.adjusted_rank = i
self.workers = [item.worker for item in sorted_worker_metadata]
rerank_mapping = {
item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
}
self._run_workers("adjust_rank", rerank_mapping)
self.collective_rpc("adjust_rank", args=(rerank_mapping,))
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
@ -365,8 +324,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
self._env_vars_for_all_workers = all_args_to_update_environment_variables
self._run_workers(
"update_environment_variables", self._get_env_vars_to_be_updated()
self.collective_rpc(
"update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
)
if len(node_gpus) == 1:
@ -396,138 +355,95 @@ class RayDistributedExecutor(DistributedExecutorBase):
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
all_kwargs.append(kwargs)
self._run_workers("init_worker", all_kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,))
self._run_workers("init_device")
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.max_parallel_loading_workers,
)
self.collective_rpc("init_device")
self.collective_rpc("load_model")
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (
pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: list[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: list[RayWorkerWrapper] = []
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.collective_rpc("reinitialize_distributed", args=(reconfig_request,))
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
def _driver_execute_model(
self, execute_model_req: ExecuteModelRequest | None
) -> list[SamplerOutput] | None:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
)
return self.driver_worker.execute_method("execute_model", execute_model_req)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> list[SamplerOutput]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
if self.use_v1:
serialized_data = execute_model_req
else:
serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data))
output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0])
return output
def _run_workers(
self,
method: str | Callable,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: int | None = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
"""Execute the model on the Ray workers.
Args:
- async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
scheduler_output: The scheduler output to execute.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore
if not self.has_connector:
# Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available.
if not non_block:
return refs[0].get()
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs)
# Get output from all workers when connector is present
assert self.kv_output_aggregator is not None
if not non_block:
# Block and get results from all workers
outputs = [ref.get() for ref in refs]
return self.kv_output_aggregator.aggregate(outputs)
# Return a future that will aggregate outputs from all workers
return FutureWrapper(refs, self.kv_output_aggregator)
def collective_rpc(
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
non_block: bool = False,
) -> list[Any]:
"""Runs the given method on all workers."""
sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for spmd mode."
)
if max_concurrent_workers:
raise NotImplementedError("max_concurrent_workers is not supported yet.")
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
if kwargs is None:
kwargs = {}
ray_worker_outputs = [
worker.execute_method.remote( # type: ignore[attr-defined]
sent_method, *args, **kwargs
)
for worker in ray_workers
for worker in self.workers
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(sent_method, *args, **kwargs)
]
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
if non_block:
return [FutureWrapper((output,)) for output in ray_worker_outputs]
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
return ray.get(ray_worker_outputs, timeout=timeout)
def _check_ray_cgraph_installation(self):
import importlib.metadata
@ -595,13 +511,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
#
# For V0:
# ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501
# ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501
# ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501
# ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501
#
# For V1:
# SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501
# SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501
# SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501
@ -613,20 +522,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
# Each PP worker takes in the output of the previous PP worker,
# and the TP group executes in SPMD fashion.
if self.use_v1:
outputs = [
worker.execute_model_ray.bind( # type: ignore[attr-defined]
outputs[i]
)
for i, worker in enumerate(tp_group)
]
else:
outputs = [
worker.execute_model_spmd.bind( # type: ignore[attr-defined]
outputs[i]
)
for i, worker in enumerate(tp_group)
]
outputs = [
worker.execute_model_ray.bind(outputs[i]) # type: ignore[attr-defined]
for i, worker in enumerate(tp_group)
]
last_pp_rank = len(self.pp_tp_workers) - 1
if (
@ -674,82 +573,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def __del__(self):
self.shutdown()
async def execute_model_async(
self, execute_model_req: ExecuteModelRequest
) -> list[SamplerOutput]:
if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
output = await dag_future[0]
return self.output_decoder.decode(output)
async def _driver_execute_model_async(
self, execute_model_req: ExecuteModelRequest | None = None
) -> list[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
)
if not self.tp_driver_workers:
return await self.driver_exec_method("execute_model", execute_model_req)
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
tasks = [
asyncio.create_task(
_run_task_with_lock(
self.driver_exec_method,
self.pp_locks[0],
"execute_model",
execute_model_req,
)
)
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(
driver_worker.execute_method.remote, # type: ignore[attr-defined]
self.pp_locks[pp_rank],
"execute_model",
execute_model_req,
)
)
)
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self):
assert not self.use_ray_spmd_worker, (
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1"
)
coros = [
worker.execute_method.remote("start_worker_execution_loop") # type: ignore[attr-defined]
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
def check_health(self) -> None:
# Assume that the Ray workers are healthy.
# TODO: check the health of the Ray workers
return
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)

View File

@ -4,17 +4,16 @@
import os
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import TYPE_CHECKING, Union
import msgspec
import vllm.platforms
from vllm.config import ParallelConfig
from vllm.distributed import get_pp_group
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.sequence import IntermediateTensors
from vllm.utils.network_utils import get_ip
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -51,11 +50,6 @@ try:
# that thread.
self.compiled_dag_cuda_device_set = False
self.input_decoder = msgspec.msgpack.Decoder(
ExecuteModelRequest, dec_hook=decode_hook
)
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
def get_node_ip(self) -> str:
return get_ip()
@ -70,47 +64,6 @@ try:
gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key]
return node_id, gpu_ids
def execute_model_spmd(
self,
req_or_tuple: bytes | tuple[bytes, IntermediateTensors | None],
) -> bytes:
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
Args:
req_or_tuple: A request or a tuple containing the
request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
"""
if isinstance(req_or_tuple, bytes):
serialized_req, intermediate_tensors = req_or_tuple, None
else:
serialized_req, intermediate_tensors = req_or_tuple
execute_model_req = self.input_decoder.decode(serialized_req)
assert self.worker is not None, "Worker is not initialized"
# TODO(swang): This is needed right now because Ray Compiled Graph
# executes on a background thread, so we need to reset torch's
# current device.
if not self.compiled_dag_cuda_device_set:
assert self.worker.device is not None
current_platform.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
output = self.worker._execute_model_spmd( # type: ignore[attr-defined]
execute_model_req, intermediate_tensors
)
# Pipeline model request and output to the next pipeline stage.
if isinstance(output, IntermediateTensors):
output = serialized_req, output
else:
output = self.output_encoder.encode(output)
return output
def setup_device_if_necessary(self):
# TODO(swang): This is needed right now because Ray CG executes
# on a background thread, so we need to reset torch's current
@ -174,6 +127,31 @@ except ImportError as e:
RayWorkerWrapper = None # type: ignore
class FutureWrapper(Future):
"""A wrapper around Ray output reference to meet the interface
of .execute_model(): The top level (core busy loop) expects .result() api
to block and return a single output.
If aggregator is provided, the outputs from all workers are aggregated upon
the result() call. If not only the first worker's output is returned.
"""
def __init__(self, refs, aggregator: KVOutputAggregator | None = None):
super().__init__()
self.refs = refs
self.aggregator = aggregator
def result(self, timeout=None):
if timeout is not None:
raise NotImplementedError("timeout is not supported")
if self.aggregator is None:
return self.refs[0].get()
outputs = [ref.get() for ref in self.refs]
return self.aggregator.aggregate(outputs, output_rank=0)
def ray_is_available() -> bool:
"""Returns True if Ray is available."""
return ray is not None

View File

@ -11,20 +11,18 @@ import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import run_method
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class UniProcExecutor(ExecutorBase):
uses_ray: bool = False
class UniProcExecutor(Executor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
@ -44,9 +42,9 @@ class UniProcExecutor(ExecutorBase):
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
)
self.collective_rpc("init_worker", args=([kwargs],))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
self.driver_worker.init_worker(all_kwargs=[kwargs])
self.driver_worker.init_device()
self.driver_worker.load_model()
def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank)."""
@ -101,16 +99,12 @@ class UniProcExecutor(ExecutorBase):
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
return
def shutdown(self) -> None:
if worker := self.driver_worker:
worker.shutdown()
UniProcExecutorAsync = UniProcExecutor
class ExecutorWithExternalLauncher(UniProcExecutor):
"""An executor that uses external launchers to launch engines,
specially designed for torchrun-compatible launchers, for
@ -128,8 +122,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
and they don't need to synchronize the states with each other.
"""
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
if envs.VLLM_USE_V1:
@ -152,22 +144,12 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
local_rank = int(os.environ["LOCAL_RANK"])
return distributed_init_method, rank, local_rank
def determine_num_available_blocks(self) -> tuple[int, int]:
"""
Determine the number of available KV blocks.
Add an additional all_reduce to get the min across all ranks.
Note that even if we have the same `gpu_memory_utilization` and
`swap_space`, the available memory in every rank might still
differ because NCCL can take different amounts of memory in
different ranks. Therefore, it is necessary to test if all ranks
agree on the same KV cache configuration.
"""
a, b = super().determine_num_available_blocks()
def determine_available_memory(self) -> list[int]: # in bytes
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return a_tensor.item(), b_tensor.item()
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return [memory_tensor.item()]

View File

@ -128,28 +128,6 @@ class WorkerBase:
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
raise NotImplementedError
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
raise NotImplementedError("Dead V0 code")
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.