This commit is contained in:
Woosuk Kwon
2025-10-30 16:30:06 -07:00
parent 110770170f
commit 09e4b2f6eb
13 changed files with 1001 additions and 502 deletions

View File

@ -10,6 +10,7 @@ torchaudio==2.9.0
# These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.4.1
apache-tvm-ffi==0.1.0b15

View File

@ -34,6 +34,7 @@ else:
class NewRequestData:
req_id: str
prompt_token_ids: list[int] | None
prefill_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
@ -51,6 +52,7 @@ class NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
prefill_token_ids=request._all_token_ids,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
@ -173,6 +175,7 @@ class SchedulerOutput:
# This can be used for cascade attention.
num_common_prefix_blocks: list[int]
preempted_req_ids: set[str]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.

View File

@ -606,6 +606,9 @@ class Scheduler(SchedulerInterface):
)
# Construct the scheduler output.
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
scheduled_resumed_reqs = []
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids()
@ -635,6 +638,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids={req.request_id for req in preempted_reqs},
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
@ -720,14 +724,6 @@ class Scheduler(SchedulerInterface):
req.num_computed_tokens : req.num_computed_tokens + num_tokens
]
new_token_ids.append(token_ids)
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
if idx >= num_running_reqs:
assert not scheduled_in_prev_step
resumed_req_ids.add(req_id)
if not scheduled_in_prev_step:
all_token_ids[req_id] = req.all_token_ids[
: req.num_computed_tokens + num_tokens
]
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
)
@ -902,7 +898,6 @@ class Scheduler(SchedulerInterface):
model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids
num_sampled_tokens = model_runner_output.num_sampled_tokens
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens

View File

@ -15,7 +15,6 @@ else:
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: np.ndarray
# [num_reqs, max_num_logprobs + 1]
@ -135,13 +134,14 @@ class KVConnectorOutput:
class ModelRunnerOutput:
# [num_reqs]
req_ids: list[str]
# req_id -> index
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: np.ndarray | None
num_sampled_tokens: np.ndarray | None
sampled_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
@ -186,8 +186,8 @@ class DraftTokenIds:
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
sampled_token_ids=None,
num_sampled_tokens=None,
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],

View File

@ -1,21 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import numpy as np
import torch
from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
ModelRunnerOutput,
SamplerOutput,
)
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: np.ndarray,
copy_stream: torch.cuda.Stream,
):
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
self.copy_stream = copy_stream
self.copy_event = torch.cuda.Event()
@ -23,26 +30,46 @@ class AsyncOutput(AsyncModelRunnerOutput):
with torch.cuda.stream(self.copy_stream):
self.copy_stream.wait_stream(default_stream)
# NOTE(woosuk): We should keep the CPU tensors unfreed, until the copy completes.
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
"cpu", non_blocking=True)
x = sampler_output.logprobs_tensors
if x is not None:
self.logprobs_tensors = LogprobsTensors(
logprob_token_ids=x.logprob_token_ids.to(
"cpu", non_blocking=True),
logprobs=x.logprobs.to("cpu", non_blocking=True),
selected_token_ranks=x.selected_token_ranks.to(
"cpu", non_blocking=True),
"cpu", non_blocking=True
)
if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors = (
sampler_output.logprobs_tensors.to_cpu_nonblocking()
)
else:
self.logprobs_tensors = None
self.copy_event.record()
self.prompt_logprobs_dict = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
self.copy_event.record(self.copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
self.model_runner_output.sampled_token_ids = (
self.sampled_token_ids.numpy())
# NOTE(woosuk): The following code ensures compatibility with OSS vLLM.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids_np = self.sampled_token_ids.numpy()
sampled_token_ids = sampled_token_ids_np.tolist()
for i, tokens in enumerate(sampled_token_ids):
del tokens[self.num_sampled_tokens[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = (
self.logprobs_tensors.tolists())
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
return self.model_runner_output
@contextmanager
def async_barrier(event: torch.cuda.Event | None):
if event is not None:
event.synchronize()
try:
yield
finally:
if event is not None:
event.record()

View File

@ -7,9 +7,17 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, SlidingWindowSpec)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
SlidingWindowSpec,
)
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.utils import bind_kv_cache
@ -18,7 +26,6 @@ def get_kv_cache_spec(
kv_cache_dtype: torch.dtype,
) -> dict[str, KVCacheSpec]:
block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
@ -31,7 +38,6 @@ def get_kv_cache_spec(
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
@ -39,7 +45,6 @@ def get_kv_cache_spec(
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
use_mla=use_mla,
)
return kv_cache_spec
@ -52,6 +57,7 @@ def init_attn_backend(
attn_backends: dict[str, AttentionBackend] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
flashinfer_workspace: torch.Tensor | None = None
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
@ -67,7 +73,13 @@ def init_attn_backend(
vllm_config,
device,
)
attn_metadata_builders.append(attn_metadata_builder)
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
if "FLASHINFER" in attn_backend.get_name():
if flashinfer_workspace is None:
flashinfer_workspace = attn_metadata_builder.get_workspace_buffer()
else:
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
return attn_backends, attn_metadata_builders
@ -77,9 +89,7 @@ def _allocate_kv_cache(
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=device)
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
@ -87,8 +97,9 @@ def _allocate_kv_cache(
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()
), "Some layers are not correctly initialized"
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
@ -103,17 +114,19 @@ def _reshape_kv_cache(
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes)
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
dtype = kv_cache_spec.dtype
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
@ -132,8 +145,56 @@ def init_kv_cache(
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
):
) -> None:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors,
attn_backends)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_computed_tokens_cpu: torch.Tensor,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
seq_lens_cpu = seq_lens.cpu[:num_reqs]
max_seq_len = int(seq_lens.np[:num_reqs].max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@ -6,14 +6,13 @@ import torch
import triton
import triton.language as tl
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
PAD_SLOT_ID = -1
class BlockTables:
def __init__(
self,
block_sizes: list[int],
@ -50,44 +49,48 @@ class BlockTables:
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(
self.input_block_tables)
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device)
self.block_sizes_tensor = torch.tensor(self.block_sizes,
dtype=torch.int32,
device=self.device)
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
)
self.num_blocks = torch.zeros(
self.num_kv_cache_groups,
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.slot_mappings = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups,
self.max_num_reqs + 1,
dtype=torch.int32)
self.cu_num_new_blocks = self._make_buffer(
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x],
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
ptrs_tensor_cpu = torch.tensor(
[t.data_ptr() for t in x],
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
@ -105,7 +108,7 @@ class BlockTables:
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound to the number of new blocks in a single step.
@ -120,9 +123,8 @@ class BlockTables:
)
new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device,
non_blocking=True)
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
@ -135,7 +137,7 @@ class BlockTables:
self.block_table_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
BLOCK_SIZE=1024, # type: ignore
)
def gather_block_tables(
@ -150,10 +152,9 @@ class BlockTables:
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs]
for block_table in self.input_block_tables)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
@ -174,7 +175,7 @@ class BlockTables:
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
@ -201,8 +202,7 @@ def _append_block_ids_kernel(
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = (cu_num_new_blocks_ptr +
group_id * cu_num_new_blocks_stride)
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
@ -220,15 +220,15 @@ def _append_block_ids_kernel(
block_table_stride = tl.load(block_table_strides + group_id)
row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = (new_block_ids_ptr +
group_id * new_block_ids_stride)
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
for i in range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
mask=offset < num_new_blocks)
tl.store(row_ptr + dst_start_idx + offset,
block_ids,
mask=offset < num_new_blocks)
block_ids = tl.load(
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
)
tl.store(
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
)
@triton.jit
@ -282,11 +282,9 @@ def _compute_slot_mappings_kernel(
if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset,
PAD_ID,
mask=offset < max_num_tokens)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
@ -295,12 +293,13 @@ def _compute_slot_mappings_kernel(
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
for i in range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(block_table_ptr +
req_idx * block_table_stride + block_indices)
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices
)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)

View File

@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from contextlib import contextmanager
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
self.padded_sizes = self._init_padded_sizes()
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def _init_padded_sizes(self) -> dict[int, int]:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
# CUDA graphs are disabled.
return {}
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
raise NotImplementedError("Piecewise CUDA graphs are not supported")
if self.compilation_config.level != 0:
raise NotImplementedError("Dynamo is not used. Compilation level must be 0")
padded_sizes: dict[int, int] = {}
assert len(self.cudagraph_sizes) > 0
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
padded_sizes[i] = x
break
return padded_sizes
def needs_capture(self) -> bool:
return len(self.padded_sizes) > 0
def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None:
if max(scheduler_output.num_scheduled_tokens.values()) > 1:
# Prefill is included.
return None
return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens)
def capture_graph(
self,
batch_size: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert batch_size not in self.graphs
# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions.gpu[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
input_buffers.seq_lens.np[batch_size:] = 0
input_buffers.seq_lens.copy_to_gpu()
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=batch_size,
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
torch.cuda.synchronize()
# Capture the graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:batch_size] = hidden_states
self.graphs[batch_size] = graph
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert self.needs_capture()
# Capture larger graphs first.
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with freeze_gc(), graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)
def run(self, batch_size: int) -> torch.Tensor:
assert batch_size in self.graphs
self.graphs[batch_size].replay()
return self.hidden_states[:batch_size]
@contextmanager
def freeze_gc():
gc.collect()
gc.freeze()
try:
yield
finally:
gc.unfreeze()

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
@ -16,11 +15,12 @@ from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
@ -32,20 +32,17 @@ class InputBuffers:
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.query_start_loc = self._make_buffer(max_num_reqs + 1,
dtype=torch.int32)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
@ -54,17 +51,23 @@ class InputBatch:
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# [num_reqs]
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
# [num_reqs]
is_chunked_prefilling: np.ndarray
# [max_num_batched_tokens]
# [num_reqs + 1]
query_start_loc: torch.Tensor
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
seq_lens_np: np.ndarray
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [max_num_batched_tokens]
# [num_tokens_after_padding]
positions: torch.Tensor
# layer_name -> Metadata
@ -78,23 +81,34 @@ class InputBatch:
cls,
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.tensor(idx_mapping_np, device=device)
num_scheduled_tokens = np.full(num_reqs,
num_tokens // num_reqs,
dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_)
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device)
attn_metadata = defaultdict(lambda: None)
logits_indices = torch.arange(num_reqs,
dtype=torch.int32,
device=device)
assert int(num_scheduled_tokens.sum()) == num_tokens
input_buffers.query_start_loc.np[0] = 0
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
num_scheduled_tokens
)
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
input_buffers.seq_lens.np[num_reqs:] = 0
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
positions = input_buffers.positions.copy_to_gpu(num_tokens)
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
@ -103,10 +117,13 @@ class InputBatch:
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
is_chunked_prefilling=is_chunked_prefilling,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
attn_metadata=None,
logits_indices=logits_indices,
)
@ -130,14 +147,14 @@ class InputBatch:
cache=True,
)
def _prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
@ -161,14 +178,14 @@ def _prepare_inputs(
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_inputs(
idx_mapping: np.ndarray,
prompt_token_ids: np.ndarray,
prefill_token_ids: np.ndarray,
num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray,
input_ids: CpuGpuBuffer,
@ -176,10 +193,10 @@ def prepare_inputs(
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_tokens: int,
) -> tuple[np.ndarray, np.ndarray]:
) -> None:
_prepare_inputs(
idx_mapping,
prompt_token_ids,
prefill_token_ids,
num_computed_tokens,
num_scheduled_tokens,
input_ids.np,
@ -194,11 +211,7 @@ def prepare_inputs(
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
seq_lens.copy_to_gpu()
num_reqs = num_scheduled_tokens.shape[0]
max_query_len = int(num_scheduled_tokens.max())
max_seq_len = int(seq_lens.np[:num_reqs].max())
return max_query_len, max_seq_len
return
@triton.jit
@ -208,21 +221,18 @@ def _combine_last_token_ids_kernel(
last_token_ids_ptr,
query_start_loc_ptr,
seq_lens_ptr,
num_tokens_ptr,
prefill_len_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
num_tokens = tl.load(num_tokens_ptr + req_state_idx)
if seq_len < num_tokens:
# Chunked prefilling.
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
# Handling prefill tokens.
return
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
if last_token_id == -1:
return
end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id)
@ -233,15 +243,15 @@ def combine_last_token_ids(
last_token_ids: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: torch.Tensor,
prefill_len: torch.Tensor,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs, )](
_combine_last_token_ids_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_token_ids,
query_start_loc,
seq_lens,
num_tokens,
prefill_len,
)
return input_ids

View File

@ -3,41 +3,52 @@
import gc
import time
from copy import deepcopy
from typing import Any, Optional
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, is_pin_memory_available)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache)
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
get_kv_cache_spec,
init_attn_backend,
init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
evenly_split)
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
combine_last_token_ids,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
combine_last_token_ids,
prepare_inputs,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
logger = init_logger(__name__)
class GPUModelRunner:
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
vllm_config: VllmConfig,
@ -61,17 +72,19 @@ class GPUModelRunner:
if self.cache_config.cache_dtype != "auto":
# Quantized KV cache.
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.cache_config.cache_dtype
]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.hidden_size = self.model_config.get_hidden_size()
self.use_async_scheduling = self.scheduler_config.async_scheduling
assert self.use_async_scheduling
self.output_copy_stream = torch.cuda.Stream()
self.output_copy_stream = torch.cuda.Stream(self.device)
self.input_prep_event = torch.cuda.Event()
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
@ -84,29 +97,46 @@ class GPUModelRunner:
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
hidden_size=self.hidden_size,
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
self.sampler = Sampler()
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
vllm_config=self.vllm_config,
device=self.device,
)
def get_supported_tasks(self) -> tuple[str]:
return ("generate", )
return ("generate",)
def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(self.vllm_config.load_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
)
if self.lora_config:
self.model = self.load_lora_model(
self.model,
self.vllm_config,
self.device,
)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load)
logger.info(
"Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load,
)
def get_model(self) -> nn.Module:
return self.model
@ -143,32 +173,60 @@ class GPUModelRunner:
self.device,
)
@torch.inference_mode()
def _dummy_run(
self,
num_tokens: int,
*args,
input_batch: Optional[InputBatch] = None,
input_batch: InputBatch | None = None,
skip_attn: bool = True,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if input_batch is None:
num_reqs = min(num_tokens, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=min(num_tokens, self.max_num_reqs),
num_reqs=num_reqs,
num_tokens=num_tokens,
input_buffers=self.input_buffers,
device=self.device,
)
if not skip_attn:
block_tables = self.block_tables.gather_block_tables(
input_batch.idx_mapping
)
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.query_start_loc,
input_batch.positions,
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=None,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_batch.attn_metadata = attn_metadata
with set_forward_context(
with self.maybe_dummy_run_with_lora(
self.lora_config, input_batch.num_scheduled_tokens
):
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@torch.inference_mode()
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
@ -179,35 +237,80 @@ class GPUModelRunner:
device=self.device,
)
logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata)
self.sampler.sample(logits, sampling_metadata)
@torch.inference_mode()
def profile_run(self) -> None:
input_batch = InputBatch.make_dummy(
num_reqs=self.max_num_reqs,
num_tokens=self.max_num_tokens,
input_buffers=self.input_buffers,
device=self.device,
)
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens,
input_batch=input_batch,
skip_attn=True,
)
self._dummy_sampler_run(sample_hidden_states)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def reset_mm_cache(self) -> None:
pass
@torch.inference_mode()
def capture_model(self) -> int:
if not self.cudagraph_manager.needs_capture():
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return 0
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
self.cudagraph_manager.capture(
model=self.model,
input_buffers=self.input_buffers,
block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info(
"Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time,
cuda_graph_size / (1 << 30),
)
return cuda_graph_size
def warmup_for_prefill(self) -> None:
# For FlashInfer, we would like to execute a dummy prefill run to trigger JIT compilation.
if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
self._dummy_run(self.max_num_tokens, skip_attn=False)
torch.cuda.synchronize()
def update_states(self, scheduler_output: SchedulerOutput) -> None:
# for req_id in scheduler_output.preempted_req_ids:
# self.req_states.remove_request(req_id)
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
# TODO(woosuk): Change SchedulerOutput.
req_indices: list[int] = []
cu_num_new_blocks = tuple(
[0] for _ in range(self.block_tables.num_kv_cache_groups))
new_block_ids = tuple(
[] for _ in range(self.block_tables.num_kv_cache_groups))
[0] for _ in range(self.block_tables.num_kv_cache_groups)
)
new_block_ids = tuple([] for _ in range(self.block_tables.num_kv_cache_groups))
overwrite: list[bool] = []
# Add new requests.
@ -215,9 +318,11 @@ class GPUModelRunner:
req_id = new_req_data.req_id
self.req_states.add_request(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_len=len(new_req_data.prompt_token_ids),
prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
lora_request=new_req_data.lora_request,
)
req_index = self.req_states.req_id_to_index[req_id]
@ -250,21 +355,30 @@ class GPUModelRunner:
overwrite=overwrite,
)
def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch:
def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
use_cudagraph: bool,
padded_num_tokens: int | None,
) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get)
req_ids = sorted(
scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get,
)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)
# TODO(woosuk): Support CUDA graphs.
num_tokens_after_padding = num_tokens
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
)
if use_cudagraph:
assert padded_num_tokens is not None
num_tokens_after_padding = padded_num_tokens
else:
num_tokens_after_padding = num_tokens
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
@ -277,9 +391,9 @@ class GPUModelRunner:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
max_query_len, max_seq_len = prepare_inputs(
prepare_inputs(
idx_mapping_np,
self.req_states.prompt_token_ids,
self.req_states.prefill_token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens,
self.input_buffers.input_ids,
@ -290,10 +404,9 @@ class GPUModelRunner:
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
seq_lens_cpu = self.input_buffers.seq_lens.cpu[:num_reqs]
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
# Some input token ids are directly read from the last sampled tokens.
@ -303,56 +416,33 @@ class GPUModelRunner:
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens_gpu,
self.req_states.num_tokens.copy_to_gpu(),
self.req_states.prefill_len.copy_to_gpu(),
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens])
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
)
num_computed_tokens_cpu = torch.from_numpy(
self.req_states.num_computed_tokens[idx_mapping_np])
# Whether the request is chunked-prefilling or not.
is_chunked_prefilling = (
seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np])
self.req_states.num_computed_tokens[idx_mapping_np]
)
# Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1
num_logits_indices = logits_indices.size(0)
# Layer name -> attention metadata.
attn_metadata: dict[str, Any] = {}
kv_cache_groups = self.kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
logits_indices_padded=None,
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=None,
)
attn_metadata_builder = self.attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=num_computed_tokens_cpu,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
@ -364,7 +454,10 @@ class GPUModelRunner:
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
is_chunked_prefilling=is_chunked_prefilling,
query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens_gpu,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
@ -375,102 +468,221 @@ class GPUModelRunner:
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
pos = input_batch.positions[input_batch.logits_indices]
idx_mapping_np = input_batch.idx_mapping_np
num_reqs = logits.shape[0]
# When the batch size is large enough, use DP sampler.
tp_group = get_tp_group()
tp_size = tp_group.world_size
n = (num_reqs + tp_size - 1) // tp_size
use_dp_sampler = tp_size > 1 and n > 32 # TODO(woosuk): Tune.
if use_dp_sampler:
# NOTE(woosuk): Make sure that no rank gets zero requests.
tp_rank = tp_group.rank
start, end = evenly_split(num_reqs, tp_size, tp_rank)
logits = logits[start:end]
pos = pos[start:end]
idx_mapping_np = idx_mapping_np[start:end]
sampling_metadata = self.req_states.make_sampling_metadata(
idx_mapping_np, pos)
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
needs_prompt_logprobs = np.any(
self.req_states.needs_prompt_logprobs[idx_mapping_np])
assert not needs_prompt_logprobs
if use_dp_sampler:
# All-gather the outputs.
sampler_output = all_gather_sampler_output(
sampler_output,
num_reqs,
tp_size,
)
sampler_output = self.sampler.sample(logits, sampling_metadata)
return sampler_output
def compute_prompt_logprobs(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# No request asks for prompt logprobs.
return {}
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
includes_prompt = num_computed_tokens < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = (
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
)
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Just to be safe, clone the input ids.
n = input_batch.num_tokens
# Shift the input ids by one.
token_ids = torch.empty_like(input_batch.input_ids[:n])
token_ids[: n - 1] = input_batch.input_ids[1:n]
# To avoid out-of-bound access, set the last token id to 0.
token_ids[n - 1] = 0
# Handle chunked prompts.
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
is_prompt_chunked = seq_lens < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids
query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
if not is_prompt_chunked[i]:
continue
# The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i]
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
idx = int(query_start_loc[i + 1] - 1)
# Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation.
token_ids[idx] = next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
torch.relu(token_ids),
hidden_states[:n],
self.model.compute_logits,
)
prompt_logprobs[:, 0].masked_fill_(token_ids < 0, 0)
prompt_token_ids = token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc[i]
end_idx = query_start_loc[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
req_extra_data = self.req_states.extra_data[req_id]
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
def postprocess(
self,
sampler_output: SamplerOutput,
sampling_metadata: SamplingMetadata,
prompt_logprobs_dict: dict[str, LogprobsTensors],
input_batch: InputBatch,
) -> AsyncOutput:
) -> AsyncOutput | ModelRunnerOutput:
# Store the last sampled token ids.
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
sampler_output.sampled_token_ids)
sampler_output.sampled_token_ids
)
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling
idx_mapping_np = input_batch.idx_mapping_np
is_chunked_prefilling = (
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
)
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
idx_mapping_np = input_batch.idx_mapping_np
self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
# Increment the number of computed tokens.
self.req_states.num_computed_tokens[idx_mapping_np] += (
input_batch.num_scheduled_tokens)
input_batch.num_scheduled_tokens
)
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None,
num_sampled_tokens=num_sampled_tokens,
logprobs=None,
prompt_logprobs_dict={},
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
return AsyncOutput(
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
copy_stream=self.output_copy_stream,
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: SchedulerOutput,
) -> AsyncOutput:
self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0:
return EMPTY_MODEL_RUNNER_OUTPUT
intermediate_tensors: Any | None = None,
) -> AsyncOutput | ModelRunnerOutput:
assert intermediate_tensors is None
input_batch = self.prepare_inputs(scheduler_output)
num_tokens = input_batch.num_tokens_after_padding
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
with async_barrier(
self.input_prep_event if self.use_async_scheduling else None
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0:
return EMPTY_MODEL_RUNNER_OUTPUT
padded_num_tokens = self.cudagraph_manager.get_cudagraph_size(
scheduler_output
)
use_cudagraph = padded_num_tokens is not None
input_batch = self.prepare_inputs(
scheduler_output,
use_cudagraph,
padded_num_tokens,
)
pos = input_batch.positions[input_batch.logits_indices]
idx_mapping_np = input_batch.idx_mapping_np
sampling_metadata = self.req_states.make_sampling_metadata(
idx_mapping_np, pos
)
sampler_output = self.sample(hidden_states, input_batch)
return self.postprocess(sampler_output, input_batch)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
input_batch.req_ids,
input_batch.idx_mapping_np,
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
# Run model.
if use_cudagraph:
# Run CUDA graph.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
hidden_states = self.cudagraph_manager.run(padded_num_tokens)
else:
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
):
# Run PyTorch model in eager mode.
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sampler_output = self.sample(hidden_states, input_batch, sampling_metadata)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
output = self.postprocess(
sampler_output,
sampling_metadata,
prompt_logprobs_dict,
input_batch,
)
return output

View File

@ -1,61 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import LogprobsMode
from vllm.config.model import LogprobsMode
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.states import SamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
class Sampler:
def __init__(
self,
logprobs_mode: LogprobsMode = "processed_logprobs",
logprobs_mode: LogprobsMode = "raw_logprobs",
):
super().__init__()
assert logprobs_mode == "processed_logprobs"
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
def forward(
def sample_token(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# Divide logits by temperature, in FP32.
logits = apply_temperature(logits, sampling_metadata.temperature)
# Apply top_k and/or top_p.
return_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
is_greedy = sampling_metadata.temperature == 0
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits = logits / temp.view(-1, 1)
logits = apply_top_k_top_p(
logits,
sampling_metadata.top_k,
sampling_metadata.top_p,
logits, sampling_metadata.top_k, sampling_metadata.top_p
)
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# Sample the next token (int64).
sampled = gumbel_sample(
probs,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
)
sampled = sampled.to(torch.int64)
return sampled, logits if return_logits else None
logprobs_tensors = None
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
logprobs_tensors = compute_logprobs(
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
sampled, logits = self.sample_token(
logits, sampling_metadata, return_logits=True
)
else:
assert self.logprobs_mode == "raw_logprobs"
sampled, _ = self.sample_token(
logits, sampling_metadata, return_logits=False
)
logprobs_tensors = compute_topk_logprobs(
logits,
num_logprobs,
sampling_metadata.max_num_logprobs,
sampled,
)
else:
sampled, _ = self.sample_token(
logits, sampling_metadata, return_logits=False
)
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
@ -69,60 +84,7 @@ class Sampler(nn.Module):
@triton.jit
def _apply_temp_kernel(
logits, # bf16[batch_size, vocab_size]
logits_stride,
output, # fp32[batch_size, vocab_size]
output_stride,
temperature,
vocab_size,
BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
temp = tl.load(temperature + batch_idx)
if temp < EPSILON:
# Greedy sampling. Don't apply temperature.
# NOTE(woosuk): In this case, we assume that its logprobs are not used.
temp = 1.0
offset = tl.arange(0, BLOCK_SIZE)
block = block_idx * BLOCK_SIZE + offset
# Load the logits.
x = tl.load(logits + batch_idx * logits_stride + block,
mask=block < vocab_size)
x = x.to(tl.float32)
x = x / temp
tl.store(output + batch_idx * output_stride + block,
x,
mask=block < vocab_size)
def apply_temperature(
logits: torch.Tensor,
temperature: torch.Tensor,
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
output = torch.empty_like(logits, dtype=torch.float32)
BLOCK_SIZE = 8192
_apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))](
logits,
logits.stride(0),
output,
output.stride(0),
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
EPSILON=_SAMPLING_EPS,
)
return output
@triton.jit
def _apply_gumbel_kernel(
def _gumbel_sample_kernel(
probs_ptr,
probs_stride,
seeds_ptr,
@ -130,18 +92,17 @@ def _apply_gumbel_kernel(
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
req_idx = tl.program_id(0)
temp = tl.load(temp_ptr + req_idx)
if temp < EPSILON:
if temp == 0.0:
# Greedy sampling. Don't apply gumbel noise.
return
seed = tl.load(seeds_ptr + req_idx).to(tl.uint64)
pos = tl.load(pos_ptr + req_idx).to(tl.uint64)
gumbel_seed = seed ^ (pos * 0x9E3779B97F4A7C15)
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
gumbel_seed = tl.randint(seed, pos)
block_id = tl.program_id(1)
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@ -153,42 +114,33 @@ def _apply_gumbel_kernel(
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
q = -1.0 * q
p = tl.load(probs_ptr + req_idx * probs_stride + r_offset,
mask=r_offset < vocab_size)
p = tl.load(
probs_ptr + req_idx * probs_stride + r_offset, mask=r_offset < vocab_size
)
p = p / q
tl.store(probs_ptr + req_idx * probs_stride + r_offset,
p,
mask=r_offset < vocab_size)
tl.store(
probs_ptr + req_idx * probs_stride + r_offset, p, mask=r_offset < vocab_size
)
def gumbel_sample(
# fp32[num_reqs, vocab_size]
probs: torch.Tensor,
# fp32[num_reqs]
temperature: torch.Tensor,
# int64[num_reqs]
seeds: torch.Tensor,
# int64[num_reqs]
pos: torch.Tensor,
probs: torch.Tensor, # [num_reqs, vocab_size]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
) -> torch.Tensor:
num_reqs = probs.shape[0]
vocab_size = probs.shape[1]
# Update the probs in-place.
BLOCK_SIZE = 8192
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
num_reqs, vocab_size = probs.shape
_gumbel_sample_kernel[(num_reqs,)](
probs,
probs.stride(0),
seeds,
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE,
EPSILON=_SAMPLING_EPS,
BLOCK_SIZE=8192, # type: ignore
)
# Sample the next token.
return probs.argmax(dim=-1).view(-1)
sampled = probs.argmax(dim=-1)
return sampled
@triton.jit
@ -208,54 +160,31 @@ def _topk_log_softmax_kernel(
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block,
mask=block < vocab_size,
other=float("-inf"))
max_val = tl.max(tl.maximum(l, max_val))
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32)
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
e = tl.exp(l - max_val)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations are in float32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask)
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
l = tl.load(row_ptr + topk_ids, mask=k_mask)
o = l - max_val - lse
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
def compute_topk_logprobs(
logits: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
topk = topk_ids.shape[1]
output = torch.empty(
batch_size,
topk,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size, )](
output,
logits,
logits.stride(0),
topk_ids,
topk,
vocab_size,
BLOCK_SIZE=1024,
PADDED_TOPK=triton.next_power_of_2(topk),
)
return output
@triton.jit
def _ranks_kernel(
output_ptr,
@ -274,14 +203,39 @@ def _ranks_kernel(
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block,
mask=block < vocab_size,
other=float("-inf"))
n += tl.sum((l > x).to(tl.int32))
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
n += tl.sum((logits > x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_logprobs(
def compute_token_logprobs(
logits: torch.Tensor,
token_ids: torch.Tensor,
) -> torch.Tensor:
batch_size = logits.shape[0]
vocab_size = logits.shape[1]
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = torch.empty(
batch_size,
num_logprobs,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
)
return logprobs
def compute_topk_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
@ -293,31 +247,56 @@ def compute_logprobs(
else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_topk_logprobs(
logits,
logprob_token_ids,
)
logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty(
batch_size,
dtype=torch.int64,
device=logits.device,
)
_ranks_kernel[(batch_size, )](
_ranks_kernel[(batch_size,)](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192,
BLOCK_SIZE=8192, # type: ignore
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
)
def compute_prompt_logprobs(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@ -1,21 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
NO_LORA_ID = 0
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
top_p: torch.Tensor | None
@ -36,12 +37,14 @@ class SamplingMetadata:
assert num_reqs > 0
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
top_p = torch.ones(num_reqs, dtype=torch.float32, device=device)
top_p[0] = 0.99
top_k = torch.ones(num_reqs, dtype=torch.int32, device=device)
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
top_p = None
top_k = None
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls(
temperature=temperature,
top_p=top_p,
@ -53,7 +56,6 @@ class SamplingMetadata:
class RequestState:
def __init__(
self,
max_num_reqs: int,
@ -73,15 +75,15 @@ class RequestState:
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
self.extra_data: dict[str, ExtraData] = {}
# NOTE(woosuk): Strictly speaking, it contains prompt + some output
# because of preemption.
self.prompt_token_ids = np.zeros(
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
self.prefill_token_ids = np.zeros(
(self.max_num_reqs, self.max_model_len),
dtype=np.int32,
)
self.num_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled tokens.
@ -92,6 +94,10 @@ class RequestState:
device=device,
)
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
@ -104,16 +110,12 @@ class RequestState:
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
return CpuGpuBuffer(
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
@property
def num_reqs(self) -> int:
@ -122,23 +124,32 @@ class RequestState:
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
prompt_len: int,
prefill_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
lora_request: LoRARequest | None,
) -> None:
assert len(self.free_indices) > 0
assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.extra_data[req_id] = ExtraData(lora_request)
# NOTE(woosuk): Strictly speaking, "prompt_len" here may include
# output tokens, if the request is resumed from preemption.
prompt_len = len(prompt_token_ids)
self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids
self.num_tokens.np[req_idx] = prompt_len
self.prompt_len[req_idx] = prompt_len
prefill_len = len(prefill_token_ids)
assert prefill_len >= prompt_len, (
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
self.num_tokens[req_idx] = prefill_len
self.num_computed_tokens[req_idx] = num_computed_tokens
# TODO(woosuk): Optimize.
self.last_sampled_tokens[req_idx].fill_(-1)
if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id
else:
self.lora_ids[req_idx] = NO_LORA_ID
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
@ -165,6 +176,7 @@ class RequestState:
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
@ -205,9 +217,25 @@ class RequestState:
max_num_logprobs=max_num_logprobs,
)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.extra_data[req_id].lora_request
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
class Param:
def __init__(
self,
size: int,
@ -227,3 +255,9 @@ class Param:
n = x.shape[0]
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)
@dataclass
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)

View File

@ -42,6 +42,7 @@ from vllm.v1.outputs import (
ModelRunnerOutput,
)
from vllm.v1.utils import report_usage_stats
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.utils import is_residual_scattered_for_sp
@ -495,6 +496,8 @@ class Worker(WorkerBase):
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
return self.model_runner.execute_model(scheduler_output)
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens