Compare commits

...

2 Commits

Author SHA1 Message Date
2e773e55b3 docs: merge v1 architecture with class hierarchy 2025-05-17 23:48:12 -07:00
9ab2c02ff8 Support sequence parallelism combined with pipeline parallelism (#18243)
Signed-off-by: cascade812 <cascade812@outlook.com>
2025-05-17 22:47:25 +00:00
4 changed files with 142 additions and 40 deletions

View File

@ -14,8 +14,14 @@ This document provides an overview of the vLLM architecture.
vLLM provides a number of entrypoints for interacting with the system. The
following diagram shows the relationship between them.
:::{image} /assets/design/arch_overview/entrypoints.excalidraw.png
:alt: Entrypoints Diagram
:::{mermaid}
flowchart TD
CLI["vllm CLI"] --> APIServer["OpenAI API Server"]
LLM["LLM Class"] --> LLMEngine
APIServer --> AsyncLLMEngine
LLMEngine --> EngineCoreClient
AsyncLLMEngine --> EngineCoreClient
EngineCoreClient --> EngineCore
:::
### LLM Class
@ -84,8 +90,14 @@ More details on the API server can be found in the [OpenAI-Compatible Server](#o
The `LLMEngine` and `AsyncLLMEngine` classes are central to the functioning of
the vLLM system, handling model inference and asynchronous request processing.
:::{image} /assets/design/arch_overview/llm_engine.excalidraw.png
:alt: LLMEngine Diagram
:::{mermaid}
flowchart LR
Processor --> EngineCoreClient
EngineCoreClient --> EngineCore
EngineCore --> Executor
Executor --> Worker
Worker --> ModelRunner
ModelRunner --> Model
:::
### LLMEngine
@ -104,7 +116,7 @@ processing.
- **Output Processing**: Processes the outputs generated by the model, decoding the
token IDs from a language model into human-readable text.
The code for `LLMEngine` can be found in <gh-file:vllm/engine/llm_engine.py>.
The code for `LLMEngine` can be found in <gh-file:vllm/v1/engine/llm_engine.py>.
### AsyncLLMEngine
@ -116,7 +128,7 @@ can handle multiple concurrent requests and stream outputs to clients.
The OpenAI-compatible API server uses the `AsyncLLMEngine`. There is also a demo
API server that serves as a simpler example in <gh-file:vllm/entrypoints/api_server.py>.
The code for `AsyncLLMEngine` can be found in <gh-file:vllm/engine/async_llm_engine.py>.
The code for `AsyncLLMEngine` can be found in <gh-file:vllm/v1/engine/async_llm.py>.
## Worker
@ -140,15 +152,29 @@ Every model runner object has one model object, which is the actual
`torch.nn.Module` instance. See [huggingface_integration](#huggingface-integration) for how various
configurations affect the class we ultimately get.
## Class Hierarchy
## Class Hierarchy and vLLM V1 Architecture
The following figure shows the class hierarchy of vLLM:
The following diagram shows how the main classes interact:
> :::{figure} /assets/design/hierarchy.png
> :align: center
> :alt: query
> :width: 100%
> :::
:::{mermaid}
classDiagram
class LLMEngine
class AsyncLLMEngine
class EngineCoreClient
class EngineCore
class Executor
class Worker
class ModelRunner
class Model
AsyncLLMEngine --> LLMEngine
LLMEngine --> EngineCoreClient
EngineCoreClient --> EngineCore
EngineCore --> Executor
Executor --> Worker
Worker --> ModelRunner
ModelRunner --> Model
:::
There are several important design choices behind this class hierarchy:
@ -250,3 +276,32 @@ big problem.
In summary, the complete config object `VllmConfig` can be treated as an
engine-level global state that is shared among all vLLM classes.
vLLM V1 introduces a streamlined engine that splits responsibilities between a thin frontend and a highly optimized backend. The design is centered on three core layers:
1. **Frontend (`LLMEngine` and `AsyncLLM`)** user-facing classes that handle tokenization, batching of incoming requests, and postprocessing of generated outputs. These classes interact with the engine core through an `EngineCoreClient`.
2. **Engine Core** the inner loop that schedules requests and runs the model. The core lives in `vllm/v1/engine/core.py` and exposes a lightweight API for adding requests, aborting them, or stepping the model.
3. **Executor and Workers** the executor (for example `MultiprocExecutor` in <gh-file:vllm/v1/executor/multiproc_executor.py>) manages worker processes. Each worker controls a single accelerator device and hosts a `ModelRunner` (such as `GPUModelRunner` in <gh-file:vllm/v1/worker/gpu_model_runner.py>) which executes the forward pass.
### EngineCore and Scheduler
The `EngineCore` maintains a [`Scheduler`](<gh-file:vllm/v1/core/sched/scheduler.py>) and a `KVCacheManager` (<gh-file:vllm/v1/core/kv_cache_manager.py>). At each iteration the scheduler chooses how many tokens to process for every active `Request`, supporting features like prefix caching, chunked prefill and speculative decoding. Scheduled tokens are passed to the model runner and the resulting `EngineCoreOutputs` include generated tokens and per-request events.
The scheduler keeps separate waiting and running queues and enforces limits from
`VllmConfig` such as `max_num_seqs` and `max_num_batched_tokens`. When GPU
memory becomes scarce it can preempt lower priority requests, freeing their KV
cache blocks before resuming them later. After a step finishes it records
statistics and updates each request's progress based on the returned events.
### Communication via EngineCoreClient
To overlap computation with I/O, the engine core often runs in a separate process. `EngineCoreClient` (<gh-file:vllm/v1/engine/core_client.py>) forwards requests and pulls results over ZeroMQ sockets. When using multiple data-parallel ranks, `DPAsyncMPClient` manages a set of engine-core processes and aggregates their outputs.
### Workers and Model Runners
Workers are defined in <gh-dir:vllm/v1/worker>. The default GPU worker initializes CUDA, sets up distributed communication and hosts a `GPUModelRunner` which loads the model, prepares KV cache memory and executes inference kernels. The runner also handles LoRA adapters, attention backends, and cudagraph capture.
### Output Processing
`OutputProcessor` (<gh-file:vllm/v1/engine/output_processor.py>) converts raw `EngineCoreOutputs` into `RequestOutput` objects, assembling logprobs, speculative tokens, and final texts. When using `AsyncLLM`, an asynchronous loop continuously fetches these outputs and streams them back to callers.
This new layering keeps the hot path (`EngineCore`) minimal while letting the frontend focus on user interactions and request bookkeeping. It reduces CPU overhead and simplifies the addition of new optimizations.

View File

@ -26,6 +26,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
sp_enabled: bool
eager_mode: bool
chunked_prefill: bool
@ -60,6 +61,7 @@ class SPTestSettings:
def detailed(
*,
tp_base: int = 2,
pp_base: int = 1,
multi_node_only: bool = False,
task: TaskOption = "auto",
load_format: Optional[str] = None,
@ -67,18 +69,42 @@ class SPTestSettings:
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True)
@ -94,6 +120,7 @@ class SPTestSettings:
def fast(
*,
tp_base: int = 2,
pp_base: int = 1,
task: TaskOption = "auto",
multi_node_only: bool = False,
load_format: Optional[str] = None,
@ -101,6 +128,12 @@ class SPTestSettings:
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
@ -136,6 +169,7 @@ def _compare_sp(
):
(
tp_size,
pp_size,
sp_enabled,
eager_mode,
chunked_prefill,
@ -167,7 +201,6 @@ def _compare_sp(
else:
model_info.check_available_online(on_fail="skip")
pp_size = 1
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
@ -256,7 +289,7 @@ def _compare_sp(
SP_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(),
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
}
SP_TEST_MODELS = [

View File

@ -4287,18 +4287,6 @@ class VllmConfig:
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()
if self.parallel_config is not None and \
self.parallel_config.tensor_parallel_size > 1 and \
self.parallel_config.pipeline_parallel_size > 1 and \
self.compilation_config is not None and \
self.compilation_config.pass_config is not None and \
self.compilation_config.pass_config.enable_sequence_parallelism:
logger.warning_once(
"Sequence parallelism is not supported with pipeline "
"parallelism. Disabling sequence parallelism.")
self.compilation_config.pass_config.\
enable_sequence_parallelism = False
self._set_cudagraph_sizes()
if self.cache_config is not None and \

View File

@ -1056,6 +1056,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
indices=out_indices,
)
def sync_and_slice_intermediate_tensors(
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
sync_self: bool) -> IntermediateTensors:
assert self.intermediate_tensors is not None
tp = self.vllm_config.parallel_config.tensor_parallel_size
enabled_sp = self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism
if enabled_sp:
# When sequence parallelism is enabled, we always pad num_tokens
# to be a multiple of tensor_parallel_size (tp) earlier
assert num_tokens % tp == 0
is_residual_scattered = tp > 1 and enabled_sp \
and num_tokens % tp == 0
# When sequence parallelism is enabled, the "residual" tensor is sharded
# across tensor parallel ranks, so each rank only needs its own slice.
if sync_self:
assert intermediate_tensors is not None
for k, v in intermediate_tensors.items():
is_scattered = "residual" and is_residual_scattered
copy_len = num_tokens // tp if is_scattered else \
num_tokens
self.intermediate_tensors[k][:copy_len].copy_(
v[:copy_len], non_blocking=True)
return IntermediateTensors({
k:
v[:num_tokens // tp]
if k == "residual" and is_residual_scattered else v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
@torch.inference_mode()
def execute_model(
self,
@ -1131,15 +1165,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
for k, v in intermediate_tensors.items():
self.intermediate_tensors[k][:num_input_tokens].copy_(
v[:num_input_tokens], non_blocking=True)
intermediate_tensors = IntermediateTensors({
k: v[:num_input_tokens]
for k, v in self.intermediate_tensors.items()
})
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
# Run the decoder.
# Use persistent buffers for CUDA graphs.
@ -1658,10 +1685,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False)
with set_forward_context(attn_metadata,
self.vllm_config,