Compare commits

..

8 Commits

Author SHA1 Message Date
85013bf094 Prune Ray v1 non-SPMD code paths 2025-09-18 20:42:46 -07:00
07665f8679 Fix Ray executor futures to resolve asynchronously 2025-09-18 17:22:42 -07:00
9fac6aa30b [BugFix] Fix DeepGEMM warmup, no m.weight_scale_inv (#25206)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2025-09-18 14:26:28 -07:00
a53ad626d6 [KV offload][1b/N] rename offloading to kv_offload (#25191)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
2025-09-18 20:53:52 +00:00
1c3dad22ff [V0 Deprecation] Remove unused async_timeout.py (#25190)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 20:35:21 +00:00
d2a30a2d93 [Bug] Fix torch Compilation Cache Hit Error (#25093)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-09-18 12:38:37 -07:00
75fb112d80 [Bug] Fix returned_lse not Defined issue (#25106)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-09-18 19:32:24 +00:00
38db529f66 [feat]: Create interface for model-specific M-RoPE (#24194)
Signed-off-by: AzizCode92 <azizbenothman76@gmail.com>
Signed-off-by: Aziz <azizbenothman76@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2025-09-18 19:18:56 +00:00
22 changed files with 860 additions and 1175 deletions

View File

@ -280,7 +280,7 @@ steps:
# split the test to avoid interference
- pytest -v -s v1/core
- pytest -v -s v1/executor
- pytest -v -s v1/offloading
- pytest -v -s v1/kv_offload
- pytest -v -s v1/sample
- pytest -v -s v1/logits_processors
- pytest -v -s v1/worker

View File

@ -1,15 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import threading
from typing import Optional
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
"""
To run this example, run the following commands simultaneously with
@ -25,67 +22,37 @@ send a request to the instance with DP rank 1.
"""
def _do_background_logging(engine, interval, stop_event):
try:
while not stop_event.is_set():
asyncio.run(engine.do_log_stats())
stop_event.wait(interval)
except Exception as e:
print(f"vLLM background logging shutdown: {e}")
pass
async def main():
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
data_parallel_size=2,
tensor_parallel_size=1,
dtype="auto",
max_model_len=2048,
data_parallel_address="127.0.0.1",
data_parallel_rpc_port=62300,
data_parallel_size_local=1,
enforce_eager=True,
enable_log_requests=True,
disable_custom_all_reduce=True,
)
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
return LoggingStatLogger(config, rank)
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
engine_client = AsyncLLMEngine.from_engine_args(
engine_args,
# Example: Using both regular loggers and aggregated logger
stat_loggers=[per_engine_logger_factory, AggregatedStatLogger],
)
stop_logging_event = threading.Event()
logging_thread = threading.Thread(
target=_do_background_logging,
args=(engine_client, 5, stop_logging_event),
daemon=True,
)
logging_thread.start()
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=100,
)
num_prompts = 10
for i in range(num_prompts):
prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=f"abcdef-{i}",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
stop_logging_event.set()
logging_thread.join()
prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id="abcdef",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
if __name__ == "__main__":

View File

@ -18,7 +18,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
from vllm.v1.metrics.loggers import LoggingStatLogger
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
@ -389,15 +389,6 @@ class MockLoggingStatLogger(LoggingStatLogger):
self.log = MagicMock()
class MockAggregatedStatLogger(AggregatedStatLogger):
def __init__(self,
vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None):
super().__init__(vllm_config, engine_indexes)
self.log = MagicMock()
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers.
@ -424,35 +415,6 @@ async def test_customize_loggers(monkeypatch):
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio
async def test_customize_aggregated_loggers(monkeypatch):
"""Test that we can customize the aggregated loggers.
If a customized logger is provided at the init, it should
be added to the default loggers.
"""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
)
after.callback(engine.shutdown)
await engine.do_log_stats()
stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1
assert len(
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
aggregated_loggers = engine.logger_manager.aggregated_loggers
assert len(aggregated_loggers) == 1
aggregated_loggers[0].log.assert_called_once()
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.offloading.abstract import LoadStoreSpec
from vllm.v1.offloading.worker.worker import (OffloadingHandler,
from vllm.v1.kv_offload.abstract import LoadStoreSpec
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
OffloadingWorker, TransferResult,
TransferSpec)

View File

@ -50,8 +50,8 @@ ALLOWED_FILES = set([
# cloudpickle
'vllm/worker/worker_base.py',
'vllm/executor/mp_distributed_executor.py',
'vllm/executor/ray_distributed_executor.py',
'vllm/entrypoints/llm.py',
'vllm/v1/executor/ray_distributed_executor.py',
'tests/utils.py',
# pickle and cloudpickle
'vllm/utils/__init__.py',

View File

@ -563,18 +563,6 @@ class CompilationConfig:
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
if envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput":
# exclude MoE dispatch/combine from capture by ensuring
# piecewise splitting includes them, so communication remains
# outside CUDA graphs while compute can still be graphed.
moe_ops = [
"vllm.moe_forward",
"vllm.moe_forward_shared",
]
for op in moe_ops:
if op not in self.splitting_ops:
self.splitting_ops.append(op)
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops)

View File

@ -1,173 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Workaround for https://github.com/python/cpython/issues/86296
#
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
# Licensed under the Apache License (Apache-2.0)
import asyncio
import enum
import sys
from types import TracebackType
from typing import Any, Optional, Type
if sys.version_info[:2] >= (3, 11):
from asyncio import timeout as asyncio_timeout
else:
class _State(enum.Enum):
INIT = "INIT"
ENTER = "ENTER"
TIMEOUT = "TIMEOUT"
EXIT = "EXIT"
class Timeout:
# Internal class, please don't instantiate it directly
# Use timeout() and timeout_at() public factories instead.
#
# Implementation note: `async with timeout()` is preferred
# over `with timeout()`.
# While technically the Timeout class implementation
# doesn't need to be async at all,
# the `async with` statement explicitly points that
# the context manager should be used from async function context.
#
# This design allows to avoid many silly misusages.
#
# TimeoutError is raised immediately when scheduled
# if the deadline is passed.
# The purpose is to time out as soon as possible
# without waiting for the next await expression.
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
def __init__(self, deadline: Optional[float],
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._state = _State.INIT
self._timeout_handler = None # type: Optional[asyncio.Handle]
if deadline is None:
self._deadline = None # type: Optional[float]
else:
self.update(deadline)
async def __aenter__(self) -> "Timeout":
self._do_enter()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
self._do_exit(exc_type)
return None
@property
def expired(self) -> bool:
"""Is timeout expired during execution?"""
return self._state == _State.TIMEOUT
@property
def deadline(self) -> Optional[float]:
return self._deadline
def reject(self) -> None:
"""Reject scheduled timeout if any."""
# cancel is maybe better name but
# task.cancel() raises CancelledError in asyncio world.
if self._state not in (_State.INIT, _State.ENTER):
raise RuntimeError(f"invalid state {self._state.value}")
self._reject()
def _reject(self) -> None:
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._timeout_handler = None
def shift(self, delay: float) -> None:
"""Advance timeout on delay seconds.
The delay can be negative.
Raise RuntimeError if shift is called when deadline is not scheduled
"""
deadline = self._deadline
if deadline is None:
raise RuntimeError(
"cannot shift timeout if deadline is not scheduled")
self.update(deadline + delay)
def update(self, deadline: float) -> None:
"""Set deadline to absolute value.
deadline argument points on the time in the same clock system
as loop.time().
If new deadline is in the past the timeout is raised immediately.
Please note: it is not POSIX time but a time with
undefined starting base, e.g. the time of the system power on.
"""
if self._state == _State.EXIT:
raise RuntimeError(
"cannot reschedule after exit from context manager")
if self._state == _State.TIMEOUT:
raise RuntimeError("cannot reschedule expired timeout")
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._deadline = deadline
if self._state != _State.INIT:
self._reschedule()
def _reschedule(self) -> None:
assert self._state == _State.ENTER
deadline = self._deadline
if deadline is None:
return
now = self._loop.time()
if self._timeout_handler is not None:
self._timeout_handler.cancel()
task = asyncio.current_task()
if deadline <= now:
self._timeout_handler = self._loop.call_soon(
self._on_timeout, task)
else:
self._timeout_handler = self._loop.call_at(
deadline, self._on_timeout, task)
def _do_enter(self) -> None:
if self._state != _State.INIT:
raise RuntimeError(f"invalid state {self._state.value}")
self._state = _State.ENTER
self._reschedule()
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
if exc_type is asyncio.CancelledError and \
self._state == _State.TIMEOUT:
self._timeout_handler = None
raise asyncio.TimeoutError
# timeout has not expired
self._state = _State.EXIT
self._reject()
return None
def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
if task:
task.cancel()
self._state = _State.TIMEOUT
# drop the reference early
self._timeout_handler = None
def asyncio_timeout(delay: Optional[float]) -> Timeout:
"""timeout context manager.
Useful in cases when you want to apply timeout logic around block
of code or in cases when asyncio.wait_for is not suitable. For example:
>>> async with timeout(0.001):
... async with aiohttp.get('https://github.com') as r:
... await r.text()
delay - value in seconds or None to disable timeout logic
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + delay if delay is not None else None
return Timeout(deadline, loop)

View File

@ -433,9 +433,9 @@ class LLMEngine:
f"ExecutorBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
raise RuntimeError(
"The Ray distributed executor is only available in the v1 "
"engine. Enable it by setting 'VLLM_USE_V1=1'.")
elif distributed_executor_backend == "mp":
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)

View File

@ -1,699 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import cloudpickle
import msgspec
import vllm.envs as envs
from vllm.executor.executor_base import (
DistributedExecutorBase) # yapf: disable
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
ray)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
if ray is not None:
from ray.actor import ActorHandle
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
else:
ActorHandle = None
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
@dataclass
class RayWorkerMetaData:
"""
Metadata for a Ray worker.
The order of ray worker creation can be random,
and we need to reset the rank after creating all workers.
"""
worker: ActorHandle
created_rank: int
adjusted_rank: int = -1
ip: str = ""
class RayDistributedExecutor(DistributedExecutorBase):
"""Ray-based distributed executor"""
# These env vars are worker-specific, therefore are NOT copied
# from the driver to the workers
WORKER_SPECIFIC_ENV_VARS = {
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
}
# These non-vLLM env vars are copied from the driver to workers
ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}
uses_ray: bool = True
def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
if envs.VLLM_USE_V1:
# V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
# For TPU or XPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu() or current_platform.is_xpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if self.use_ray_compiled_dag:
assert self.use_ray_spmd_worker, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1")
if self.use_ray_spmd_worker:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert self.use_ray_compiled_dag, (
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.uses_ray
initialize_ray_cluster(self.parallel_config)
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
self.use_v1 = envs.VLLM_USE_V1
self.pp_locks: Optional[List[asyncio.Lock]] = None
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
def shutdown(self) -> None:
if logger:
# Somehow logger can be None here.
logger.info(
"Shutting down Ray distributed executor. If you see error log "
"from logging.cc regarding SIGTERM received, please ignore "
"because this is the expected termination process in Ray.")
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
# child class could overwrite this to return actual env vars.
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
bundle_indices: List[int]
if envs.VLLM_RAY_BUNDLE_INDICES:
# Use the bundle indices specified by the user.
bundle_indices = list(
map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
assert len(bundle_indices) == self.parallel_config.world_size, \
("VLLM_RAY_BUNDLE_INDICES must have the same size"
f" as the world size, but got {bundle_indices=} "
f"and {self.parallel_config.world_size=}")
assert len(set(bundle_indices)) == len(bundle_indices), \
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
f" but got {bundle_indices=}")
else:
# use the first N bundles that have GPU resources.
bundle_indices = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if bundle.get(current_platform.ray_device_key, 0):
bundle_indices.append(bundle_id)
bundle_indices = bundle_indices[:self.parallel_config.world_size]
worker_metadata: List[RayWorkerMetaData] = []
driver_ip = get_ip()
for rank, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
if current_platform.ray_device_key == "GPU":
# NV+AMD GPUs, and Intel XPUs
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rpc_rank=rank)
else:
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rpc_rank=rank)
worker_metadata.append(
RayWorkerMetaData(worker=worker, created_rank=rank))
worker_ips = ray.get([
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
for each in worker_metadata
])
for each, ip in zip(worker_metadata, worker_ips):
each.ip = ip
if not self.use_ray_spmd_worker:
for i, each in enumerate(worker_metadata):
# find and remove the dummy worker from the list
worker = each.worker
worker_ip = each.ip
if self.driver_dummy_worker is None and worker_ip == driver_ip:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config, rpc_rank=0)
worker_metadata.pop(i)
break
logger.debug("workers: %s", worker_metadata)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node."
f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
"Consider adjusting the Ray placement group or running "
"the driver on a GPU node.")
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = item.ip
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
sorted_worker_metadata = sorted(worker_metadata,
key=sort_by_driver_then_worker_ip)
start_rank = 0 if self.use_ray_spmd_worker else 1
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i + start_rank
self.workers = [item.worker for item in sorted_worker_metadata]
rerank_mapping = {
item.created_rank: item.adjusted_rank
for item in sorted_worker_metadata
}
self._run_workers("adjust_rank", rerank_mapping)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote()) \
) # type: ignore
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP`"
" environment variable, make sure it is unique for"
" each node.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [{
current_platform.device_control_env_var:
",".join(map(str, node_gpus[node_id])),
} for (node_id, _) in worker_node_and_gpu_ids]
# Environment variables to copy from driver to workers
env_vars_to_copy = get_env_vars_to_copy(
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
additional_vars=set(current_platform.additional_env_vars).union(
self.ADDITIONAL_ENV_VARS),
destination="workers")
# Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables:
# TODO: refactor platform-specific env vars
for name in env_vars_to_copy:
if name in os.environ:
args[name] = os.environ[name]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
all_kwargs.append(kwargs)
self._run_workers("init_worker", all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
if self.use_v1:
serialized_data = execute_model_req
else:
serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data))
if self.use_v1:
output = outputs[0]
else:
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers(
self,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
Args:
- async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode.")
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(sent_method, *args, **kwargs)
for worker in ray_workers
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(sent_method, *args, **kwargs)
]
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _check_ray_cgraph_installation(self):
import importlib.metadata
from packaging import version
required_version = version.parse("2.43.0")
current_version = version.parse(importlib.metadata.version("ray"))
if current_version < required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
import importlib.util
cgraph_spec = importlib.util.find_spec(
"ray.experimental.compiled_dag_ref")
if cgraph_spec is None:
raise ValueError("Ray Compiled Graph is not installed. "
"Run `pip install ray[cgraph]` to install it.")
cupy_spec = importlib.util.find_spec("cupy")
if (cupy_spec is None
and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
raise ValueError(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
"Run `pip install ray[cgraph]` and check cupy installation.")
def _compiled_ray_dag(self, enable_asyncio: bool):
assert self.parallel_config.use_ray
self._check_ray_cgraph_installation()
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
# i.e., the distributed execution that includes model forward runs and
# intermediate tensor communications, in the case of vllm.
# Note: we should set this env var before importing
# ray.dag, otherwise it will not take effect.
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
from ray.dag import InputNode, MultiOutputNode
logger.info("RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
if channel_type not in ("auto", "nccl", "shm"):
raise ValueError(
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
#
# For V0:
# ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501
# ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501
# ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501
# ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501
#
# For V1:
# SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501
# SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501
# SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501
# SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput # noqa: E501
# All workers in the first TP group will take in the
# ExecuteModelRequest as input.
outputs = [input_data for _ in self.pp_tp_workers[0]]
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
# Each PP worker takes in the output of the previous PP worker,
# and the TP group executes in SPMD fashion.
if self.use_v1:
outputs = [
worker.execute_model_ray.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
else:
outputs = [
worker.execute_model_spmd.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
last_pp_rank = len(self.pp_tp_workers) - 1
if (pp_rank < last_pp_rank and
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
# Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last
# pp stage or when using shared memory (the default).
transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
outputs = [
output.with_tensor_transport(transport=transport)
for output in outputs
]
forward_dag = MultiOutputNode(outputs)
if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
from ray.experimental.channel.accelerator_context import (
register_accelerator_context)
from vllm.distributed.device_communicators.ray_communicator import (
RayPPCommunicator)
register_accelerator_context(torch_module_name="cuda",
communicator_cls=RayPPCommunicator)
logger.info("Using RayPPCommunicator "
"(which wraps vLLM _PP GroupCoordinator) "
"for Ray Compiled Graph communication.")
else:
logger.info("Using Ray's NCCL communicator for "
"Ray Compiled Graph communication.")
return forward_dag.experimental_compile(
enable_asyncio=enable_asyncio,
_overlap_gpu_communication=envs.
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
def __del__(self):
self.shutdown()
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
output = await dag_future[0]
return self.output_decoder.decode(output)
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
if not self.tp_driver_workers:
return await self.driver_exec_method("execute_model",
execute_model_req)
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method.remote,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self):
assert not self.use_ray_spmd_worker, (
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
def check_health(self) -> None:
# Assume that the Ray workers are healthy.
# TODO: check the health of the Ray workers
return

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, SupportsTranscription, SupportsV0Only,
has_inner_state, supports_lora, supports_multimodal,
supports_pp, supports_transcription, supports_v0_only)
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMRoPE,
SupportsMultiModal, SupportsPP, SupportsTranscription,
SupportsV0Only, has_inner_state, supports_lora,
supports_mrope, supports_multimodal, supports_pp,
supports_transcription, supports_v0_only)
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
is_pooling_model, is_text_generation_model)
from .registry import ModelRegistry
@ -21,6 +22,8 @@ __all__ = [
"supports_lora",
"SupportsMultiModal",
"supports_multimodal",
"SupportsMRoPE",
"supports_mrope",
"SupportsPP",
"supports_pp",
"SupportsTranscription",

View File

@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
import numpy as np
import torch
from torch import Tensor
from transformers import PretrainedConfig
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs
@ -852,3 +853,70 @@ def supports_eagle3(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
return isinstance(model, SupportsEagle3)
@runtime_checkable
class SupportsMRoPE(Protocol):
"""The interface required for all models that support M-RoPE."""
supports_mrope: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports M-RoPE.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""
Get M-RoPE input positions and delta value for this specific model.
This method should be implemented by each model that supports M-RoPE
to provide model-specific logic for computing input positions.
Args:
input_tokens: List of input token IDs
hf_config: HuggingFace model configuration
image_grid_thw: Image grid dimensions (t, h, w)
video_grid_thw: Video grid dimensions (t, h, w)
second_per_grid_ts: Seconds per grid timestep for videos
context_len: Context length
seq_len: Sequence length
audio_feature_lengths: Audio feature lengths for multimodal models
use_audio_in_video: Whether to use audio in video for interleaving
Returns:
Tuple of (llm_positions, mrope_position_delta)
- llm_positions: Tensor of shape [3, num_tokens]
with T/H/W positions
- mrope_position_delta: Delta for position calculations
"""
...
@overload
def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]:
...
@overload
def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]:
...
def supports_mrope(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMRoPE]], TypeIs[SupportsMRoPE]]:
return isinstance(model, SupportsMRoPE)

View File

@ -32,7 +32,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import AutoConfig, BatchFeature
from transformers import AutoConfig, BatchFeature, PretrainedConfig
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
Qwen2VLProcessor)
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
@ -73,7 +73,7 @@ from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
@ -1096,7 +1096,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
info=Qwen2VLProcessingInfo,
dummy_inputs=Qwen2VLDummyInputsBuilder)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
SupportsLoRA, SupportsPP, SupportsMRoPE):
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@ -1109,6 +1109,118 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.": "language_model.model.",
})
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get M-RoPE input positions for Qwen2-VL model."""
if image_grid_thw is None:
image_grid_thw = []
if video_grid_thw is None:
video_grid_thw = []
if second_per_grid_ts is None:
second_per_grid_ts = []
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
tokens_per_second).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):

View File

@ -36,7 +36,7 @@ def _extract_data_from_linear_base_module(
assert m.quant_method.quant_config is not None
w = m.weight
ws = m.weight_scale_inv
ws = m.weight_scale
quant_block_size = m.quant_method.quant_config.weight_block_size
assert isinstance(w, torch.Tensor)

View File

@ -191,14 +191,17 @@ class CudaPlatformBase(Platform):
compilation_config = vllm_config.compilation_config
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and compilation_config.cudagraph_mode
not in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE]):
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
# TODO: Piecewise Cuda graph might be enabled
# if torch compile cache key issue fixed
# See https://github.com/vllm-project/vllm/pull/25093
logger.info(
"Data Parallel with DeepEP high-throughput: using PIECEWISE "
"CUDA graphs and excluding MoE ops from capture. Set "
"VLLM_ALL2ALL_BACKEND=deepep_low_latency if you need MoE "
"graphs captured as well.")
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
"Data Parallel: disabling cudagraphs since DP "
"with DeepEP high-throughput kernels are not CUDA Graph "
"compatible. The DeepEP low-latency kernels are CUDA Graph "
"compatible. Set the all_to_all backend to deepep_low_latency "
"to use those kernels instead.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
@classmethod
def get_current_memory_usage(cls,

View File

@ -206,12 +206,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
)
if H < MAX_HEADS:
# Extract the subsets of the outputs
returned_lse = lse[:, :H].contiguous(
) if self.need_to_return_lse_for_decode else lse
out = out[:, :H]
if self.need_to_return_lse_for_decode:
lse = lse[:, :H].contiguous()
return out, returned_lse
return out, lse
def _forward_decode(
self,

View File

@ -1,62 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import threading
from collections import defaultdict
from concurrent.futures import Future
from typing import Optional, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import cloudpickle
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.executor.executor_base import DistributedExecutorBase
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
ray)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
try: # msgspec is optional at runtime but required for serialization.
import msgspec
except ImportError: # pragma: no cover - msgspec is an optional dependency.
msgspec = None # type: ignore
if ray is not None:
from ray.actor import ActorHandle
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
else:
ActorHandle = None
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
class FutureWrapper(Future):
"""A wrapper around Ray output reference to meet the interface
of .execute_model(): The top level (core busy loop) expects .result() api
to block and return a single output.
If aggregator is provided, the outputs from all workers are aggregated upon
the result() call. If not only the first worker's output is returned.
@dataclass
class RayWorkerMetaData:
"""
Metadata for a Ray worker.
The order of ray worker creation can be random,
and we need to reset the rank after creating all workers.
"""
def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None):
worker: ActorHandle
created_rank: int
adjusted_rank: int = -1
ip: str = ""
class FutureWrapper(Future):
"""Future compatible wrapper around Ray object references."""
def __init__(self,
refs,
aggregator: Optional[KVOutputAggregator] = None) -> None:
super().__init__()
self.refs = refs
self.aggregator = aggregator
self._refs = refs
self._aggregator = aggregator
# Resolve the Ray object references off-thread so that the driver event
# loop is not blocked and Future callbacks fire when the result is
# ready.
threading.Thread(target=self._resolve, daemon=True).start()
def result(self, timeout=None):
if timeout is not None:
raise NotImplementedError("timeout is not supported")
def cancel(self) -> bool: # pragma: no cover - cancellation unsupported.
return False
if self.aggregator is None:
return self.refs[0].get()
outputs = [ref.get() for ref in self.refs]
return self.aggregator.aggregate(outputs, output_rank=0)
def _resolve(self) -> None:
try:
if ray is None:
raise RuntimeError("Ray is required to resolve distributed "
"results.")
outputs = ray.get(self._refs)
if self._aggregator is None:
result = outputs[0]
else:
result = self._aggregator.aggregate(outputs, output_rank=0)
self.set_result(result)
except BaseException as exc: # pragma: no cover - Ray errors propagated.
self.set_exception(exc)
finally:
self._refs = None
self._aggregator = None
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
"""Ray distributed executor using Ray Compiled Graphs."""
class RayDistributedExecutor(DistributedExecutorBase, Executor):
"""Ray-based distributed executor for the v1 engine."""
# These env vars are worker-specific, therefore are NOT copied
# from the driver to the workers
WORKER_SPECIFIC_ENV_VARS = {
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
}
# These non-vLLM env vars are copied from the driver to workers
ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}
uses_ray: bool = True
supports_pp: bool = True
def _init_executor(self) -> None:
super()._init_executor()
self.forward_dag: Optional[ray.dag.CompiledDAG] = None # type: ignore
# V1 executor always relies on the SPMD worker implementation which in
# turn requires the compiled DAG API.
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
# For TPU or XPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu() or current_platform.is_xpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# These flags configure the worker setup.
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if self.use_ray_compiled_dag:
assert self.use_ray_spmd_worker, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1")
if self.use_ray_spmd_worker:
assert self.use_ray_compiled_dag, (
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.uses_ray
initialize_ray_cluster(self.parallel_config)
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
# msgspec is only required when compiled DAG is disabled which is not
# expected for V1, but initialize the codec for completeness.
if msgspec is not None:
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
else: # pragma: no cover - msgspec should normally be available.
self.input_encoder = None
self.output_decoder = None
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
"""Ray distributed executor supports pipeline parallelism."""
if self.scheduler_config.async_scheduling:
return 2
return self.parallel_config.pipeline_parallel_size
@ -66,43 +169,443 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
"""Execute the model on the Ray workers.
"""Execute the model on the Ray workers."""
Args:
scheduler_output: The scheduler output to execute.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore
refs = self.forward_dag.execute(scheduler_output)
if not self.has_connector:
# Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available.
if not non_block:
return refs[0].get()
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs)
# Get output from all workers when connector is present
assert self.kv_output_aggregator is not None, (
"KVOutputAggregator must be initialized when kv transfer is "
"configured")
if not non_block:
# Block and get results from all workers
outputs = [ref.get() for ref in refs]
return self.kv_output_aggregator.aggregate(outputs)
# Return a future that will aggregate outputs from all workers
return FutureWrapper(refs, self.kv_output_aggregator)
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
self._run_workers("reinitialize_distributed", reconfig_request)
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown()
def shutdown(self) -> None:
if logger:
# Somehow logger can be None here.
logger.info(
"Shutting down Ray distributed executor. If you see error log "
"from logging.cc regarding SIGTERM received, please ignore "
"because this is the expected termination process in Ray.")
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray as _ray
for worker in self.workers:
_ray.kill(worker)
self.forward_dag = None
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
# child class could overwrite this to return actual env vars.
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
# Ray actors that perform all model execution.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
bundle_indices: List[int]
if envs.VLLM_RAY_BUNDLE_INDICES:
# Use the bundle indices specified by the user.
bundle_indices = list(
map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
assert len(bundle_indices) == self.parallel_config.world_size, \
("VLLM_RAY_BUNDLE_INDICES must have the same size"
f" as the world size, but got {bundle_indices=} "
f"and {self.parallel_config.world_size=}")
assert len(set(bundle_indices)) == len(bundle_indices), \
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
f" but got {bundle_indices=}")
else:
# use the first N bundles that have GPU resources.
bundle_indices = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if bundle.get(current_platform.ray_device_key, 0):
bundle_indices.append(bundle_id)
bundle_indices = bundle_indices[:self.parallel_config.world_size]
worker_metadata: List[RayWorkerMetaData] = []
driver_ip = get_ip()
for rank, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
if current_platform.ray_device_key == "GPU":
# NV+AMD GPUs, and Intel XPUs
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rpc_rank=rank)
else:
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rpc_rank=rank)
worker_metadata.append(
RayWorkerMetaData(worker=worker, created_rank=rank))
worker_ips = ray.get([
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
for each in worker_metadata
])
for each, ip in zip(worker_metadata, worker_ips):
each.ip = ip
logger.debug("workers: %s", worker_metadata)
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = item.ip
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
sorted_worker_metadata = sorted(worker_metadata,
key=sort_by_driver_then_worker_ip)
start_rank = 0
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i + start_rank
self.workers = [item.worker for item in sorted_worker_metadata]
rerank_mapping = {
item.created_rank: item.adjusted_rank
for item in sorted_worker_metadata
}
self._run_workers("adjust_rank", rerank_mapping)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in self.workers:
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote())
) # type: ignore
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP`"
" environment variable, make sure it is unique for"
" each node.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [{
current_platform.device_control_env_var:
",".join(map(str, node_gpus[node_id])),
} for (node_id, _) in worker_node_and_gpu_ids]
# Environment variables to copy from driver to workers
env_vars_to_copy = get_env_vars_to_copy(
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
additional_vars=set(current_platform.additional_env_vars).union(
self.ADDITIONAL_ENV_VARS),
destination="workers")
# Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables:
# TODO: refactor platform-specific env vars
for name in env_vars_to_copy:
if name in os.environ:
args[name] = os.environ[name]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
all_kwargs.append(kwargs)
self._run_workers("init_worker", all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
raise RuntimeError(
"RayDistributedExecutor only supports compiled DAG execution "
"and does not expose a separate driver worker loop.")
def _run_workers(
self,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode.")
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(sent_method, *args, **kwargs)
for worker in self.workers
]
if not self.workers:
return []
# Get the results of the ray workers.
return ray.get(ray_worker_outputs)
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers()."""
ray.get(parallel_worker_tasks)
def _check_ray_cgraph_installation(self):
import importlib.metadata
from packaging import version
required_version = version.parse("2.43.0")
current_version = version.parse(importlib.metadata.version("ray"))
if current_version < required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
import importlib.util
cgraph_spec = importlib.util.find_spec(
"ray.experimental.compiled_dag_ref")
if cgraph_spec is None:
raise ValueError("Ray Compiled Graph is not installed. "
"Run `pip install ray[cgraph]` to install it.")
cupy_spec = importlib.util.find_spec("cupy")
if (cupy_spec is None
and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
raise ValueError(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
"Run `pip install ray[cgraph]` and check cupy installation.")
def _compiled_ray_dag(self, enable_asyncio: bool):
assert self.parallel_config.use_ray
self._check_ray_cgraph_installation()
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
from ray.dag import InputNode, MultiOutputNode
logger.info("RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
if channel_type not in ("auto", "nccl", "shm"):
raise ValueError(
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
# SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) ->
# 4 -> ModelRunnerOutput, etc.
outputs = [input_data for _ in self.pp_tp_workers[0]]
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
outputs = [
worker.execute_model_ray.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
last_pp_rank = len(self.pp_tp_workers) - 1
if (pp_rank < last_pp_rank and
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
outputs = [
output.with_tensor_transport(transport=transport)
for output in outputs
]
forward_dag = MultiOutputNode(outputs)
if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
from ray.experimental.channel.accelerator_context import (
register_accelerator_context)
from vllm.distributed.device_communicators.ray_communicator import (
RayPPCommunicator)
register_accelerator_context(torch_module_name="cuda",
communicator_cls=RayPPCommunicator)
logger.info("Using RayPPCommunicator "
"(which wraps vLLM _PP GroupCoordinator) "
"for Ray Compiled Graph communication.")
else:
logger.info("Using Ray's NCCL communicator for "
"Ray Compiled Graph communication.")
return forward_dag.experimental_compile(
enable_asyncio=enable_asyncio,
_overlap_gpu_communication=envs.
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
def __del__(self):
self.shutdown()
async def execute_model_async(
self,
scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
return await make_async(self.execute_model)(scheduler_output)
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
raise RuntimeError(
"RayDistributedExecutor only supports compiled DAG execution "
"and does not expose a separate driver worker loop.")
async def _start_worker_execution_loop(self):
raise RuntimeError(
"RayDistributedExecutor only supports compiled DAG execution "
"and does not expose a separate driver worker loop.")
def check_health(self) -> None:
# Assume that the Ray workers are healthy.
# TODO: check the health of the Ray workers
return

View File

@ -4,7 +4,7 @@ from abc import ABC
import numpy as np
from vllm.v1.offloading.abstract import LoadStoreSpec
from vllm.v1.kv_offload.abstract import LoadStoreSpec
class BlockIDsLoadStoreSpec(LoadStoreSpec, ABC):

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from vllm.logger import init_logger
from vllm.v1.offloading.abstract import LoadStoreSpec
from vllm.v1.kv_offload.abstract import LoadStoreSpec
# a single transfer spec (src_blocks_spec, dst_blocks_spec)
TransferSpec = tuple[LoadStoreSpec, LoadStoreSpec]

View File

@ -18,9 +18,7 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
StatLoggerFactory = Union[PerEngineStatLoggerFactory,
type["AggregatedStatLoggerBase"]]
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
class StatLoggerBase(ABC):
@ -50,16 +48,6 @@ class StatLoggerBase(ABC):
pass
class AggregatedStatLoggerBase(StatLoggerBase):
"""Abstract base class for loggers that
aggregates statistics across multiple engines."""
@abstractmethod
def __init__(self, vllm_config: VllmConfig,
engine_indexes: Optional[list[int]]):
...
class LoggingStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
@ -73,7 +61,6 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging = SpecDecodingLogging()
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
self.engine_is_idle = False
def _reset(self, now):
self.last_log_time = now
@ -113,25 +100,25 @@ class LoggingStatLogger(StatLoggerBase):
self.last_scheduler_stats = scheduler_stats
def get_log_stats(self):
def log(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
self.engine_is_idle = not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput))
def log(self):
self.get_log_stats()
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if self.engine_is_idle:
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
# Format and print output.
log_fn(
"Engine %03d: "
@ -141,11 +128,11 @@ class LoggingStatLogger(StatLoggerBase):
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
self.engine_index,
self.last_prompt_throughput,
self.last_generation_throughput,
self.last_scheduler_stats.num_running_reqs,
self.last_scheduler_stats.num_waiting_reqs,
self.last_scheduler_stats.kv_cache_usage * 100,
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
@ -158,61 +145,7 @@ class LoggingStatLogger(StatLoggerBase):
self.vllm_config.cache_config.num_gpu_blocks)
class AggregatedStatLogger(LoggingStatLogger, AggregatedStatLoggerBase):
def __init__(self,
vllm_config: VllmConfig,
engine_idxs: Optional[list[int]] = None):
if engine_idxs is None:
engine_idxs = [0]
self.engine_idxs = engine_idxs
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
if engine_idx not in self.engine_idxs:
logger.warning("Unexpected engine_idx: %d", engine_idx)
return
LoggingStatLogger.record(self, scheduler_stats, iteration_stats,
engine_idx)
def log(self):
self.get_log_stats()
log_fn = logger.info
if self.engine_is_idle:
# Avoid log noise on an idle production system
log_fn = logger.debug
# Format and print output.
log_fn(
"%s Engines Aggregated: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
len(self.engine_idxs),
self.last_prompt_throughput,
self.last_generation_throughput,
self.last_scheduler_stats.num_running_reqs,
self.last_scheduler_stats.num_waiting_reqs,
self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info(
"%d Engines: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", len(self.engine_idxs),
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(AggregatedStatLoggerBase):
class PrometheusStatLogger(StatLoggerBase):
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
@ -741,32 +674,23 @@ class StatLoggerManager:
# engine_idx: StatLogger
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
self.aggregated_loggers: list[AggregatedStatLoggerBase] = []
aggregated_loggers_factories = set()
prometheus_factory = PrometheusStatLogger
for engine_idx in self.engine_idxs:
loggers: list[StatLoggerBase] = []
for logger_factory in factories:
# If we get a custom prometheus logger or aggregated logger,
# We initialize it separately with all engine idxs.
# A custom prometheus logger is typically used for the ray.
if (isinstance(logger_factory, type) and issubclass(
logger_factory, AggregatedStatLoggerBase)):
aggregated_loggers_factories.add(logger_factory)
else:
loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
# If we get a custom prometheus logger, use that
# instead. This is typically used for the ray case.
if (isinstance(logger_factory, type)
and issubclass(logger_factory, PrometheusStatLogger)):
prometheus_factory = logger_factory
continue
loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
self.per_engine_logger_dict[engine_idx] = loggers
# If no custom aggregated logger is provide,
# we by default use PrometheusStatLogger
if not aggregated_loggers_factories:
aggregated_loggers_factories.add(PrometheusStatLogger)
# For custom aggregated logger(or default Prometheus Logger)
# need to share the metrics between EngineCores.
# For Prometheus, need to share the metrics between EngineCores.
# Each EngineCore's metrics are expressed as a unique label.
for aggregated_loggers_factory in aggregated_loggers_factories:
self.aggregated_loggers.append(
aggregated_loggers_factory(vllm_config, engine_idxs))
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
def record(
self,
@ -780,19 +704,18 @@ class StatLoggerManager:
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
for logger in self.aggregated_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
self.prometheus_logger.record(scheduler_stats, iteration_stats,
engine_idx)
def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log()
for logger in self.aggregated_loggers:
logger.log()
def log_engine_initialized(self):
for agg_logger in self.aggregated_loggers:
agg_logger.log_engine_initialized()
self.prometheus_logger.log_engine_initialized()
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()

View File

@ -42,6 +42,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
supports_eagle3,
supports_mrope,
supports_transcription)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
@ -730,16 +731,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
req_state.mrope_positions, req_state.mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
if supports_mrope(self.model):
req_state.mrope_positions, req_state.mrope_position_delta = \
self.model.get_mrope_input_positions(
req_state.prompt_token_ids,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
else:
req_state.mrope_positions, req_state.mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
def _extract_mm_kwargs(
self,

View File

@ -41,7 +41,8 @@ from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
get_sampler)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models import (supports_lora, supports_mrope,
supports_multimodal)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
@ -670,18 +671,33 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.seq_ids[seq_idx]]
token_ids = seq_data.get_token_ids()
mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=inter_data.context_lens[seq_idx],
seq_len=inter_data.seq_lens[seq_idx],
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
if supports_mrope(self.runner.model):
mrope_input_positions, mrope_position_delta = \
self.runner.model.get_mrope_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=inter_data.context_lens[seq_idx],
seq_len=inter_data.seq_lens[seq_idx],
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
mrope_input_positions = mrope_input_positions.tolist()
else:
mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=inter_data.context_lens[seq_idx],
seq_len=inter_data.seq_lens[seq_idx],
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
seq_data.mrope_position_delta = mrope_position_delta
inter_data.mrope_input_positions[