This commit is contained in:
Woosuk Kwon
2025-10-30 16:30:06 -07:00
parent 110770170f
commit 09e4b2f6eb
13 changed files with 1001 additions and 502 deletions

View File

@ -10,6 +10,7 @@ torchaudio==2.9.0
# These must be updated alongside torch # These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 # Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9 # xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.4.1 flashinfer-python==0.4.1
apache-tvm-ffi==0.1.0b15

View File

@ -34,6 +34,7 @@ else:
class NewRequestData: class NewRequestData:
req_id: str req_id: str
prompt_token_ids: list[int] | None prompt_token_ids: list[int] | None
prefill_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec] mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None sampling_params: SamplingParams | None
pooling_params: PoolingParams | None pooling_params: PoolingParams | None
@ -51,6 +52,7 @@ class NewRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
prefill_token_ids=request._all_token_ids,
mm_features=request.mm_features, mm_features=request.mm_features,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
pooling_params=request.pooling_params, pooling_params=request.pooling_params,
@ -173,6 +175,7 @@ class SchedulerOutput:
# This can be used for cascade attention. # This can be used for cascade attention.
num_common_prefix_blocks: list[int] num_common_prefix_blocks: list[int]
preempted_req_ids: set[str]
# Request IDs that are finished in between the previous and the current # Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests # steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests. # so that they can free the cached states for those requests.

View File

@ -606,6 +606,9 @@ class Scheduler(SchedulerInterface):
) )
# Construct the scheduler output. # Construct the scheduler output.
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
scheduled_resumed_reqs = []
new_reqs_data = [ new_reqs_data = [
NewRequestData.from_request( NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids() req, req_to_new_blocks[req.request_id].get_block_ids()
@ -635,6 +638,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs, scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks, num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids={req.request_id for req in preempted_reqs},
# finished_req_ids is an existing state in the scheduler, # finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step. # instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between # It contains the request IDs that are finished in between
@ -720,14 +724,6 @@ class Scheduler(SchedulerInterface):
req.num_computed_tokens : req.num_computed_tokens + num_tokens req.num_computed_tokens : req.num_computed_tokens + num_tokens
] ]
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
if idx >= num_running_reqs:
assert not scheduled_in_prev_step
resumed_req_ids.add(req_id)
if not scheduled_in_prev_step:
all_token_ids[req_id] = req.all_token_ids[
: req.num_computed_tokens + num_tokens
]
new_block_ids.append( new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True) req_to_new_blocks[req_id].get_block_ids(allow_none=True)
) )
@ -902,7 +898,6 @@ class Scheduler(SchedulerInterface):
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]: ) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
num_sampled_tokens = model_runner_output.num_sampled_tokens
logprobs = model_runner_output.logprobs logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens

View File

@ -15,7 +15,6 @@ else:
class LogprobsLists(NamedTuple): class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
logprob_token_ids: np.ndarray logprob_token_ids: np.ndarray
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
@ -135,13 +134,14 @@ class KVConnectorOutput:
class ModelRunnerOutput: class ModelRunnerOutput:
# [num_reqs] # [num_reqs]
req_ids: list[str] req_ids: list[str]
# req_id -> index
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens # num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens # num_generated_tokens is the number of tokens
# generated in the current step. It can be different for # generated in the current step. It can be different for
# each request due to speculative/jump decoding. # each request due to speculative/jump decoding.
sampled_token_ids: np.ndarray | None sampled_token_ids: list[list[int]]
num_sampled_tokens: np.ndarray | None
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
@ -186,8 +186,8 @@ class DraftTokenIds:
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[], req_ids=[],
sampled_token_ids=None, req_id_to_index={},
num_sampled_tokens=None, sampled_token_ids=[],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],

View File

@ -1,21 +1,28 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import numpy as np
import torch import torch
from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors, from vllm.v1.outputs import (
ModelRunnerOutput, SamplerOutput) AsyncModelRunnerOutput,
ModelRunnerOutput,
SamplerOutput,
)
class AsyncOutput(AsyncModelRunnerOutput): class AsyncOutput(AsyncModelRunnerOutput):
def __init__( def __init__(
self, self,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput, sampler_output: SamplerOutput,
num_sampled_tokens: np.ndarray,
copy_stream: torch.cuda.Stream, copy_stream: torch.cuda.Stream,
): ):
self.model_runner_output = model_runner_output self.model_runner_output = model_runner_output
self.sampler_output = sampler_output self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
self.copy_stream = copy_stream self.copy_stream = copy_stream
self.copy_event = torch.cuda.Event() self.copy_event = torch.cuda.Event()
@ -23,26 +30,46 @@ class AsyncOutput(AsyncModelRunnerOutput):
with torch.cuda.stream(self.copy_stream): with torch.cuda.stream(self.copy_stream):
self.copy_stream.wait_stream(default_stream) self.copy_stream.wait_stream(default_stream)
# NOTE(woosuk): We should keep the CPU tensors unfreed, until the copy completes.
self.sampled_token_ids = sampler_output.sampled_token_ids.to( self.sampled_token_ids = sampler_output.sampled_token_ids.to(
"cpu", non_blocking=True) "cpu", non_blocking=True
x = sampler_output.logprobs_tensors )
if x is not None: if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors = LogprobsTensors( self.logprobs_tensors = (
logprob_token_ids=x.logprob_token_ids.to( sampler_output.logprobs_tensors.to_cpu_nonblocking()
"cpu", non_blocking=True),
logprobs=x.logprobs.to("cpu", non_blocking=True),
selected_token_ranks=x.selected_token_ranks.to(
"cpu", non_blocking=True),
) )
else: else:
self.logprobs_tensors = None self.logprobs_tensors = None
self.copy_event.record() self.prompt_logprobs_dict = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
self.copy_event.record(self.copy_stream)
def get_output(self) -> ModelRunnerOutput: def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize() self.copy_event.synchronize()
self.model_runner_output.sampled_token_ids = (
self.sampled_token_ids.numpy()) # NOTE(woosuk): The following code ensures compatibility with OSS vLLM.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids_np = self.sampled_token_ids.numpy()
sampled_token_ids = sampled_token_ids_np.tolist()
for i, tokens in enumerate(sampled_token_ids):
del tokens[self.num_sampled_tokens[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.logprobs_tensors is not None: if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = ( self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
self.logprobs_tensors.tolists()) self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
return self.model_runner_output return self.model_runner_output
@contextmanager
def async_barrier(event: torch.cuda.Event | None):
if event is not None:
event.synchronize()
try:
yield
finally:
if event is not None:
event.record()

View File

@ -7,9 +7,17 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backends.utils import (
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, AttentionMetadataBuilder,
KVCacheSpec, SlidingWindowSpec) CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
SlidingWindowSpec,
)
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.utils import bind_kv_cache from vllm.v1.worker.utils import bind_kv_cache
@ -18,7 +26,6 @@ def get_kv_cache_spec(
kv_cache_dtype: torch.dtype, kv_cache_dtype: torch.dtype,
) -> dict[str, KVCacheSpec]: ) -> dict[str, KVCacheSpec]:
block_size = vllm_config.cache_config.block_size block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, Attention) attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
@ -31,7 +38,6 @@ def get_kv_cache_spec(
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=kv_cache_dtype, dtype=kv_cache_dtype,
sliding_window=attn_module.sliding_window, sliding_window=attn_module.sliding_window,
use_mla=use_mla,
) )
else: else:
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
@ -39,7 +45,6 @@ def get_kv_cache_spec(
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=kv_cache_dtype, dtype=kv_cache_dtype,
use_mla=use_mla,
) )
return kv_cache_spec return kv_cache_spec
@ -52,6 +57,7 @@ def init_attn_backend(
attn_backends: dict[str, AttentionBackend] = {} attn_backends: dict[str, AttentionBackend] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = [] attn_metadata_builders: list[AttentionMetadataBuilder] = []
flashinfer_workspace: torch.Tensor | None = None
attn_layers = get_layers_from_vllm_config(vllm_config, Attention) attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for kv_cache_group_spec in kv_cache_config.kv_cache_groups: for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names layer_names = kv_cache_group_spec.layer_names
@ -67,7 +73,13 @@ def init_attn_backend(
vllm_config, vllm_config,
device, device,
) )
attn_metadata_builders.append(attn_metadata_builder) attn_metadata_builders.append(attn_metadata_builder) # type: ignore
if "FLASHINFER" in attn_backend.get_name():
if flashinfer_workspace is None:
flashinfer_workspace = attn_metadata_builder.get_workspace_buffer()
else:
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
return attn_backends, attn_metadata_builders return attn_backends, attn_metadata_builders
@ -77,9 +89,7 @@ def _allocate_kv_cache(
): ):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {} kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors: for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
dtype=torch.int8,
device=device)
for layer_name in kv_cache_tensor.shared_by: for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor kv_cache_raw_tensors[layer_name] = tensor
@ -87,8 +97,9 @@ def _allocate_kv_cache(
for group in kv_cache_config.kv_cache_groups: for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names: for layer_name in group.layer_names:
layer_names.add(layer_name) layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys() assert layer_names == set(kv_cache_raw_tensors.keys()), (
), "Some layers are not correctly initialized" "Some layers are not correctly initialized"
)
return kv_cache_raw_tensors return kv_cache_raw_tensors
@ -103,17 +114,19 @@ def _reshape_kv_cache(
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name] attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape( kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, num_blocks,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_shape = tuple(kv_cache_shape[i] kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
for i in kv_cache_stride_order)
inv_order = [ inv_order = [
kv_cache_stride_order.index(i) kv_cache_stride_order.index(i)
@ -132,8 +145,56 @@ def init_kv_cache(
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend], attn_backends: dict[str, AttentionBackend],
device: torch.device, device: torch.device,
): ) -> None:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches) bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_computed_tokens_cpu: torch.Tensor,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
seq_lens_cpu = seq_lens.cpu[:num_reqs]
max_seq_len = int(seq_lens.np[:num_reqs].max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@ -6,14 +6,13 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
PAD_SLOT_ID = -1 PAD_SLOT_ID = -1
class BlockTables: class BlockTables:
def __init__( def __init__(
self, self,
block_sizes: list[int], block_sizes: list[int],
@ -50,44 +49,48 @@ class BlockTables:
self.input_block_tables: list[torch.Tensor] = [ self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables torch.zeros_like(block_table) for block_table in self.block_tables
] ]
self.input_block_table_ptrs = self._make_ptr_tensor( self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.input_block_tables)
self.block_table_strides = torch.tensor( self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables], [b.stride(0) for b in self.block_tables],
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device,
self.block_sizes_tensor = torch.tensor(self.block_sizes, )
dtype=torch.int32, self.block_sizes_tensor = torch.tensor(
device=self.device) self.block_sizes, dtype=torch.int32, device=self.device
self.num_blocks = torch.zeros(self.num_kv_cache_groups, )
self.max_num_reqs, self.num_blocks = torch.zeros(
dtype=torch.int32, self.num_kv_cache_groups,
device=self.device) self.max_num_reqs,
self.slot_mappings = torch.zeros(self.num_kv_cache_groups, dtype=torch.int32,
self.max_num_batched_tokens, device=self.device,
dtype=torch.int64, )
device=self.device) self.slot_mappings = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
# Misc buffers. # Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs, self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool) self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups, self.cu_num_new_blocks = self._make_buffer(
self.max_num_reqs + 1, self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
dtype=torch.int32) )
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args, return CpuGpuBuffer(
dtype=dtype, *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
pin_memory=self.pin_memory, )
device=self.device)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x], ptrs_tensor_cpu = torch.tensor(
dtype=torch.int64, [t.data_ptr() for t in x],
device="cpu", dtype=torch.int64,
pin_memory=self.pin_memory) device="cpu",
pin_memory=self.pin_memory,
)
return ptrs_tensor_cpu.to(self.device, non_blocking=True) return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids( def append_block_ids(
@ -105,7 +108,7 @@ class BlockTables:
self.req_indices.np[:num_reqs] = req_indices self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups): for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i] self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's # NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound to the number of new blocks in a single step. # no clear upper bound to the number of new blocks in a single step.
@ -120,9 +123,8 @@ class BlockTables:
) )
new_block_ids_np = self.new_block_ids_cpu.numpy() new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups): for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i] new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)]( _append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs), self.req_indices.copy_to_gpu(num_reqs),
@ -135,7 +137,7 @@ class BlockTables:
self.block_table_ptrs, self.block_table_ptrs,
self.num_blocks, self.num_blocks,
self.num_blocks.stride(0), self.num_blocks.stride(0),
BLOCK_SIZE=1024, BLOCK_SIZE=1024, # type: ignore
) )
def gather_block_tables( def gather_block_tables(
@ -150,10 +152,9 @@ class BlockTables:
self.block_table_strides, self.block_table_strides,
self.num_blocks, self.num_blocks,
self.num_blocks.stride(0), self.num_blocks.stride(0),
BLOCK_SIZE=1024, BLOCK_SIZE=1024, # type: ignore
) )
return tuple(block_table[:num_reqs] return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
for block_table in self.input_block_tables)
def compute_slot_mappings( def compute_slot_mappings(
self, self,
@ -174,7 +175,7 @@ class BlockTables:
self.slot_mappings, self.slot_mappings,
self.slot_mappings.stride(0), self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID, PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024, BLOCK_SIZE=1024, # type: ignore
) )
return self.slot_mappings[:, :num_tokens] return self.slot_mappings[:, :num_tokens]
@ -201,8 +202,7 @@ def _append_block_ids_kernel(
req_idx = tl.load(req_indices + batch_idx) req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx) do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = (cu_num_new_blocks_ptr + group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
group_id * cu_num_new_blocks_stride)
start_idx = tl.load(group_new_blocks_ptr + batch_idx) start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1) end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx num_new_blocks = end_idx - start_idx
@ -220,15 +220,15 @@ def _append_block_ids_kernel(
block_table_stride = tl.load(block_table_strides + group_id) block_table_stride = tl.load(block_table_strides + group_id)
row_ptr = block_table_ptr + req_idx * block_table_stride row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = (new_block_ids_ptr + group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
group_id * new_block_ids_stride) for i in range(0, num_new_blocks, BLOCK_SIZE):
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE) offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset, block_ids = tl.load(
mask=offset < num_new_blocks) group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
tl.store(row_ptr + dst_start_idx + offset, )
block_ids, tl.store(
mask=offset < num_new_blocks) row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
)
@triton.jit @triton.jit
@ -282,11 +282,9 @@ def _compute_slot_mappings_kernel(
if req_idx == tl.num_programs(1) - 1: if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs. # Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE): for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE) offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
PAD_ID,
mask=offset < max_num_tokens)
return return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
@ -295,12 +293,13 @@ def _compute_slot_mappings_kernel(
start_idx = tl.load(cu_num_tokens + req_idx) start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1) end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in tl.range(start_idx, end_idx, BLOCK_SIZE): for i in range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE) offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0) positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size block_indices = positions // page_size
block_numbers = tl.load(block_table_ptr + block_numbers = tl.load(
req_idx * block_table_stride + block_indices) block_table_ptr + req_idx * block_table_stride + block_indices
)
slot_ids = block_numbers * page_size + positions % page_size slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)

View File

@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from contextlib import contextmanager
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
self.padded_sizes = self._init_padded_sizes()
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def _init_padded_sizes(self) -> dict[int, int]:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
# CUDA graphs are disabled.
return {}
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
raise NotImplementedError("Piecewise CUDA graphs are not supported")
if self.compilation_config.level != 0:
raise NotImplementedError("Dynamo is not used. Compilation level must be 0")
padded_sizes: dict[int, int] = {}
assert len(self.cudagraph_sizes) > 0
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
padded_sizes[i] = x
break
return padded_sizes
def needs_capture(self) -> bool:
return len(self.padded_sizes) > 0
def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None:
if max(scheduler_output.num_scheduled_tokens.values()) > 1:
# Prefill is included.
return None
return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens)
def capture_graph(
self,
batch_size: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert batch_size not in self.graphs
# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions.gpu[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
input_buffers.seq_lens.np[batch_size:] = 0
input_buffers.seq_lens.copy_to_gpu()
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=batch_size,
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
torch.cuda.synchronize()
# Capture the graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:batch_size] = hidden_states
self.graphs[batch_size] = graph
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert self.needs_capture()
# Capture larger graphs first.
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with freeze_gc(), graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)
def run(self, batch_size: int) -> torch.Tensor:
assert batch_size in self.graphs
self.graphs[batch_size].replay()
return self.hidden_states[:batch_size]
@contextmanager
def freeze_gc():
gc.collect()
gc.freeze()
try:
yield
finally:
gc.unfreeze()

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
@ -16,11 +15,12 @@ from vllm.v1.utils import CpuGpuBuffer
class InputBuffers: class InputBuffers:
def __init__( def __init__(
self, self,
max_num_reqs: int, max_num_reqs: int,
max_num_tokens: int, max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
device: torch.device, device: torch.device,
pin_memory: bool, pin_memory: bool,
): ):
@ -32,20 +32,17 @@ class InputBuffers:
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64) self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32) self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args, return CpuGpuBuffer(
dtype=dtype, *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
pin_memory=self.pin_memory, )
device=self.device)
@dataclass @dataclass
class InputBatch: class InputBatch:
# batch_idx -> req_id # batch_idx -> req_id
req_ids: list[str] req_ids: list[str]
num_reqs: int num_reqs: int
@ -54,17 +51,23 @@ class InputBatch:
idx_mapping: torch.Tensor idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray idx_mapping_np: np.ndarray
# [num_reqs]
# batch_idx -> num_scheduled_tokens # batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens) # sum(num_scheduled_tokens)
num_tokens: int num_tokens: int
num_tokens_after_padding: int num_tokens_after_padding: int
# [num_reqs]
is_chunked_prefilling: np.ndarray
# [max_num_batched_tokens] # [num_reqs + 1]
query_start_loc: torch.Tensor
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
seq_lens_np: np.ndarray
# [num_tokens_after_padding]
input_ids: torch.Tensor input_ids: torch.Tensor
# [max_num_batched_tokens] # [num_tokens_after_padding]
positions: torch.Tensor positions: torch.Tensor
# layer_name -> Metadata # layer_name -> Metadata
@ -78,23 +81,34 @@ class InputBatch:
cls, cls,
num_reqs: int, num_reqs: int,
num_tokens: int, num_tokens: int,
input_buffers: InputBuffers,
device: torch.device, device: torch.device,
) -> "InputBatch": ) -> "InputBatch":
assert 0 < num_reqs <= num_tokens assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)] req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32) idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.tensor(idx_mapping_np, device=device) idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_tokens // num_reqs,
dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs num_scheduled_tokens[-1] += num_tokens % num_reqs
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_) assert int(num_scheduled_tokens.sum()) == num_tokens
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device) input_buffers.query_start_loc.np[0] = 0
attn_metadata = defaultdict(lambda: None) input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
logits_indices = torch.arange(num_reqs, num_scheduled_tokens
dtype=torch.int32, )
device=device) input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
input_buffers.seq_lens.np[num_reqs:] = 0
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
positions = input_buffers.positions.copy_to_gpu(num_tokens)
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
return cls( return cls(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
@ -103,10 +117,13 @@ class InputBatch:
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_after_padding=num_tokens, num_tokens_after_padding=num_tokens,
is_chunked_prefilling=is_chunked_prefilling, query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
attn_metadata=attn_metadata, attn_metadata=None,
logits_indices=logits_indices, logits_indices=logits_indices,
) )
@ -130,14 +147,14 @@ class InputBatch:
cache=True, cache=True,
) )
def _prepare_inputs( def _prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len] token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N] num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B] num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens] input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens] positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1] query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B] seq_lens: np.ndarray, # [B]
) -> None: ) -> None:
num_reqs = num_scheduled_tokens.shape[0] num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0 query_start_loc[0] = 0
@ -161,14 +178,14 @@ def _prepare_inputs(
# Pad the inputs for CUDA graphs. # Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels # Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that # like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens) query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode. # Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0) seq_lens[num_reqs:].fill(0)
def prepare_inputs( def prepare_inputs(
idx_mapping: np.ndarray, idx_mapping: np.ndarray,
prompt_token_ids: np.ndarray, prefill_token_ids: np.ndarray,
num_computed_tokens: np.ndarray, num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
input_ids: CpuGpuBuffer, input_ids: CpuGpuBuffer,
@ -176,10 +193,10 @@ def prepare_inputs(
query_start_loc: CpuGpuBuffer, query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer, seq_lens: CpuGpuBuffer,
num_tokens: int, num_tokens: int,
) -> tuple[np.ndarray, np.ndarray]: ) -> None:
_prepare_inputs( _prepare_inputs(
idx_mapping, idx_mapping,
prompt_token_ids, prefill_token_ids,
num_computed_tokens, num_computed_tokens,
num_scheduled_tokens, num_scheduled_tokens,
input_ids.np, input_ids.np,
@ -194,11 +211,7 @@ def prepare_inputs(
# for full CUDA graph mode. # for full CUDA graph mode.
query_start_loc.copy_to_gpu() query_start_loc.copy_to_gpu()
seq_lens.copy_to_gpu() seq_lens.copy_to_gpu()
return
num_reqs = num_scheduled_tokens.shape[0]
max_query_len = int(num_scheduled_tokens.max())
max_seq_len = int(seq_lens.np[:num_reqs].max())
return max_query_len, max_seq_len
@triton.jit @triton.jit
@ -208,21 +221,18 @@ def _combine_last_token_ids_kernel(
last_token_ids_ptr, last_token_ids_ptr,
query_start_loc_ptr, query_start_loc_ptr,
seq_lens_ptr, seq_lens_ptr,
num_tokens_ptr, prefill_len_ptr,
): ):
batch_idx = tl.program_id(0) batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx) seq_len = tl.load(seq_lens_ptr + batch_idx)
num_tokens = tl.load(num_tokens_ptr + req_state_idx) prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len < num_tokens: if seq_len <= prefill_len:
# Chunked prefilling. # Handling prefill tokens.
return return
last_token_id = tl.load(last_token_ids_ptr + req_state_idx) last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
if last_token_id == -1:
return
end = tl.load(query_start_loc_ptr + batch_idx + 1) end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id) tl.store(input_ids_ptr + end - 1, last_token_id)
@ -233,15 +243,15 @@ def combine_last_token_ids(
last_token_ids: torch.Tensor, last_token_ids: torch.Tensor,
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
num_tokens: torch.Tensor, prefill_len: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = seq_lens.shape[0] num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs, )]( _combine_last_token_ids_kernel[(num_reqs,)](
input_ids, input_ids,
idx_mapping, idx_mapping,
last_token_ids, last_token_ids,
query_start_loc, query_start_loc,
seq_lens, seq_lens,
num_tokens, prefill_len,
) )
return input_ids return input_ids

View File

@ -3,41 +3,52 @@
import gc import gc
import time import time
from copy import deepcopy from copy import deepcopy
from typing import Any, Optional from typing import Any
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils.mem_constants import GiB_bytes
GiB_bytes, is_pin_memory_available) from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.sample.sampler import SamplerOutput from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec, from vllm.v1.worker.gpu.attn_utils import (
init_attn_backend, init_kv_cache) build_attn_metadata,
get_kv_cache_spec,
init_attn_backend,
init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output, from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
evenly_split) from vllm.v1.worker.gpu.input_batch import (
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, InputBatch,
combine_last_token_ids, InputBuffers,
prepare_inputs) combine_last_token_ids,
from vllm.v1.worker.gpu.sampler import Sampler prepare_inputs,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
logger = init_logger(__name__) logger = init_logger(__name__)
class GPUModelRunner: class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
@ -61,17 +72,19 @@ class GPUModelRunner:
if self.cache_config.cache_dtype != "auto": if self.cache_config.cache_dtype != "auto":
# Quantized KV cache. # Quantized KV cache.
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype] self.cache_config.cache_dtype
]
self.is_pooling_model = False self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size() self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.hidden_size = self.model_config.get_hidden_size()
self.use_async_scheduling = self.scheduler_config.async_scheduling self.use_async_scheduling = self.scheduler_config.async_scheduling
assert self.use_async_scheduling self.output_copy_stream = torch.cuda.Stream(self.device)
self.output_copy_stream = torch.cuda.Stream() self.input_prep_event = torch.cuda.Event()
self.req_states = RequestState( self.req_states = RequestState(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
@ -84,29 +97,46 @@ class GPUModelRunner:
self.input_buffers = InputBuffers( self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens, max_num_tokens=self.max_num_tokens,
hidden_size=self.hidden_size,
dtype=self.dtype,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
self.sampler = Sampler() self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
vllm_config=self.vllm_config,
device=self.device,
)
def get_supported_tasks(self) -> tuple[str]: def get_supported_tasks(self) -> tuple[str]:
return ("generate", ) return ("generate",)
def load_model(self, *args, **kwargs) -> None: def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(self.vllm_config.load_config) model_loader = get_model_loader(self.vllm_config.load_config)
logger.info("Loading model from scratch...") logger.info("Loading model from scratch...")
self.model = model_loader.load_model( self.model = model_loader.load_model(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config, model_config=self.vllm_config.model_config,
) )
if self.lora_config:
self.model = self.load_lora_model(
self.model,
self.vllm_config,
self.device,
)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds", logger.info(
m.consumed_memory / GiB_bytes, "Model loading took %.4f GiB and %.6f seconds",
time_after_load - time_before_load) m.consumed_memory / GiB_bytes,
time_after_load - time_before_load,
)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
@ -143,32 +173,60 @@ class GPUModelRunner:
self.device, self.device,
) )
@torch.inference_mode()
def _dummy_run( def _dummy_run(
self, self,
num_tokens: int, num_tokens: int,
*args, *args,
input_batch: Optional[InputBatch] = None, input_batch: InputBatch | None = None,
skip_attn: bool = True,
**kwargs, **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if input_batch is None: if input_batch is None:
num_reqs = min(num_tokens, self.max_num_reqs)
input_batch = InputBatch.make_dummy( input_batch = InputBatch.make_dummy(
num_reqs=min(num_tokens, self.max_num_reqs), num_reqs=num_reqs,
num_tokens=num_tokens, num_tokens=num_tokens,
input_buffers=self.input_buffers,
device=self.device, device=self.device,
) )
if not skip_attn:
block_tables = self.block_tables.gather_block_tables(
input_batch.idx_mapping
)
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.query_start_loc,
input_batch.positions,
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=None,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_batch.attn_metadata = attn_metadata
with set_forward_context( with self.maybe_dummy_run_with_lora(
self.lora_config, input_batch.num_scheduled_tokens
):
with set_forward_context(
input_batch.attn_metadata, input_batch.attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens,
): ):
hidden_states = self.model( hidden_states = self.model(
input_ids=input_batch.input_ids, input_ids=input_batch.input_ids,
positions=input_batch.positions, positions=input_batch.positions,
) )
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states return hidden_states, sample_hidden_states
@torch.inference_mode()
def _dummy_sampler_run( def _dummy_sampler_run(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -179,35 +237,80 @@ class GPUModelRunner:
device=self.device, device=self.device,
) )
logits = self.model.compute_logits(hidden_states) logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata) self.sampler.sample(logits, sampling_metadata)
@torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
input_batch = InputBatch.make_dummy( input_batch = InputBatch.make_dummy(
num_reqs=self.max_num_reqs, num_reqs=self.max_num_reqs,
num_tokens=self.max_num_tokens, num_tokens=self.max_num_tokens,
input_buffers=self.input_buffers,
device=self.device, device=self.device,
) )
hidden_states, sample_hidden_states = self._dummy_run( hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, self.max_num_tokens,
input_batch=input_batch, input_batch=input_batch,
skip_attn=True,
) )
self._dummy_sampler_run(sample_hidden_states) self._dummy_sampler_run(sample_hidden_states)
torch.cuda.synchronize() torch.cuda.synchronize()
del hidden_states, sample_hidden_states del hidden_states, sample_hidden_states
gc.collect() gc.collect()
def reset_mm_cache(self) -> None:
pass
@torch.inference_mode()
def capture_model(self) -> int:
if not self.cudagraph_manager.needs_capture():
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return 0
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
self.cudagraph_manager.capture(
model=self.model,
input_buffers=self.input_buffers,
block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info(
"Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time,
cuda_graph_size / (1 << 30),
)
return cuda_graph_size
def warmup_for_prefill(self) -> None:
# For FlashInfer, we would like to execute a dummy prefill run to trigger JIT compilation.
if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
self._dummy_run(self.max_num_tokens, skip_attn=False)
torch.cuda.synchronize()
def update_states(self, scheduler_output: SchedulerOutput) -> None: def update_states(self, scheduler_output: SchedulerOutput) -> None:
# for req_id in scheduler_output.preempted_req_ids: for req_id in scheduler_output.preempted_req_ids:
# self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
# TODO(woosuk): Change SchedulerOutput. # TODO(woosuk): Change SchedulerOutput.
req_indices: list[int] = [] req_indices: list[int] = []
cu_num_new_blocks = tuple( cu_num_new_blocks = tuple(
[0] for _ in range(self.block_tables.num_kv_cache_groups)) [0] for _ in range(self.block_tables.num_kv_cache_groups)
new_block_ids = tuple( )
[] for _ in range(self.block_tables.num_kv_cache_groups)) new_block_ids = tuple([] for _ in range(self.block_tables.num_kv_cache_groups))
overwrite: list[bool] = [] overwrite: list[bool] = []
# Add new requests. # Add new requests.
@ -215,9 +318,11 @@ class GPUModelRunner:
req_id = new_req_data.req_id req_id = new_req_data.req_id
self.req_states.add_request( self.req_states.add_request(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_len=len(new_req_data.prompt_token_ids),
prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params, sampling_params=new_req_data.sampling_params,
lora_request=new_req_data.lora_request,
) )
req_index = self.req_states.req_id_to_index[req_id] req_index = self.req_states.req_id_to_index[req_id]
@ -250,21 +355,30 @@ class GPUModelRunner:
overwrite=overwrite, overwrite=overwrite,
) )
def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch: def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
use_cudagraph: bool,
padded_num_tokens: int | None,
) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0 assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens) num_reqs = len(scheduler_output.num_scheduled_tokens)
# Decode first, then prefill. # Decode first, then prefill.
# batch_idx -> req_id # batch_idx -> req_id
req_ids = sorted(scheduler_output.num_scheduled_tokens, req_ids = sorted(
key=scheduler_output.num_scheduled_tokens.get) scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get,
)
num_scheduled_tokens = np.array( num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], [scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
dtype=np.int32) )
if use_cudagraph:
# TODO(woosuk): Support CUDA graphs. assert padded_num_tokens is not None
num_tokens_after_padding = num_tokens num_tokens_after_padding = padded_num_tokens
else:
num_tokens_after_padding = num_tokens
idx_mapping_list = [ idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids self.req_states.req_id_to_index[req_id] for req_id in req_ids
@ -277,9 +391,9 @@ class GPUModelRunner:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping) block_tables = self.block_tables.gather_block_tables(idx_mapping)
max_query_len, max_seq_len = prepare_inputs( prepare_inputs(
idx_mapping_np, idx_mapping_np,
self.req_states.prompt_token_ids, self.req_states.prefill_token_ids,
self.req_states.num_computed_tokens, self.req_states.num_computed_tokens,
num_scheduled_tokens, num_scheduled_tokens,
self.input_buffers.input_ids, self.input_buffers.input_ids,
@ -290,10 +404,9 @@ class GPUModelRunner:
) )
query_start_loc = self.input_buffers.query_start_loc query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1] query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1] query_start_loc_np = query_start_loc.np[: num_reqs + 1]
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs] seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
seq_lens_cpu = self.input_buffers.seq_lens.cpu[:num_reqs]
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs] seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
# Some input token ids are directly read from the last sampled tokens. # Some input token ids are directly read from the last sampled tokens.
@ -303,56 +416,33 @@ class GPUModelRunner:
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
query_start_loc_gpu, query_start_loc_gpu,
seq_lens_gpu, seq_lens_gpu,
self.req_states.num_tokens.copy_to_gpu(), self.req_states.prefill_len.copy_to_gpu(),
) )
# Compute slot mappings: [num_kv_cache_groups, num_tokens] # Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]) query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
)
num_computed_tokens_cpu = torch.from_numpy( num_computed_tokens_cpu = torch.from_numpy(
self.req_states.num_computed_tokens[idx_mapping_np]) self.req_states.num_computed_tokens[idx_mapping_np]
)
# Whether the request is chunked-prefilling or not.
is_chunked_prefilling = (
seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np])
# Logits indices to sample next token from. # Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1 logits_indices = query_start_loc_gpu[1:] - 1
num_logits_indices = logits_indices.size(0)
# Layer name -> attention metadata. # Layer name -> attention metadata.
attn_metadata: dict[str, Any] = {} attn_metadata = build_attn_metadata(
kv_cache_groups = self.kv_cache_config.kv_cache_groups attn_metadata_builders=self.attn_metadata_builders,
for i, kv_cache_spec in enumerate(kv_cache_groups): num_reqs=num_reqs,
block_table = block_tables[i] num_tokens=num_tokens,
slot_mapping = slot_mappings[i] query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
common_attn_metadata = CommonAttentionMetadata( num_computed_tokens_cpu=num_computed_tokens_cpu,
query_start_loc=query_start_loc_gpu, block_tables=block_tables,
query_start_loc_cpu=query_start_loc_cpu, slot_mappings=slot_mappings,
seq_lens=seq_lens_gpu, kv_cache_config=self.kv_cache_config,
seq_lens_cpu=seq_lens_cpu, )
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
logits_indices_padded=None,
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=None,
)
attn_metadata_builder = self.attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding] positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
@ -364,7 +454,10 @@ class GPUModelRunner:
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding, num_tokens_after_padding=num_tokens_after_padding,
is_chunked_prefilling=is_chunked_prefilling, query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens_gpu,
seq_lens_np=seq_lens_np,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
@ -375,102 +468,221 @@ class GPUModelRunner:
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_batch: InputBatch, input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
pos = input_batch.positions[input_batch.logits_indices] sampler_output = self.sampler.sample(logits, sampling_metadata)
idx_mapping_np = input_batch.idx_mapping_np
num_reqs = logits.shape[0]
# When the batch size is large enough, use DP sampler.
tp_group = get_tp_group()
tp_size = tp_group.world_size
n = (num_reqs + tp_size - 1) // tp_size
use_dp_sampler = tp_size > 1 and n > 32 # TODO(woosuk): Tune.
if use_dp_sampler:
# NOTE(woosuk): Make sure that no rank gets zero requests.
tp_rank = tp_group.rank
start, end = evenly_split(num_reqs, tp_size, tp_rank)
logits = logits[start:end]
pos = pos[start:end]
idx_mapping_np = idx_mapping_np[start:end]
sampling_metadata = self.req_states.make_sampling_metadata(
idx_mapping_np, pos)
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
needs_prompt_logprobs = np.any(
self.req_states.needs_prompt_logprobs[idx_mapping_np])
assert not needs_prompt_logprobs
if use_dp_sampler:
# All-gather the outputs.
sampler_output = all_gather_sampler_output(
sampler_output,
num_reqs,
tp_size,
)
return sampler_output return sampler_output
def compute_prompt_logprobs(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# No request asks for prompt logprobs.
return {}
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
includes_prompt = num_computed_tokens < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = (
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
)
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Just to be safe, clone the input ids.
n = input_batch.num_tokens
# Shift the input ids by one.
token_ids = torch.empty_like(input_batch.input_ids[:n])
token_ids[: n - 1] = input_batch.input_ids[1:n]
# To avoid out-of-bound access, set the last token id to 0.
token_ids[n - 1] = 0
# Handle chunked prompts.
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
is_prompt_chunked = seq_lens < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids
query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
if not is_prompt_chunked[i]:
continue
# The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i]
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
idx = int(query_start_loc[i + 1] - 1)
# Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation.
token_ids[idx] = next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
torch.relu(token_ids),
hidden_states[:n],
self.model.compute_logits,
)
prompt_logprobs[:, 0].masked_fill_(token_ids < 0, 0)
prompt_token_ids = token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc[i]
end_idx = query_start_loc[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
req_extra_data = self.req_states.extra_data[req_id]
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
def postprocess( def postprocess(
self, self,
sampler_output: SamplerOutput, sampler_output: SamplerOutput,
sampling_metadata: SamplingMetadata,
prompt_logprobs_dict: dict[str, LogprobsTensors],
input_batch: InputBatch, input_batch: InputBatch,
) -> AsyncOutput: ) -> AsyncOutput | ModelRunnerOutput:
# Store the last sampled token ids. # Store the last sampled token ids.
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = ( self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
sampler_output.sampled_token_ids) sampler_output.sampled_token_ids
)
# Get the number of sampled tokens. # Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not. # 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling idx_mapping_np = input_batch.idx_mapping_np
is_chunked_prefilling = (
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
)
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens. # Increment the number of tokens.
idx_mapping_np = input_batch.idx_mapping_np self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens
# Increment the number of computed tokens. # Increment the number of computed tokens.
self.req_states.num_computed_tokens[idx_mapping_np] += ( self.req_states.num_computed_tokens[idx_mapping_np] += (
input_batch.num_scheduled_tokens) input_batch.num_scheduled_tokens
)
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids, req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None, sampled_token_ids=None,
num_sampled_tokens=num_sampled_tokens,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[], pooler_output=[],
kv_connector_output=None, kv_connector_output=None,
num_nans_in_logits=None, num_nans_in_logits=None,
) )
return AsyncOutput( async_output = AsyncOutput(
model_runner_output=model_runner_output, model_runner_output=model_runner_output,
sampler_output=sampler_output, sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
copy_stream=self.output_copy_stream, copy_stream=self.output_copy_stream,
) )
if self.use_async_scheduling:
return async_output
return async_output.get_output()
@torch.inference_mode()
def execute_model( def execute_model(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
) -> AsyncOutput: intermediate_tensors: Any | None = None,
self.update_states(scheduler_output) ) -> AsyncOutput | ModelRunnerOutput:
if scheduler_output.total_num_scheduled_tokens == 0: assert intermediate_tensors is None
return EMPTY_MODEL_RUNNER_OUTPUT
input_batch = self.prepare_inputs(scheduler_output) with async_barrier(
num_tokens = input_batch.num_tokens_after_padding self.input_prep_event if self.use_async_scheduling else None
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
): ):
hidden_states = self.model( self.update_states(scheduler_output)
input_ids=input_batch.input_ids, if scheduler_output.total_num_scheduled_tokens == 0:
positions=input_batch.positions, return EMPTY_MODEL_RUNNER_OUTPUT
padded_num_tokens = self.cudagraph_manager.get_cudagraph_size(
scheduler_output
)
use_cudagraph = padded_num_tokens is not None
input_batch = self.prepare_inputs(
scheduler_output,
use_cudagraph,
padded_num_tokens,
)
pos = input_batch.positions[input_batch.logits_indices]
idx_mapping_np = input_batch.idx_mapping_np
sampling_metadata = self.req_states.make_sampling_metadata(
idx_mapping_np, pos
) )
sampler_output = self.sample(hidden_states, input_batch) if self.lora_config:
return self.postprocess(sampler_output, input_batch) # Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
input_batch.req_ids,
input_batch.idx_mapping_np,
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
# Run model.
if use_cudagraph:
# Run CUDA graph.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
hidden_states = self.cudagraph_manager.run(padded_num_tokens)
else:
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
):
# Run PyTorch model in eager mode.
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sampler_output = self.sample(hidden_states, input_batch, sampling_metadata)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
output = self.postprocess(
sampler_output,
sampling_metadata,
prompt_logprobs_dict,
input_batch,
)
return output

View File

@ -1,61 +1,76 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
import torch.nn as nn
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.config import LogprobsMode from vllm.config.model import LogprobsMode
from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.states import SamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler:
def __init__( def __init__(
self, self,
logprobs_mode: LogprobsMode = "processed_logprobs", logprobs_mode: LogprobsMode = "raw_logprobs",
): ):
super().__init__() if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
assert logprobs_mode == "processed_logprobs" raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode self.logprobs_mode = logprobs_mode
def forward( def sample_token(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: return_logits: bool = False,
# Divide logits by temperature, in FP32. ) -> tuple[torch.Tensor, torch.Tensor | None]:
logits = apply_temperature(logits, sampling_metadata.temperature) is_greedy = sampling_metadata.temperature == 0
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
# Apply top_k and/or top_p. logits = logits / temp.view(-1, 1)
logits = apply_top_k_top_p( logits = apply_top_k_top_p(
logits, logits, sampling_metadata.top_k, sampling_metadata.top_p
sampling_metadata.top_k,
sampling_metadata.top_p,
) )
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32) probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# Sample the next token (int64).
sampled = gumbel_sample( sampled = gumbel_sample(
probs, probs,
sampling_metadata.temperature, sampling_metadata.temperature,
sampling_metadata.seeds, sampling_metadata.seeds,
sampling_metadata.pos, sampling_metadata.pos,
) )
sampled = sampled.to(torch.int64)
return sampled, logits if return_logits else None
logprobs_tensors = None def sample(
num_logprobs = sampling_metadata.max_num_logprobs self,
if num_logprobs is not None: logits: torch.Tensor,
logprobs_tensors = compute_logprobs( sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
sampled, logits = self.sample_token(
logits, sampling_metadata, return_logits=True
)
else:
assert self.logprobs_mode == "raw_logprobs"
sampled, _ = self.sample_token(
logits, sampling_metadata, return_logits=False
)
logprobs_tensors = compute_topk_logprobs(
logits, logits,
num_logprobs, sampling_metadata.max_num_logprobs,
sampled, sampled,
) )
else:
sampled, _ = self.sample_token(
logits, sampling_metadata, return_logits=False
)
logprobs_tensors = None
# These are GPU tensors. # These are GPU tensors.
sampler_output = SamplerOutput( sampler_output = SamplerOutput(
@ -69,60 +84,7 @@ class Sampler(nn.Module):
@triton.jit @triton.jit
def _apply_temp_kernel( def _gumbel_sample_kernel(
logits, # bf16[batch_size, vocab_size]
logits_stride,
output, # fp32[batch_size, vocab_size]
output_stride,
temperature,
vocab_size,
BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
temp = tl.load(temperature + batch_idx)
if temp < EPSILON:
# Greedy sampling. Don't apply temperature.
# NOTE(woosuk): In this case, we assume that its logprobs are not used.
temp = 1.0
offset = tl.arange(0, BLOCK_SIZE)
block = block_idx * BLOCK_SIZE + offset
# Load the logits.
x = tl.load(logits + batch_idx * logits_stride + block,
mask=block < vocab_size)
x = x.to(tl.float32)
x = x / temp
tl.store(output + batch_idx * output_stride + block,
x,
mask=block < vocab_size)
def apply_temperature(
logits: torch.Tensor,
temperature: torch.Tensor,
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
output = torch.empty_like(logits, dtype=torch.float32)
BLOCK_SIZE = 8192
_apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))](
logits,
logits.stride(0),
output,
output.stride(0),
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
EPSILON=_SAMPLING_EPS,
)
return output
@triton.jit
def _apply_gumbel_kernel(
probs_ptr, probs_ptr,
probs_stride, probs_stride,
seeds_ptr, seeds_ptr,
@ -130,18 +92,17 @@ def _apply_gumbel_kernel(
temp_ptr, temp_ptr,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
temp = tl.load(temp_ptr + req_idx) temp = tl.load(temp_ptr + req_idx)
if temp < EPSILON: if temp == 0.0:
# Greedy sampling. Don't apply gumbel noise. # Greedy sampling. Don't apply gumbel noise.
return return
seed = tl.load(seeds_ptr + req_idx).to(tl.uint64) seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx).to(tl.uint64) pos = tl.load(pos_ptr + req_idx)
gumbel_seed = seed ^ (pos * 0x9E3779B97F4A7C15) gumbel_seed = tl.randint(seed, pos)
block_id = tl.program_id(1) block_id = tl.program_id(1)
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@ -153,42 +114,33 @@ def _apply_gumbel_kernel(
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q)) q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
q = -1.0 * q q = -1.0 * q
p = tl.load(probs_ptr + req_idx * probs_stride + r_offset, p = tl.load(
mask=r_offset < vocab_size) probs_ptr + req_idx * probs_stride + r_offset, mask=r_offset < vocab_size
)
p = p / q p = p / q
tl.store(
tl.store(probs_ptr + req_idx * probs_stride + r_offset, probs_ptr + req_idx * probs_stride + r_offset, p, mask=r_offset < vocab_size
p, )
mask=r_offset < vocab_size)
def gumbel_sample( def gumbel_sample(
# fp32[num_reqs, vocab_size] probs: torch.Tensor, # [num_reqs, vocab_size]
probs: torch.Tensor, temperature: torch.Tensor, # [num_reqs]
# fp32[num_reqs] seed: torch.Tensor, # [num_reqs]
temperature: torch.Tensor, pos: torch.Tensor, # [num_reqs]
# int64[num_reqs]
seeds: torch.Tensor,
# int64[num_reqs]
pos: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = probs.shape[0] num_reqs, vocab_size = probs.shape
vocab_size = probs.shape[1] _gumbel_sample_kernel[(num_reqs,)](
# Update the probs in-place.
BLOCK_SIZE = 8192
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
probs, probs,
probs.stride(0), probs.stride(0),
seeds, seed,
pos, pos,
temperature, temperature,
vocab_size, vocab_size,
BLOCK_SIZE, BLOCK_SIZE=8192, # type: ignore
EPSILON=_SAMPLING_EPS,
) )
# Sample the next token. sampled = probs.argmax(dim=-1)
return probs.argmax(dim=-1).view(-1) return sampled
@triton.jit @triton.jit
@ -208,54 +160,31 @@ def _topk_log_softmax_kernel(
max_val = float("-inf") max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE): for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE) block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block, logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
mask=block < vocab_size, max_val = tl.max(tl.maximum(logits, max_val))
other=float("-inf")) max_val = max_val.to(tl.float32)
max_val = tl.max(tl.maximum(l, max_val))
se = 0.0 se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE): for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE) block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
e = tl.exp(l - max_val) # NOTE(woosuk): Make sure that logits and all following operations are in float32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0) e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e) se += tl.sum(e)
lse = tl.log(se) lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK) k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask) topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
l = tl.load(row_ptr + topk_ids, mask=k_mask) logits = tl.load(row_ptr + topk_ids, mask=k_mask)
o = l - max_val - lse logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
def compute_topk_logprobs(
logits: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
topk = topk_ids.shape[1]
output = torch.empty(
batch_size,
topk,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size, )](
output,
logits,
logits.stride(0),
topk_ids,
topk,
vocab_size,
BLOCK_SIZE=1024,
PADDED_TOPK=triton.next_power_of_2(topk),
)
return output
@triton.jit @triton.jit
def _ranks_kernel( def _ranks_kernel(
output_ptr, output_ptr,
@ -274,14 +203,39 @@ def _ranks_kernel(
n = 0 n = 0
for i in range(0, vocab_size, BLOCK_SIZE): for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE) block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block, logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
mask=block < vocab_size, n += tl.sum((logits > x).to(tl.int32))
other=float("-inf"))
n += tl.sum((l > x).to(tl.int32))
tl.store(output_ptr + req_idx, n) tl.store(output_ptr + req_idx, n)
def compute_logprobs( def compute_token_logprobs(
logits: torch.Tensor,
token_ids: torch.Tensor,
) -> torch.Tensor:
batch_size = logits.shape[0]
vocab_size = logits.shape[1]
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = torch.empty(
batch_size,
num_logprobs,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
)
return logprobs
def compute_topk_logprobs(
logits: torch.Tensor, logits: torch.Tensor,
num_logprobs: int, num_logprobs: int,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
@ -293,31 +247,56 @@ def compute_logprobs(
else: else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat( logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1) (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of # logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens. # the topk + 1 tokens.
logprobs = compute_topk_logprobs( logprobs = compute_token_logprobs(logits, logprob_token_ids)
logits,
logprob_token_ids,
)
token_ranks = torch.empty( token_ranks = torch.empty(
batch_size, batch_size,
dtype=torch.int64, dtype=torch.int64,
device=logits.device, device=logits.device,
) )
_ranks_kernel[(batch_size, )]( _ranks_kernel[(batch_size,)](
token_ranks, token_ranks,
logits, logits,
logits.stride(0), logits.stride(0),
sampled_token_ids, sampled_token_ids,
vocab_size, vocab_size,
BLOCK_SIZE=8192, BLOCK_SIZE=8192, # type: ignore
) )
return LogprobsTensors( return LogprobsTensors(
logprob_token_ids=logprob_token_ids, logprob_token_ids=logprob_token_ids,
logprobs=logprobs, logprobs=logprobs,
selected_token_ranks=token_ranks, selected_token_ranks=token_ranks,
) )
def compute_prompt_logprobs(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@ -1,21 +1,22 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional
import numpy as np import numpy as np
import torch import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
_NP_INT64_MIN = np.iinfo(np.int64).min _NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max _NP_INT64_MAX = np.iinfo(np.int64).max
NO_LORA_ID = 0
@dataclass @dataclass
class SamplingMetadata: class SamplingMetadata:
temperature: torch.Tensor temperature: torch.Tensor
top_p: torch.Tensor | None top_p: torch.Tensor | None
@ -36,12 +37,14 @@ class SamplingMetadata:
assert num_reqs > 0 assert num_reqs > 0
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5 temperature[0] = 0.5
top_p = torch.ones(num_reqs, dtype=torch.float32, device=device) # TODO(woosuk): Use top-p and top-k for dummy sampler.
top_p[0] = 0.99 # Currently, they are disabled because of memory usage.
top_k = torch.ones(num_reqs, dtype=torch.int32, device=device) top_p = None
top_k = None
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20 max_num_logprobs = 20
return cls( return cls(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@ -53,7 +56,6 @@ class SamplingMetadata:
class RequestState: class RequestState:
def __init__( def __init__(
self, self,
max_num_reqs: int, max_num_reqs: int,
@ -73,15 +75,15 @@ class RequestState:
self.req_id_to_index: dict[str, int] = {} self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {} self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs)) self.free_indices = list(range(max_num_reqs))
self.extra_data: dict[str, ExtraData] = {}
# NOTE(woosuk): Strictly speaking, it contains prompt + some output self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
# because of preemption. self.prefill_token_ids = np.zeros(
self.prompt_token_ids = np.zeros(
(self.max_num_reqs, self.max_model_len), (self.max_num_reqs, self.max_model_len),
dtype=np.int32, dtype=np.int32,
) )
self.num_tokens = self._make_buffer(self.max_num_reqs, self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
dtype=torch.int32) self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled tokens. # Last sampled tokens.
@ -92,6 +94,10 @@ class RequestState:
device=device, device=device,
) )
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters. # Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32) self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32) self.top_p = self._make_param(self.max_num_reqs, torch.float32)
@ -104,16 +110,12 @@ class RequestState:
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param": def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size, return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer: def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(size, return CpuGpuBuffer(
dtype=dtype, size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
device=self.device, )
pin_memory=self.pin_memory)
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
@ -122,23 +124,32 @@ class RequestState:
def add_request( def add_request(
self, self,
req_id: str, req_id: str,
prompt_token_ids: list[int], prompt_len: int,
prefill_token_ids: list[int],
num_computed_tokens: int, num_computed_tokens: int,
sampling_params: SamplingParams, sampling_params: SamplingParams,
lora_request: LoRARequest | None,
) -> None: ) -> None:
assert len(self.free_indices) > 0 assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop() req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id self.index_to_req_id[req_idx] = req_id
self.extra_data[req_id] = ExtraData(lora_request)
# NOTE(woosuk): Strictly speaking, "prompt_len" here may include self.prompt_len[req_idx] = prompt_len
# output tokens, if the request is resumed from preemption. prefill_len = len(prefill_token_ids)
prompt_len = len(prompt_token_ids) assert prefill_len >= prompt_len, (
self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids f"prefill_len {prefill_len} < prompt_len {prompt_len}"
self.num_tokens.np[req_idx] = prompt_len )
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
self.num_tokens[req_idx] = prefill_len
self.num_computed_tokens[req_idx] = num_computed_tokens self.num_computed_tokens[req_idx] = num_computed_tokens
# TODO(woosuk): Optimize.
self.last_sampled_tokens[req_idx].fill_(-1) if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id
else:
self.lora_ids[req_idx] = NO_LORA_ID
self.temperature.np[req_idx] = sampling_params.temperature self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p self.top_p.np[req_idx] = sampling_params.top_p
@ -165,6 +176,7 @@ class RequestState:
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def remove_request(self, req_id: str) -> None: def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None) req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None: if req_idx is None:
# Request not found. # Request not found.
@ -205,9 +217,25 @@ class RequestState:
max_num_logprobs=max_num_logprobs, max_num_logprobs=max_num_logprobs,
) )
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.extra_data[req_id].lora_request
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
class Param: class Param:
def __init__( def __init__(
self, self,
size: int, size: int,
@ -227,3 +255,9 @@ class Param:
n = x.shape[0] n = x.shape[0]
self.buffer.np[:n] = x self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n) return self.buffer.copy_to_gpu(n)
@dataclass
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)

View File

@ -42,6 +42,7 @@ from vllm.v1.outputs import (
ModelRunnerOutput, ModelRunnerOutput,
) )
from vllm.v1.utils import report_usage_stats from vllm.v1.utils import report_usage_stats
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner # from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu.model_runner import GPUModelRunner from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.utils import is_residual_scattered_for_sp
@ -495,6 +496,8 @@ class Worker(WorkerBase):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
return self.model_runner.execute_model(scheduler_output)
intermediate_tensors = None intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0 forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens