From 09e4b2f6ebc4695ab2b60b64290fda8b5ea05de2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Oct 2025 16:30:06 -0700 Subject: [PATCH] update --- requirements/cuda.txt | 3 +- vllm/v1/core/sched/output.py | 3 + vllm/v1/core/sched/scheduler.py | 13 +- vllm/v1/outputs.py | 10 +- vllm/v1/worker/gpu/async_utils.py | 61 ++- vllm/v1/worker/gpu/attn_utils.py | 101 ++++- vllm/v1/worker/gpu/block_table.py | 109 +++--- vllm/v1/worker/gpu/cudagraph_utils.py | 175 +++++++++ vllm/v1/worker/gpu/input_batch.py | 116 +++--- vllm/v1/worker/gpu/model_runner.py | 522 ++++++++++++++++++-------- vllm/v1/worker/gpu/sampler.py | 293 +++++++-------- vllm/v1/worker/gpu/states.py | 94 +++-- vllm/v1/worker/gpu_worker.py | 3 + 13 files changed, 1001 insertions(+), 502 deletions(-) create mode 100644 vllm/v1/worker/gpu/cudagraph_utils.py diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 5f7d520cd3..b9fe96dd7a 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -10,6 +10,7 @@ torchaudio==2.9.0 # These must be updated alongside torch torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 -xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9 +# xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9 # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.4.1 +apache-tvm-ffi==0.1.0b15 diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index cc6b89e2bf..efd64e671c 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -34,6 +34,7 @@ else: class NewRequestData: req_id: str prompt_token_ids: list[int] | None + prefill_token_ids: list[int] | None mm_features: list[MultiModalFeatureSpec] sampling_params: SamplingParams | None pooling_params: PoolingParams | None @@ -51,6 +52,7 @@ class NewRequestData: return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, + prefill_token_ids=request._all_token_ids, mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, @@ -173,6 +175,7 @@ class SchedulerOutput: # This can be used for cascade attention. num_common_prefix_blocks: list[int] + preempted_req_ids: set[str] # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests # so that they can free the cached states for those requests. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9407858f7e..34eb03a32d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -606,6 +606,9 @@ class Scheduler(SchedulerInterface): ) # Construct the scheduler output. + scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs + scheduled_resumed_reqs = [] + new_reqs_data = [ NewRequestData.from_request( req, req_to_new_blocks[req.request_id].get_block_ids() @@ -635,6 +638,7 @@ class Scheduler(SchedulerInterface): scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, + preempted_req_ids={req.request_id for req in preempted_reqs}, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between @@ -720,14 +724,6 @@ class Scheduler(SchedulerInterface): req.num_computed_tokens : req.num_computed_tokens + num_tokens ] new_token_ids.append(token_ids) - scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids - if idx >= num_running_reqs: - assert not scheduled_in_prev_step - resumed_req_ids.add(req_id) - if not scheduled_in_prev_step: - all_token_ids[req_id] = req.all_token_ids[ - : req.num_computed_tokens + num_tokens - ] new_block_ids.append( req_to_new_blocks[req_id].get_block_ids(allow_none=True) ) @@ -902,7 +898,6 @@ class Scheduler(SchedulerInterface): model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids - num_sampled_tokens = model_runner_output.num_sampled_tokens logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e9f42a2479..524ab01c0a 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -15,7 +15,6 @@ else: class LogprobsLists(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: np.ndarray # [num_reqs, max_num_logprobs + 1] @@ -135,13 +134,14 @@ class KVConnectorOutput: class ModelRunnerOutput: # [num_reqs] req_ids: list[str] + # req_id -> index + req_id_to_index: dict[str, int] # num_reqs x num_generated_tokens # num_generated_tokens is the number of tokens # generated in the current step. It can be different for # each request due to speculative/jump decoding. - sampled_token_ids: np.ndarray | None - num_sampled_tokens: np.ndarray | None + sampled_token_ids: list[list[int]] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] @@ -186,8 +186,8 @@ class DraftTokenIds: EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[], - sampled_token_ids=None, - num_sampled_tokens=None, + req_id_to_index={}, + sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py index ed11701739..29f4f1e5a4 100644 --- a/vllm/v1/worker/gpu/async_utils.py +++ b/vllm/v1/worker/gpu/async_utils.py @@ -1,21 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import contextmanager + +import numpy as np import torch -from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors, - ModelRunnerOutput, SamplerOutput) +from vllm.v1.outputs import ( + AsyncModelRunnerOutput, + ModelRunnerOutput, + SamplerOutput, +) class AsyncOutput(AsyncModelRunnerOutput): - def __init__( self, model_runner_output: ModelRunnerOutput, sampler_output: SamplerOutput, + num_sampled_tokens: np.ndarray, copy_stream: torch.cuda.Stream, ): self.model_runner_output = model_runner_output self.sampler_output = sampler_output + self.num_sampled_tokens = num_sampled_tokens self.copy_stream = copy_stream self.copy_event = torch.cuda.Event() @@ -23,26 +30,46 @@ class AsyncOutput(AsyncModelRunnerOutput): with torch.cuda.stream(self.copy_stream): self.copy_stream.wait_stream(default_stream) + # NOTE(woosuk): We should keep the CPU tensors unfreed, until the copy completes. self.sampled_token_ids = sampler_output.sampled_token_ids.to( - "cpu", non_blocking=True) - x = sampler_output.logprobs_tensors - if x is not None: - self.logprobs_tensors = LogprobsTensors( - logprob_token_ids=x.logprob_token_ids.to( - "cpu", non_blocking=True), - logprobs=x.logprobs.to("cpu", non_blocking=True), - selected_token_ranks=x.selected_token_ranks.to( - "cpu", non_blocking=True), + "cpu", non_blocking=True + ) + if sampler_output.logprobs_tensors is not None: + self.logprobs_tensors = ( + sampler_output.logprobs_tensors.to_cpu_nonblocking() ) else: self.logprobs_tensors = None - self.copy_event.record() + self.prompt_logprobs_dict = {} + if self.model_runner_output.prompt_logprobs_dict: + for k, v in self.model_runner_output.prompt_logprobs_dict.items(): + self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking() + self.copy_event.record(self.copy_stream) def get_output(self) -> ModelRunnerOutput: self.copy_event.synchronize() - self.model_runner_output.sampled_token_ids = ( - self.sampled_token_ids.numpy()) + + # NOTE(woosuk): The following code ensures compatibility with OSS vLLM. + # Going forward, we should keep the data structures as NumPy arrays + # rather than Python lists. + sampled_token_ids_np = self.sampled_token_ids.numpy() + sampled_token_ids = sampled_token_ids_np.tolist() + for i, tokens in enumerate(sampled_token_ids): + del tokens[self.num_sampled_tokens[i] :] + self.model_runner_output.sampled_token_ids = sampled_token_ids + if self.logprobs_tensors is not None: - self.model_runner_output.logprobs = ( - self.logprobs_tensors.tolists()) + self.model_runner_output.logprobs = self.logprobs_tensors.tolists() + self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict return self.model_runner_output + + +@contextmanager +def async_barrier(event: torch.cuda.Event | None): + if event is not None: + event.synchronize() + try: + yield + finally: + if event is not None: + event.record() diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index f23b918736..474c22e524 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -7,9 +7,17 @@ import torch from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.v1.attention.backends.utils import AttentionMetadataBuilder -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec, SlidingWindowSpec) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheSpec, + SlidingWindowSpec, +) +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.utils import bind_kv_cache @@ -18,7 +26,6 @@ def get_kv_cache_spec( kv_cache_dtype: torch.dtype, ) -> dict[str, KVCacheSpec]: block_size = vllm_config.cache_config.block_size - use_mla = vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(vllm_config, Attention) @@ -31,7 +38,6 @@ def get_kv_cache_spec( head_size=attn_module.head_size, dtype=kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=use_mla, ) else: kv_cache_spec[layer_name] = FullAttentionSpec( @@ -39,7 +45,6 @@ def get_kv_cache_spec( num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=kv_cache_dtype, - use_mla=use_mla, ) return kv_cache_spec @@ -52,6 +57,7 @@ def init_attn_backend( attn_backends: dict[str, AttentionBackend] = {} attn_metadata_builders: list[AttentionMetadataBuilder] = [] + flashinfer_workspace: torch.Tensor | None = None attn_layers = get_layers_from_vllm_config(vllm_config, Attention) for kv_cache_group_spec in kv_cache_config.kv_cache_groups: layer_names = kv_cache_group_spec.layer_names @@ -67,7 +73,13 @@ def init_attn_backend( vllm_config, device, ) - attn_metadata_builders.append(attn_metadata_builder) + attn_metadata_builders.append(attn_metadata_builder) # type: ignore + + if "FLASHINFER" in attn_backend.get_name(): + if flashinfer_workspace is None: + flashinfer_workspace = attn_metadata_builder.get_workspace_buffer() + else: + attn_metadata_builder.set_workspace_buffer(flashinfer_workspace) return attn_backends, attn_metadata_builders @@ -77,9 +89,7 @@ def _allocate_kv_cache( ): kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=device) + tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tensor @@ -87,8 +97,9 @@ def _allocate_kv_cache( for group in kv_cache_config.kv_cache_groups: for layer_name in group.layer_names: layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys() - ), "Some layers are not correctly initialized" + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) return kv_cache_raw_tensors @@ -103,17 +114,19 @@ def _reshape_kv_cache( for layer_name in kv_cache_group_spec.layer_names: raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes attn_backend = attn_backends[layer_name] kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) dtype = kv_cache_spec.dtype kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - kv_cache_shape = tuple(kv_cache_shape[i] - for i in kv_cache_stride_order) + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) inv_order = [ kv_cache_stride_order.index(i) @@ -132,8 +145,56 @@ def init_kv_cache( kv_cache_config: KVCacheConfig, attn_backends: dict[str, AttentionBackend], device: torch.device, -): +) -> None: kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) - kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, - attn_backends) + kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends) bind_kv_cache(kv_caches, forward_context, runner_kv_caches) + + +def build_attn_metadata( + attn_metadata_builders: list[AttentionMetadataBuilder], + num_reqs: int, + num_tokens: int, + query_start_loc: CpuGpuBuffer, + seq_lens: CpuGpuBuffer, + num_computed_tokens_cpu: torch.Tensor, + block_tables: tuple[torch.Tensor, ...], + slot_mappings: torch.Tensor, + kv_cache_config: KVCacheConfig, +) -> dict[str, Any]: + query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] + query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1] + max_query_len = int(query_start_loc.np[: num_reqs + 1].max()) + seq_lens_gpu = seq_lens.gpu[:num_reqs] + seq_lens_cpu = seq_lens.cpu[:num_reqs] + max_seq_len = int(seq_lens.np[:num_reqs].max()) + + attn_metadata: dict[str, Any] = {} + kv_cache_groups = kv_cache_config.kv_cache_groups + for i, kv_cache_spec in enumerate(kv_cache_groups): + block_table = block_tables[i] + slot_mapping = slot_mappings[i] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens_gpu, + seq_lens_cpu=seq_lens_cpu, + max_seq_len=max_seq_len, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + block_table_tensor=block_table, + slot_mapping=slot_mapping, + causal=True, + ) + + attn_metadata_builder = attn_metadata_builders[i] + metadata = attn_metadata_builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + for layer_name in kv_cache_spec.layer_names: + attn_metadata[layer_name] = metadata + return attn_metadata diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index d50a852867..b04cb74564 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -6,14 +6,13 @@ import torch import triton import triton.language as tl -from vllm.utils import cdiv +from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer PAD_SLOT_ID = -1 class BlockTables: - def __init__( self, block_sizes: list[int], @@ -50,44 +49,48 @@ class BlockTables: self.input_block_tables: list[torch.Tensor] = [ torch.zeros_like(block_table) for block_table in self.block_tables ] - self.input_block_table_ptrs = self._make_ptr_tensor( - self.input_block_tables) + self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables) self.block_table_strides = torch.tensor( [b.stride(0) for b in self.block_tables], dtype=torch.int64, - device=self.device) - self.block_sizes_tensor = torch.tensor(self.block_sizes, - dtype=torch.int32, - device=self.device) - self.num_blocks = torch.zeros(self.num_kv_cache_groups, - self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.slot_mappings = torch.zeros(self.num_kv_cache_groups, - self.max_num_batched_tokens, - dtype=torch.int64, - device=self.device) + device=self.device, + ) + self.block_sizes_tensor = torch.tensor( + self.block_sizes, dtype=torch.int32, device=self.device + ) + self.num_blocks = torch.zeros( + self.num_kv_cache_groups, + self.max_num_reqs, + dtype=torch.int32, + device=self.device, + ) + self.slot_mappings = torch.zeros( + self.num_kv_cache_groups, + self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device, + ) # Misc buffers. - self.req_indices = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) + self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool) - self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups, - self.max_num_reqs + 1, - dtype=torch.int32) + self.cu_num_new_blocks = self._make_buffer( + self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32 + ) def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer(*args, - dtype=dtype, - pin_memory=self.pin_memory, - device=self.device) + return CpuGpuBuffer( + *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device + ) def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: - ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x], - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) + ptrs_tensor_cpu = torch.tensor( + [t.data_ptr() for t in x], + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) return ptrs_tensor_cpu.to(self.device, non_blocking=True) def append_block_ids( @@ -105,7 +108,7 @@ class BlockTables: self.req_indices.np[:num_reqs] = req_indices self.overwrite.np[:num_reqs] = overwrite for i in range(self.num_kv_cache_groups): - self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i] + self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i] # NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's # no clear upper bound to the number of new blocks in a single step. @@ -120,9 +123,8 @@ class BlockTables: ) new_block_ids_np = self.new_block_ids_cpu.numpy() for i in range(self.num_kv_cache_groups): - new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i] - new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, - non_blocking=True) + new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i] + new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True) _append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)]( self.req_indices.copy_to_gpu(num_reqs), @@ -135,7 +137,7 @@ class BlockTables: self.block_table_ptrs, self.num_blocks, self.num_blocks.stride(0), - BLOCK_SIZE=1024, + BLOCK_SIZE=1024, # type: ignore ) def gather_block_tables( @@ -150,10 +152,9 @@ class BlockTables: self.block_table_strides, self.num_blocks, self.num_blocks.stride(0), - BLOCK_SIZE=1024, + BLOCK_SIZE=1024, # type: ignore ) - return tuple(block_table[:num_reqs] - for block_table in self.input_block_tables) + return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) def compute_slot_mappings( self, @@ -174,7 +175,7 @@ class BlockTables: self.slot_mappings, self.slot_mappings.stride(0), PAD_ID=PAD_SLOT_ID, - BLOCK_SIZE=1024, + BLOCK_SIZE=1024, # type: ignore ) return self.slot_mappings[:, :num_tokens] @@ -201,8 +202,7 @@ def _append_block_ids_kernel( req_idx = tl.load(req_indices + batch_idx) do_overwrite = tl.load(overwrite + batch_idx) - group_new_blocks_ptr = (cu_num_new_blocks_ptr + - group_id * cu_num_new_blocks_stride) + group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride start_idx = tl.load(group_new_blocks_ptr + batch_idx) end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1) num_new_blocks = end_idx - start_idx @@ -220,15 +220,15 @@ def _append_block_ids_kernel( block_table_stride = tl.load(block_table_strides + group_id) row_ptr = block_table_ptr + req_idx * block_table_stride - group_new_block_ids_ptr = (new_block_ids_ptr + - group_id * new_block_ids_stride) - for i in tl.range(0, num_new_blocks, BLOCK_SIZE): + group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride + for i in range(0, num_new_blocks, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) - block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset, - mask=offset < num_new_blocks) - tl.store(row_ptr + dst_start_idx + offset, - block_ids, - mask=offset < num_new_blocks) + block_ids = tl.load( + group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks + ) + tl.store( + row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks + ) @triton.jit @@ -282,11 +282,9 @@ def _compute_slot_mappings_kernel( if req_idx == tl.num_programs(1) - 1: # Pad remaining slots to -1. This is needed for CUDA graphs. - for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE): + for i in range(num_tokens, max_num_tokens, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) - tl.store(slot_mapping_ptr + offset, - PAD_ID, - mask=offset < max_num_tokens) + tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens) return block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) @@ -295,12 +293,13 @@ def _compute_slot_mappings_kernel( start_idx = tl.load(cu_num_tokens + req_idx) end_idx = tl.load(cu_num_tokens + req_idx + 1) - for i in tl.range(start_idx, end_idx, BLOCK_SIZE): + for i in range(start_idx, end_idx, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) positions = tl.load(pos + offset, mask=offset < end_idx, other=0) block_indices = positions // page_size - block_numbers = tl.load(block_table_ptr + - req_idx * block_table_stride + block_indices) + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices + ) slot_ids = block_numbers * page_size + positions % page_size tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py new file mode 100644 index 0000000000..454ff4d268 --- /dev/null +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +from contextlib import contextmanager + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode +from vllm.distributed.parallel_state import graph_capture, is_global_first_rank +from vllm.forward_context import set_forward_context +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.attn_utils import build_attn_metadata +from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.input_batch import InputBuffers + + +class CudaGraphManager: + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.device = device + + self.max_model_len = vllm_config.model_config.max_model_len + self.compilation_config = vllm_config.compilation_config + assert self.compilation_config is not None + + self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) + self.padded_sizes = self._init_padded_sizes() + + self.graphs: dict[int, torch.cuda.CUDAGraph] = {} + self.pool = torch.cuda.graph_pool_handle() + self.hidden_states: torch.Tensor | None = None + + def _init_padded_sizes(self) -> dict[int, int]: + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: + # CUDA graphs are disabled. + return {} + if self.compilation_config.cudagraph_mode.requires_piecewise_compilation(): + raise NotImplementedError("Piecewise CUDA graphs are not supported") + if self.compilation_config.level != 0: + raise NotImplementedError("Dynamo is not used. Compilation level must be 0") + + padded_sizes: dict[int, int] = {} + assert len(self.cudagraph_sizes) > 0 + for i in range(1, self.cudagraph_sizes[-1] + 1): + for x in self.cudagraph_sizes: + if i <= x: + padded_sizes[i] = x + break + return padded_sizes + + def needs_capture(self) -> bool: + return len(self.padded_sizes) > 0 + + def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None: + if max(scheduler_output.num_scheduled_tokens.values()) > 1: + # Prefill is included. + return None + return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens) + + def capture_graph( + self, + batch_size: int, + model: nn.Module, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + kv_cache_config: KVCacheConfig, + ) -> None: + assert batch_size not in self.graphs + + # Prepare dummy inputs. + input_ids = input_buffers.input_ids.gpu[:batch_size] + positions = input_buffers.positions.gpu[:batch_size] + + input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1) + input_buffers.query_start_loc.np[batch_size:] = batch_size + input_buffers.query_start_loc.copy_to_gpu() + input_buffers.seq_lens.np[:batch_size] = self.max_model_len + input_buffers.seq_lens.np[batch_size:] = 0 + input_buffers.seq_lens.copy_to_gpu() + + input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables] + slot_mappings = block_tables.slot_mappings[:, :batch_size] + + attn_metadata = build_attn_metadata( + attn_metadata_builders=attn_metadata_builders, + num_reqs=batch_size, + num_tokens=batch_size, + query_start_loc=input_buffers.query_start_loc, + seq_lens=input_buffers.seq_lens, + num_computed_tokens_cpu=None, # FIXME + block_tables=input_block_tables, + slot_mappings=slot_mappings, + kv_cache_config=kv_cache_config, + ) + + # Warm up. + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=batch_size, + ): + hidden_states = model( + input_ids=input_ids, + positions=positions, + ) + if self.hidden_states is None: + self.hidden_states = torch.empty_like(hidden_states) + torch.cuda.synchronize() + + # Capture the graph. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, self.pool): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=batch_size, + ): + hidden_states = model( + input_ids=input_ids, + positions=positions, + ) + self.hidden_states[:batch_size] = hidden_states + self.graphs[batch_size] = graph + + @torch.inference_mode() + def capture( + self, + model: nn.Module, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + kv_cache_config: KVCacheConfig, + ) -> None: + assert self.needs_capture() + # Capture larger graphs first. + sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True) + if is_global_first_rank(): + sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") + + with freeze_gc(), graph_capture(device=self.device): + for batch_size in sizes_to_capture: + self.capture_graph( + batch_size, + model, + input_buffers, + block_tables, + attn_metadata_builders, + kv_cache_config, + ) + + def run(self, batch_size: int) -> torch.Tensor: + assert batch_size in self.graphs + self.graphs[batch_size].replay() + return self.hidden_states[:batch_size] + + +@contextmanager +def freeze_gc(): + gc.collect() + gc.freeze() + try: + yield + finally: + gc.unfreeze() diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index cf35945c90..bf43609ec8 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass from typing import Any @@ -16,11 +15,12 @@ from vllm.v1.utils import CpuGpuBuffer class InputBuffers: - def __init__( self, max_num_reqs: int, max_num_tokens: int, + hidden_size: int, + dtype: torch.dtype, device: torch.device, pin_memory: bool, ): @@ -32,20 +32,17 @@ class InputBuffers: self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64) - self.query_start_loc = self._make_buffer(max_num_reqs + 1, - dtype=torch.int32) + self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32) def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer(*args, - dtype=dtype, - pin_memory=self.pin_memory, - device=self.device) + return CpuGpuBuffer( + *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device + ) @dataclass class InputBatch: - # batch_idx -> req_id req_ids: list[str] num_reqs: int @@ -54,17 +51,23 @@ class InputBatch: idx_mapping: torch.Tensor idx_mapping_np: np.ndarray + # [num_reqs] # batch_idx -> num_scheduled_tokens num_scheduled_tokens: np.ndarray # sum(num_scheduled_tokens) num_tokens: int num_tokens_after_padding: int - # [num_reqs] - is_chunked_prefilling: np.ndarray - # [max_num_batched_tokens] + # [num_reqs + 1] + query_start_loc: torch.Tensor + query_start_loc_np: np.ndarray + # [num_reqs] + seq_lens: torch.Tensor + seq_lens_np: np.ndarray + + # [num_tokens_after_padding] input_ids: torch.Tensor - # [max_num_batched_tokens] + # [num_tokens_after_padding] positions: torch.Tensor # layer_name -> Metadata @@ -78,23 +81,34 @@ class InputBatch: cls, num_reqs: int, num_tokens: int, + input_buffers: InputBuffers, device: torch.device, ) -> "InputBatch": assert 0 < num_reqs <= num_tokens req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)] idx_mapping_np = np.arange(num_reqs, dtype=np.int32) - idx_mapping = torch.tensor(idx_mapping_np, device=device) - num_scheduled_tokens = np.full(num_reqs, - num_tokens // num_reqs, - dtype=np.int32) + idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) + num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32) num_scheduled_tokens[-1] += num_tokens % num_reqs - is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_) - input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device) - positions = torch.zeros(num_tokens, dtype=torch.int64, device=device) - attn_metadata = defaultdict(lambda: None) - logits_indices = torch.arange(num_reqs, - dtype=torch.int32, - device=device) + assert int(num_scheduled_tokens.sum()) == num_tokens + + input_buffers.query_start_loc.np[0] = 0 + input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum( + num_scheduled_tokens + ) + input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens + query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1] + query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1] + # seq_len equals to query_len + input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens + input_buffers.seq_lens.np[num_reqs:] = 0 + seq_lens_np = input_buffers.seq_lens.np[:num_reqs] + seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs] + + input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens) + positions = input_buffers.positions.copy_to_gpu(num_tokens) + # attn_metadata = defaultdict(lambda: None) + logits_indices = query_start_loc[1:] - 1 return cls( req_ids=req_ids, num_reqs=num_reqs, @@ -103,10 +117,13 @@ class InputBatch: num_scheduled_tokens=num_scheduled_tokens, num_tokens=num_tokens, num_tokens_after_padding=num_tokens, - is_chunked_prefilling=is_chunked_prefilling, + query_start_loc=query_start_loc, + query_start_loc_np=query_start_loc_np, + seq_lens=seq_lens, + seq_lens_np=seq_lens_np, input_ids=input_ids, positions=positions, - attn_metadata=attn_metadata, + attn_metadata=None, logits_indices=logits_indices, ) @@ -130,14 +147,14 @@ class InputBatch: cache=True, ) def _prepare_inputs( - idx_mapping: np.ndarray, # batch_idx -> req_idx - token_ids: np.ndarray, # [N, max_model_len] - num_computed_tokens: np.ndarray, # [N] - num_scheduled_tokens: np.ndarray, # [B] - input_ids: np.ndarray, # [num_input_tokens] - positions: np.ndarray, # [num_input_tokens] - query_start_loc: np.ndarray, # [B + 1] - seq_lens: np.ndarray, # [B] + idx_mapping: np.ndarray, # batch_idx -> req_idx + token_ids: np.ndarray, # [N, max_model_len] + num_computed_tokens: np.ndarray, # [N] + num_scheduled_tokens: np.ndarray, # [B] + input_ids: np.ndarray, # [num_input_tokens] + positions: np.ndarray, # [num_input_tokens] + query_start_loc: np.ndarray, # [B + 1] + seq_lens: np.ndarray, # [B] ) -> None: num_reqs = num_scheduled_tokens.shape[0] query_start_loc[0] = 0 @@ -161,14 +178,14 @@ def _prepare_inputs( # Pad the inputs for CUDA graphs. # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that - query_start_loc[num_reqs + 1:].fill(cu_num_tokens) + query_start_loc[num_reqs + 1 :].fill(cu_num_tokens) # Fill unused with 0 for full cuda graph mode. seq_lens[num_reqs:].fill(0) def prepare_inputs( idx_mapping: np.ndarray, - prompt_token_ids: np.ndarray, + prefill_token_ids: np.ndarray, num_computed_tokens: np.ndarray, num_scheduled_tokens: np.ndarray, input_ids: CpuGpuBuffer, @@ -176,10 +193,10 @@ def prepare_inputs( query_start_loc: CpuGpuBuffer, seq_lens: CpuGpuBuffer, num_tokens: int, -) -> tuple[np.ndarray, np.ndarray]: +) -> None: _prepare_inputs( idx_mapping, - prompt_token_ids, + prefill_token_ids, num_computed_tokens, num_scheduled_tokens, input_ids.np, @@ -194,11 +211,7 @@ def prepare_inputs( # for full CUDA graph mode. query_start_loc.copy_to_gpu() seq_lens.copy_to_gpu() - - num_reqs = num_scheduled_tokens.shape[0] - max_query_len = int(num_scheduled_tokens.max()) - max_seq_len = int(seq_lens.np[:num_reqs].max()) - return max_query_len, max_seq_len + return @triton.jit @@ -208,21 +221,18 @@ def _combine_last_token_ids_kernel( last_token_ids_ptr, query_start_loc_ptr, seq_lens_ptr, - num_tokens_ptr, + prefill_len_ptr, ): batch_idx = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + batch_idx) seq_len = tl.load(seq_lens_ptr + batch_idx) - num_tokens = tl.load(num_tokens_ptr + req_state_idx) - if seq_len < num_tokens: - # Chunked prefilling. + prefill_len = tl.load(prefill_len_ptr + req_state_idx) + if seq_len <= prefill_len: + # Handling prefill tokens. return last_token_id = tl.load(last_token_ids_ptr + req_state_idx) - if last_token_id == -1: - return - end = tl.load(query_start_loc_ptr + batch_idx + 1) tl.store(input_ids_ptr + end - 1, last_token_id) @@ -233,15 +243,15 @@ def combine_last_token_ids( last_token_ids: torch.Tensor, query_start_loc: torch.Tensor, seq_lens: torch.Tensor, - num_tokens: torch.Tensor, + prefill_len: torch.Tensor, ) -> torch.Tensor: num_reqs = seq_lens.shape[0] - _combine_last_token_ids_kernel[(num_reqs, )]( + _combine_last_token_ids_kernel[(num_reqs,)]( input_ids, idx_mapping, last_token_ids, query_start_loc, seq_lens, - num_tokens, + prefill_len, ) return input_ids diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 70485676bf..824f5ad927 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -3,41 +3,52 @@ import gc import time from copy import deepcopy -from typing import Any, Optional +from typing import Any import numpy as np import torch import torch.nn as nn from vllm.config import VllmConfig -from vllm.distributed import get_tp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, is_pin_memory_available) -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + LogprobsTensors, + ModelRunnerOutput, +) from vllm.v1.sample.sampler import SamplerOutput -from vllm.v1.worker.gpu.async_utils import AsyncOutput -from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec, - init_attn_backend, init_kv_cache) +from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier +from vllm.v1.worker.gpu.attn_utils import ( + build_attn_metadata, + get_kv_cache_spec, + init_attn_backend, + init_kv_cache, +) from vllm.v1.worker.gpu.block_table import BlockTables -from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output, - evenly_split) -from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, - combine_last_token_ids, - prepare_inputs) -from vllm.v1.worker.gpu.sampler import Sampler +from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager +from vllm.v1.worker.gpu.input_batch import ( + InputBatch, + InputBuffers, + combine_last_token_ids, + prepare_inputs, +) +from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin logger = init_logger(__name__) -class GPUModelRunner: - +class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( self, vllm_config: VllmConfig, @@ -61,17 +72,19 @@ class GPUModelRunner: if self.cache_config.cache_dtype != "auto": # Quantized KV cache. self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - self.cache_config.cache_dtype] + self.cache_config.cache_dtype + ] self.is_pooling_model = False self.vocab_size = self.model_config.get_vocab_size() self.max_model_len = self.model_config.max_model_len self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs + self.hidden_size = self.model_config.get_hidden_size() self.use_async_scheduling = self.scheduler_config.async_scheduling - assert self.use_async_scheduling - self.output_copy_stream = torch.cuda.Stream() + self.output_copy_stream = torch.cuda.Stream(self.device) + self.input_prep_event = torch.cuda.Event() self.req_states = RequestState( max_num_reqs=self.max_num_reqs, @@ -84,29 +97,46 @@ class GPUModelRunner: self.input_buffers = InputBuffers( max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, + hidden_size=self.hidden_size, + dtype=self.dtype, device=self.device, pin_memory=self.pin_memory, ) - self.sampler = Sampler() + self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) + + # CUDA graphs. + self.cudagraph_manager = CudaGraphManager( + vllm_config=self.vllm_config, + device=self.device, + ) def get_supported_tasks(self) -> tuple[str]: - return ("generate", ) + return ("generate",) def load_model(self, *args, **kwargs) -> None: time_before_load = time.perf_counter() with DeviceMemoryProfiler() as m: model_loader = get_model_loader(self.vllm_config.load_config) logger.info("Loading model from scratch...") + self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.vllm_config.model_config, ) + if self.lora_config: + self.model = self.load_lora_model( + self.model, + self.vllm_config, + self.device, + ) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - m.consumed_memory / GiB_bytes, - time_after_load - time_before_load) + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + m.consumed_memory / GiB_bytes, + time_after_load - time_before_load, + ) def get_model(self) -> nn.Module: return self.model @@ -143,32 +173,60 @@ class GPUModelRunner: self.device, ) + @torch.inference_mode() def _dummy_run( self, num_tokens: int, *args, - input_batch: Optional[InputBatch] = None, + input_batch: InputBatch | None = None, + skip_attn: bool = True, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if input_batch is None: + num_reqs = min(num_tokens, self.max_num_reqs) input_batch = InputBatch.make_dummy( - num_reqs=min(num_tokens, self.max_num_reqs), + num_reqs=num_reqs, num_tokens=num_tokens, + input_buffers=self.input_buffers, device=self.device, ) + if not skip_attn: + block_tables = self.block_tables.gather_block_tables( + input_batch.idx_mapping + ) + slot_mappings = self.block_tables.compute_slot_mappings( + input_batch.query_start_loc, + input_batch.positions, + ) + attn_metadata = build_attn_metadata( + attn_metadata_builders=self.attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc=self.input_buffers.query_start_loc, + seq_lens=self.input_buffers.seq_lens, + num_computed_tokens_cpu=None, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + input_batch.attn_metadata = attn_metadata - with set_forward_context( + with self.maybe_dummy_run_with_lora( + self.lora_config, input_batch.num_scheduled_tokens + ): + with set_forward_context( input_batch.attn_metadata, self.vllm_config, num_tokens=num_tokens, - ): - hidden_states = self.model( - input_ids=input_batch.input_ids, - positions=input_batch.positions, - ) - sample_hidden_states = hidden_states[input_batch.logits_indices] + ): + hidden_states = self.model( + input_ids=input_batch.input_ids, + positions=input_batch.positions, + ) + sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states + @torch.inference_mode() def _dummy_sampler_run( self, hidden_states: torch.Tensor, @@ -179,35 +237,80 @@ class GPUModelRunner: device=self.device, ) logits = self.model.compute_logits(hidden_states) - self.sampler(logits, sampling_metadata) + self.sampler.sample(logits, sampling_metadata) + @torch.inference_mode() def profile_run(self) -> None: input_batch = InputBatch.make_dummy( num_reqs=self.max_num_reqs, num_tokens=self.max_num_tokens, + input_buffers=self.input_buffers, device=self.device, ) hidden_states, sample_hidden_states = self._dummy_run( self.max_num_tokens, input_batch=input_batch, + skip_attn=True, ) self._dummy_sampler_run(sample_hidden_states) torch.cuda.synchronize() del hidden_states, sample_hidden_states gc.collect() + def reset_mm_cache(self) -> None: + pass + + @torch.inference_mode() + def capture_model(self) -> int: + if not self.cudagraph_manager.needs_capture(): + logger.warning( + "Skipping CUDA graph capture. To turn on CUDA graph capture, " + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) + return 0 + + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + with self.maybe_setup_dummy_loras(self.lora_config): + self.cudagraph_manager.capture( + model=self.model, + input_buffers=self.input_buffers, + block_tables=self.block_tables, + attn_metadata_builders=self.attn_metadata_builders, + kv_cache_config=self.kv_cache_config, + ) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) + return cuda_graph_size + + def warmup_for_prefill(self) -> None: + # For FlashInfer, we would like to execute a dummy prefill run to trigger JIT compilation. + if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()): + self._dummy_run(self.max_num_tokens, skip_attn=False) + torch.cuda.synchronize() + def update_states(self, scheduler_output: SchedulerOutput) -> None: - # for req_id in scheduler_output.preempted_req_ids: - # self.req_states.remove_request(req_id) + for req_id in scheduler_output.preempted_req_ids: + self.req_states.remove_request(req_id) for req_id in scheduler_output.finished_req_ids: self.req_states.remove_request(req_id) # TODO(woosuk): Change SchedulerOutput. req_indices: list[int] = [] cu_num_new_blocks = tuple( - [0] for _ in range(self.block_tables.num_kv_cache_groups)) - new_block_ids = tuple( - [] for _ in range(self.block_tables.num_kv_cache_groups)) + [0] for _ in range(self.block_tables.num_kv_cache_groups) + ) + new_block_ids = tuple([] for _ in range(self.block_tables.num_kv_cache_groups)) overwrite: list[bool] = [] # Add new requests. @@ -215,9 +318,11 @@ class GPUModelRunner: req_id = new_req_data.req_id self.req_states.add_request( req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, + prompt_len=len(new_req_data.prompt_token_ids), + prefill_token_ids=new_req_data.prefill_token_ids, num_computed_tokens=new_req_data.num_computed_tokens, sampling_params=new_req_data.sampling_params, + lora_request=new_req_data.lora_request, ) req_index = self.req_states.req_id_to_index[req_id] @@ -250,21 +355,30 @@ class GPUModelRunner: overwrite=overwrite, ) - def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch: + def prepare_inputs( + self, + scheduler_output: SchedulerOutput, + use_cudagraph: bool, + padded_num_tokens: int | None, + ) -> InputBatch: num_tokens = scheduler_output.total_num_scheduled_tokens assert num_tokens > 0 num_reqs = len(scheduler_output.num_scheduled_tokens) # Decode first, then prefill. # batch_idx -> req_id - req_ids = sorted(scheduler_output.num_scheduled_tokens, - key=scheduler_output.num_scheduled_tokens.get) + req_ids = sorted( + scheduler_output.num_scheduled_tokens, + key=scheduler_output.num_scheduled_tokens.get, + ) num_scheduled_tokens = np.array( - [scheduler_output.num_scheduled_tokens[i] for i in req_ids], - dtype=np.int32) - - # TODO(woosuk): Support CUDA graphs. - num_tokens_after_padding = num_tokens + [scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32 + ) + if use_cudagraph: + assert padded_num_tokens is not None + num_tokens_after_padding = padded_num_tokens + else: + num_tokens_after_padding = num_tokens idx_mapping_list = [ self.req_states.req_id_to_index[req_id] for req_id in req_ids @@ -277,9 +391,9 @@ class GPUModelRunner: # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] block_tables = self.block_tables.gather_block_tables(idx_mapping) - max_query_len, max_seq_len = prepare_inputs( + prepare_inputs( idx_mapping_np, - self.req_states.prompt_token_ids, + self.req_states.prefill_token_ids, self.req_states.num_computed_tokens, num_scheduled_tokens, self.input_buffers.input_ids, @@ -290,10 +404,9 @@ class GPUModelRunner: ) query_start_loc = self.input_buffers.query_start_loc - query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1] - query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1] + query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] + query_start_loc_np = query_start_loc.np[: num_reqs + 1] seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs] - seq_lens_cpu = self.input_buffers.seq_lens.cpu[:num_reqs] seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs] # Some input token ids are directly read from the last sampled tokens. @@ -303,56 +416,33 @@ class GPUModelRunner: self.req_states.last_sampled_tokens, query_start_loc_gpu, seq_lens_gpu, - self.req_states.num_tokens.copy_to_gpu(), + self.req_states.prefill_len.copy_to_gpu(), ) # Compute slot mappings: [num_kv_cache_groups, num_tokens] slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]) + query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens] + ) num_computed_tokens_cpu = torch.from_numpy( - self.req_states.num_computed_tokens[idx_mapping_np]) - - # Whether the request is chunked-prefilling or not. - is_chunked_prefilling = ( - seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np]) + self.req_states.num_computed_tokens[idx_mapping_np] + ) # Logits indices to sample next token from. logits_indices = query_start_loc_gpu[1:] - 1 - num_logits_indices = logits_indices.size(0) # Layer name -> attention metadata. - attn_metadata: dict[str, Any] = {} - kv_cache_groups = self.kv_cache_config.kv_cache_groups - for i, kv_cache_spec in enumerate(kv_cache_groups): - block_table = block_tables[i] - slot_mapping = slot_mappings[i] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc_gpu, - query_start_loc_cpu=query_start_loc_cpu, - seq_lens=seq_lens_gpu, - seq_lens_cpu=seq_lens_cpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - block_table_tensor=block_table, - slot_mapping=slot_mapping, - logits_indices_padded=None, - num_logits_indices=num_logits_indices, - causal=True, - encoder_seq_lens=None, - ) - - attn_metadata_builder = self.attn_metadata_builders[i] - metadata = attn_metadata_builder.build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) - for layer_name in kv_cache_spec.layer_names: - attn_metadata[layer_name] = metadata + attn_metadata = build_attn_metadata( + attn_metadata_builders=self.attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc=self.input_buffers.query_start_loc, + seq_lens=self.input_buffers.seq_lens, + num_computed_tokens_cpu=num_computed_tokens_cpu, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] positions = self.input_buffers.positions.gpu[:num_tokens_after_padding] @@ -364,7 +454,10 @@ class GPUModelRunner: num_scheduled_tokens=num_scheduled_tokens, num_tokens=num_tokens, num_tokens_after_padding=num_tokens_after_padding, - is_chunked_prefilling=is_chunked_prefilling, + query_start_loc=query_start_loc_gpu, + query_start_loc_np=query_start_loc_np, + seq_lens=seq_lens_gpu, + seq_lens_np=seq_lens_np, input_ids=input_ids, positions=positions, attn_metadata=attn_metadata, @@ -375,102 +468,221 @@ class GPUModelRunner: self, hidden_states: torch.Tensor, input_batch: InputBatch, + sampling_metadata: SamplingMetadata, ) -> SamplerOutput: sample_hidden_states = hidden_states[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states) - pos = input_batch.positions[input_batch.logits_indices] - idx_mapping_np = input_batch.idx_mapping_np - num_reqs = logits.shape[0] - - # When the batch size is large enough, use DP sampler. - tp_group = get_tp_group() - tp_size = tp_group.world_size - n = (num_reqs + tp_size - 1) // tp_size - use_dp_sampler = tp_size > 1 and n > 32 # TODO(woosuk): Tune. - if use_dp_sampler: - # NOTE(woosuk): Make sure that no rank gets zero requests. - tp_rank = tp_group.rank - start, end = evenly_split(num_reqs, tp_size, tp_rank) - logits = logits[start:end] - pos = pos[start:end] - idx_mapping_np = idx_mapping_np[start:end] - - sampling_metadata = self.req_states.make_sampling_metadata( - idx_mapping_np, pos) - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) - - needs_prompt_logprobs = np.any( - self.req_states.needs_prompt_logprobs[idx_mapping_np]) - assert not needs_prompt_logprobs - - if use_dp_sampler: - # All-gather the outputs. - sampler_output = all_gather_sampler_output( - sampler_output, - num_reqs, - tp_size, - ) + sampler_output = self.sampler.sample(logits, sampling_metadata) return sampler_output + def compute_prompt_logprobs( + self, + hidden_states: torch.Tensor, + input_batch: InputBatch, + ) -> dict[str, LogprobsTensors]: + idx_mapping_np = input_batch.idx_mapping_np + needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np] + if not np.any(needs_prompt_logprobs): + # No request asks for prompt logprobs. + return {} + + num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np] + prompt_lens = self.req_states.prompt_len[idx_mapping_np] + # NOTE(woosuk): -1 because the last prompt token's hidden state is not + # needed for prompt logprobs. + includes_prompt = num_computed_tokens < prompt_lens - 1 + # NOTE(woosuk): If the request was resumed after preemption, its prompt + # logprobs must have been computed before preemption. Skip. + resumed_after_prompt = ( + prompt_lens < self.req_states.prefill_len.np[idx_mapping_np] + ) + needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt + if not np.any(needs_prompt_logprobs): + return {} + + # Just to be safe, clone the input ids. + n = input_batch.num_tokens + # Shift the input ids by one. + token_ids = torch.empty_like(input_batch.input_ids[:n]) + token_ids[: n - 1] = input_batch.input_ids[1:n] + # To avoid out-of-bound access, set the last token id to 0. + token_ids[n - 1] = 0 + + # Handle chunked prompts. + seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs] + is_prompt_chunked = seq_lens < prompt_lens + prefill_token_ids = self.req_states.prefill_token_ids + query_start_loc = self.input_buffers.query_start_loc.np + for i, req_id in enumerate(input_batch.req_ids): + if not needs_prompt_logprobs[i]: + continue + if not is_prompt_chunked[i]: + continue + # The prompt is chunked. Get the next prompt token. + req_idx = input_batch.idx_mapping_np[i] + next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]]) + idx = int(query_start_loc[i + 1] - 1) + # Set the next prompt token. + # NOTE(woosuk): This triggers a GPU operation. + token_ids[idx] = next_prompt_token + + # NOTE(woosuk): We mask out logprobs for negative tokens. + prompt_logprobs, prompt_ranks = compute_prompt_logprobs( + torch.relu(token_ids), + hidden_states[:n], + self.model.compute_logits, + ) + prompt_logprobs[:, 0].masked_fill_(token_ids < 0, 0) + + prompt_token_ids = token_ids.unsqueeze(-1) + prompt_logprobs_dict: dict[str, LogprobsTensors] = {} + for i, req_id in enumerate(input_batch.req_ids): + if not needs_prompt_logprobs[i]: + continue + + start_idx = query_start_loc[i] + end_idx = query_start_loc[i + 1] + assert start_idx < end_idx, ( + f"start_idx ({start_idx}) >= end_idx ({end_idx})" + ) + logprobs = LogprobsTensors( + logprob_token_ids=prompt_token_ids[start_idx:end_idx], + logprobs=prompt_logprobs[start_idx:end_idx], + selected_token_ranks=prompt_ranks[start_idx:end_idx], + ) + + req_extra_data = self.req_states.extra_data[req_id] + prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs + if is_prompt_chunked[i]: + # Prompt is chunked. Do not return the logprobs yet. + prompt_logprobs_list.append(logprobs) + continue + + if prompt_logprobs_list: + # Merge the in-progress logprobs. + prompt_logprobs_list.append(logprobs) + logprobs = LogprobsTensors( + logprob_token_ids=torch.cat( + [x.logprob_token_ids for x in prompt_logprobs_list] + ), + logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]), + selected_token_ranks=torch.cat( + [x.selected_token_ranks for x in prompt_logprobs_list] + ), + ) + prompt_logprobs_list.clear() + + prompt_logprobs_dict[req_id] = logprobs + return prompt_logprobs_dict + def postprocess( self, sampler_output: SamplerOutput, + sampling_metadata: SamplingMetadata, + prompt_logprobs_dict: dict[str, LogprobsTensors], input_batch: InputBatch, - ) -> AsyncOutput: + ) -> AsyncOutput | ModelRunnerOutput: # Store the last sampled token ids. self.req_states.last_sampled_tokens[input_batch.idx_mapping] = ( - sampler_output.sampled_token_ids) - + sampler_output.sampled_token_ids + ) # Get the number of sampled tokens. # 0 if chunked-prefilling, 1 if not. - is_chunked_prefilling = input_batch.is_chunked_prefilling + idx_mapping_np = input_batch.idx_mapping_np + is_chunked_prefilling = ( + input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np] + ) num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) # Increment the number of tokens. - idx_mapping_np = input_batch.idx_mapping_np - self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens + self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens # Increment the number of computed tokens. self.req_states.num_computed_tokens[idx_mapping_np] += ( - input_batch.num_scheduled_tokens) + input_batch.num_scheduled_tokens + ) model_runner_output = ModelRunnerOutput( req_ids=input_batch.req_ids, + req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, sampled_token_ids=None, - num_sampled_tokens=num_sampled_tokens, logprobs=None, - prompt_logprobs_dict={}, + prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], kv_connector_output=None, num_nans_in_logits=None, ) - return AsyncOutput( + async_output = AsyncOutput( model_runner_output=model_runner_output, sampler_output=sampler_output, + num_sampled_tokens=num_sampled_tokens, copy_stream=self.output_copy_stream, ) + if self.use_async_scheduling: + return async_output + return async_output.get_output() + @torch.inference_mode() def execute_model( self, scheduler_output: SchedulerOutput, - ) -> AsyncOutput: - self.update_states(scheduler_output) - if scheduler_output.total_num_scheduled_tokens == 0: - return EMPTY_MODEL_RUNNER_OUTPUT + intermediate_tensors: Any | None = None, + ) -> AsyncOutput | ModelRunnerOutput: + assert intermediate_tensors is None - input_batch = self.prepare_inputs(scheduler_output) - num_tokens = input_batch.num_tokens_after_padding - - with set_forward_context( - input_batch.attn_metadata, - self.vllm_config, - num_tokens=num_tokens, + with async_barrier( + self.input_prep_event if self.use_async_scheduling else None ): - hidden_states = self.model( - input_ids=input_batch.input_ids, - positions=input_batch.positions, + self.update_states(scheduler_output) + if scheduler_output.total_num_scheduled_tokens == 0: + return EMPTY_MODEL_RUNNER_OUTPUT + + padded_num_tokens = self.cudagraph_manager.get_cudagraph_size( + scheduler_output + ) + use_cudagraph = padded_num_tokens is not None + input_batch = self.prepare_inputs( + scheduler_output, + use_cudagraph, + padded_num_tokens, + ) + pos = input_batch.positions[input_batch.logits_indices] + idx_mapping_np = input_batch.idx_mapping_np + sampling_metadata = self.req_states.make_sampling_metadata( + idx_mapping_np, pos ) - sampler_output = self.sample(hidden_states, input_batch) - return self.postprocess(sampler_output, input_batch) + if self.lora_config: + # Activate LoRA adapters. + lora_inputs = self.req_states.make_lora_inputs( + input_batch.req_ids, + input_batch.idx_mapping_np, + input_batch.num_scheduled_tokens, + ) + self._set_active_loras(*lora_inputs) + + # Run model. + if use_cudagraph: + # Run CUDA graph. + # NOTE(woosuk): Here, we don't need to pass the input tensors, + # because they are already copied to the CUDA graph input buffers. + hidden_states = self.cudagraph_manager.run(padded_num_tokens) + else: + with set_forward_context( + input_batch.attn_metadata, + self.vllm_config, + num_tokens=input_batch.num_tokens_after_padding, + ): + # Run PyTorch model in eager mode. + hidden_states = self.model( + input_ids=input_batch.input_ids, + positions=input_batch.positions, + ) + + sampler_output = self.sample(hidden_states, input_batch, sampling_metadata) + prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) + output = self.postprocess( + sampler_output, + sampling_metadata, + prompt_logprobs_dict, + input_batch, + ) + return output diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index 65aadf9654..4e8980f336 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -1,61 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch -import torch.nn as nn import triton import triton.language as tl -from vllm.config import LogprobsMode +from vllm.config.model import LogprobsMode from vllm.v1.outputs import LogprobsTensors, SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p -from vllm.v1.worker.gpu.states import SamplingMetadata - -_SAMPLING_EPS = 1e-5 -class Sampler(nn.Module): - +class Sampler: def __init__( self, - logprobs_mode: LogprobsMode = "processed_logprobs", + logprobs_mode: LogprobsMode = "raw_logprobs", ): - super().__init__() - assert logprobs_mode == "processed_logprobs" + if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: + raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") self.logprobs_mode = logprobs_mode - def forward( + def sample_token( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: - # Divide logits by temperature, in FP32. - logits = apply_temperature(logits, sampling_metadata.temperature) - - # Apply top_k and/or top_p. + return_logits: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + is_greedy = sampling_metadata.temperature == 0 + temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits = logits / temp.view(-1, 1) logits = apply_top_k_top_p( - logits, - sampling_metadata.top_k, - sampling_metadata.top_p, + logits, sampling_metadata.top_k, sampling_metadata.top_p ) - # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - # Sample the next token (int64). + sampled = gumbel_sample( probs, sampling_metadata.temperature, sampling_metadata.seeds, sampling_metadata.pos, ) + sampled = sampled.to(torch.int64) + return sampled, logits if return_logits else None - logprobs_tensors = None - num_logprobs = sampling_metadata.max_num_logprobs - if num_logprobs is not None: - logprobs_tensors = compute_logprobs( + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + if sampling_metadata.max_num_logprobs is not None: + if self.logprobs_mode == "processed_logprobs": + sampled, logits = self.sample_token( + logits, sampling_metadata, return_logits=True + ) + else: + assert self.logprobs_mode == "raw_logprobs" + sampled, _ = self.sample_token( + logits, sampling_metadata, return_logits=False + ) + + logprobs_tensors = compute_topk_logprobs( logits, - num_logprobs, + sampling_metadata.max_num_logprobs, sampled, ) + else: + sampled, _ = self.sample_token( + logits, sampling_metadata, return_logits=False + ) + logprobs_tensors = None # These are GPU tensors. sampler_output = SamplerOutput( @@ -69,60 +84,7 @@ class Sampler(nn.Module): @triton.jit -def _apply_temp_kernel( - logits, # bf16[batch_size, vocab_size] - logits_stride, - output, # fp32[batch_size, vocab_size] - output_stride, - temperature, - vocab_size, - BLOCK_SIZE: tl.constexpr, - EPSILON: tl.constexpr, -): - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - temp = tl.load(temperature + batch_idx) - if temp < EPSILON: - # Greedy sampling. Don't apply temperature. - # NOTE(woosuk): In this case, we assume that its logprobs are not used. - temp = 1.0 - - offset = tl.arange(0, BLOCK_SIZE) - block = block_idx * BLOCK_SIZE + offset - - # Load the logits. - x = tl.load(logits + batch_idx * logits_stride + block, - mask=block < vocab_size) - x = x.to(tl.float32) - x = x / temp - tl.store(output + batch_idx * output_stride + block, - x, - mask=block < vocab_size) - - -def apply_temperature( - logits: torch.Tensor, - temperature: torch.Tensor, -) -> torch.Tensor: - batch_size, vocab_size = logits.shape - output = torch.empty_like(logits, dtype=torch.float32) - BLOCK_SIZE = 8192 - _apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))]( - logits, - logits.stride(0), - output, - output.stride(0), - temperature, - vocab_size, - BLOCK_SIZE=BLOCK_SIZE, - EPSILON=_SAMPLING_EPS, - ) - return output - - -@triton.jit -def _apply_gumbel_kernel( +def _gumbel_sample_kernel( probs_ptr, probs_stride, seeds_ptr, @@ -130,18 +92,17 @@ def _apply_gumbel_kernel( temp_ptr, vocab_size, BLOCK_SIZE: tl.constexpr, - EPSILON: tl.constexpr, ): req_idx = tl.program_id(0) temp = tl.load(temp_ptr + req_idx) - if temp < EPSILON: + if temp == 0.0: # Greedy sampling. Don't apply gumbel noise. return - seed = tl.load(seeds_ptr + req_idx).to(tl.uint64) - pos = tl.load(pos_ptr + req_idx).to(tl.uint64) - gumbel_seed = seed ^ (pos * 0x9E3779B97F4A7C15) + seed = tl.load(seeds_ptr + req_idx) + pos = tl.load(pos_ptr + req_idx) + gumbel_seed = tl.randint(seed, pos) block_id = tl.program_id(1) r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -153,42 +114,33 @@ def _apply_gumbel_kernel( q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q)) q = -1.0 * q - p = tl.load(probs_ptr + req_idx * probs_stride + r_offset, - mask=r_offset < vocab_size) + p = tl.load( + probs_ptr + req_idx * probs_stride + r_offset, mask=r_offset < vocab_size + ) p = p / q - - tl.store(probs_ptr + req_idx * probs_stride + r_offset, - p, - mask=r_offset < vocab_size) + tl.store( + probs_ptr + req_idx * probs_stride + r_offset, p, mask=r_offset < vocab_size + ) def gumbel_sample( - # fp32[num_reqs, vocab_size] - probs: torch.Tensor, - # fp32[num_reqs] - temperature: torch.Tensor, - # int64[num_reqs] - seeds: torch.Tensor, - # int64[num_reqs] - pos: torch.Tensor, + probs: torch.Tensor, # [num_reqs, vocab_size] + temperature: torch.Tensor, # [num_reqs] + seed: torch.Tensor, # [num_reqs] + pos: torch.Tensor, # [num_reqs] ) -> torch.Tensor: - num_reqs = probs.shape[0] - vocab_size = probs.shape[1] - - # Update the probs in-place. - BLOCK_SIZE = 8192 - _apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))]( + num_reqs, vocab_size = probs.shape + _gumbel_sample_kernel[(num_reqs,)]( probs, probs.stride(0), - seeds, + seed, pos, temperature, vocab_size, - BLOCK_SIZE, - EPSILON=_SAMPLING_EPS, + BLOCK_SIZE=8192, # type: ignore ) - # Sample the next token. - return probs.argmax(dim=-1).view(-1) + sampled = probs.argmax(dim=-1) + return sampled @triton.jit @@ -208,54 +160,31 @@ def _topk_log_softmax_kernel( max_val = float("-inf") for i in range(0, vocab_size, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) - l = tl.load(row_ptr + block, - mask=block < vocab_size, - other=float("-inf")) - max_val = tl.max(tl.maximum(l, max_val)) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) + max_val = tl.max(tl.maximum(logits, max_val)) + max_val = max_val.to(tl.float32) se = 0.0 for i in range(0, vocab_size, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) - l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) - e = tl.exp(l - max_val) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) + # NOTE(woosuk): Make sure that logits and all following operations are in float32. + logits = logits.to(tl.float32) + e = tl.exp(logits - max_val) e = tl.where(block < vocab_size, e, 0.0) se += tl.sum(e) lse = tl.log(se) k_offset = tl.arange(0, PADDED_TOPK) k_mask = k_offset < topk - topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask) + topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0) - l = tl.load(row_ptr + topk_ids, mask=k_mask) - o = l - max_val - lse + logits = tl.load(row_ptr + topk_ids, mask=k_mask) + logits = logits.to(tl.float32) + o = logits - max_val - lse tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) -def compute_topk_logprobs( - logits: torch.Tensor, - topk_ids: torch.Tensor, -) -> torch.Tensor: - batch_size, vocab_size = logits.shape - topk = topk_ids.shape[1] - output = torch.empty( - batch_size, - topk, - dtype=torch.float32, - device=logits.device, - ) - _topk_log_softmax_kernel[(batch_size, )]( - output, - logits, - logits.stride(0), - topk_ids, - topk, - vocab_size, - BLOCK_SIZE=1024, - PADDED_TOPK=triton.next_power_of_2(topk), - ) - return output - - @triton.jit def _ranks_kernel( output_ptr, @@ -274,14 +203,39 @@ def _ranks_kernel( n = 0 for i in range(0, vocab_size, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) - l = tl.load(row_ptr + block, - mask=block < vocab_size, - other=float("-inf")) - n += tl.sum((l > x).to(tl.int32)) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) + n += tl.sum((logits > x).to(tl.int32)) tl.store(output_ptr + req_idx, n) -def compute_logprobs( +def compute_token_logprobs( + logits: torch.Tensor, + token_ids: torch.Tensor, +) -> torch.Tensor: + batch_size = logits.shape[0] + vocab_size = logits.shape[1] + token_ids = token_ids.to(torch.int64) + num_logprobs = token_ids.shape[1] + logprobs = torch.empty( + batch_size, + num_logprobs, + dtype=torch.float32, + device=logits.device, + ) + _topk_log_softmax_kernel[(batch_size,)]( + logprobs, + logits, + logits.stride(0), + token_ids, + num_logprobs, + vocab_size, + BLOCK_SIZE=1024, # type: ignore + PADDED_TOPK=triton.next_power_of_2(num_logprobs), + ) + return logprobs + + +def compute_topk_logprobs( logits: torch.Tensor, num_logprobs: int, sampled_token_ids: torch.Tensor, @@ -293,31 +247,56 @@ def compute_logprobs( else: topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices logprob_token_ids = torch.cat( - (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1) + (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1 + ) # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full # logprobs tensor. Instead, we only compute and return the logprobs of # the topk + 1 tokens. - logprobs = compute_topk_logprobs( - logits, - logprob_token_ids, - ) - + logprobs = compute_token_logprobs(logits, logprob_token_ids) token_ranks = torch.empty( batch_size, dtype=torch.int64, device=logits.device, ) - _ranks_kernel[(batch_size, )]( + _ranks_kernel[(batch_size,)]( token_ranks, logits, logits.stride(0), sampled_token_ids, vocab_size, - BLOCK_SIZE=8192, + BLOCK_SIZE=8192, # type: ignore ) return LogprobsTensors( logprob_token_ids=logprob_token_ids, logprobs=logprobs, selected_token_ranks=token_ranks, ) + + +def compute_prompt_logprobs( + prompt_token_ids: torch.Tensor, + prompt_hidden_states: torch.Tensor, + logits_fn: Callable[[torch.Tensor], torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: + # Since materializing the full prompt logits can take too much memory, + # we compute it in chunks. + CHUNK_SIZE = 1024 + logprobs = [] + ranks = [] + prompt_token_ids = prompt_token_ids.to(torch.int64) + for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE): + end_idx = start_idx + CHUNK_SIZE + # NOTE(woosuk): logits_fn can be slow because it involves all-gather. + prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx]) + prompt_logprobs = compute_topk_logprobs( + prompt_logits, + 0, # num_logprobs + prompt_token_ids[start_idx:end_idx], + ) + logprobs.append(prompt_logprobs.logprobs) + ranks.append(prompt_logprobs.selected_token_ranks) + + logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0] + ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0] + return logprobs, ranks diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 7980b3d159..23481e8229 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -1,21 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass -from typing import Optional +from dataclasses import dataclass, field import numpy as np import torch +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams +from vllm.v1.outputs import LogprobsTensors from vllm.v1.utils import CpuGpuBuffer _NP_INT64_MIN = np.iinfo(np.int64).min _NP_INT64_MAX = np.iinfo(np.int64).max +NO_LORA_ID = 0 @dataclass class SamplingMetadata: - temperature: torch.Tensor top_p: torch.Tensor | None @@ -36,12 +37,14 @@ class SamplingMetadata: assert num_reqs > 0 temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) temperature[0] = 0.5 - top_p = torch.ones(num_reqs, dtype=torch.float32, device=device) - top_p[0] = 0.99 - top_k = torch.ones(num_reqs, dtype=torch.int32, device=device) + # TODO(woosuk): Use top-p and top-k for dummy sampler. + # Currently, they are disabled because of memory usage. + top_p = None + top_k = None seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) max_num_logprobs = 20 + return cls( temperature=temperature, top_p=top_p, @@ -53,7 +56,6 @@ class SamplingMetadata: class RequestState: - def __init__( self, max_num_reqs: int, @@ -73,15 +75,15 @@ class RequestState: self.req_id_to_index: dict[str, int] = {} self.index_to_req_id: dict[int, str] = {} self.free_indices = list(range(max_num_reqs)) + self.extra_data: dict[str, ExtraData] = {} - # NOTE(woosuk): Strictly speaking, it contains prompt + some output - # because of preemption. - self.prompt_token_ids = np.zeros( + self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32) + self.prefill_token_ids = np.zeros( (self.max_num_reqs, self.max_model_len), dtype=np.int32, ) - self.num_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) + self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) # Last sampled tokens. @@ -92,6 +94,10 @@ class RequestState: device=device, ) + # LoRA. + self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) + self.lora_ids.fill(NO_LORA_ID) + # Sampling parameters. self.temperature = self._make_param(self.max_num_reqs, torch.float32) self.top_p = self._make_param(self.max_num_reqs, torch.float32) @@ -104,16 +110,12 @@ class RequestState: self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) def _make_param(self, size: int, dtype: torch.dtype) -> "Param": - return Param(size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory) + return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory) def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer(size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory) + return CpuGpuBuffer( + size, dtype=dtype, device=self.device, pin_memory=self.pin_memory + ) @property def num_reqs(self) -> int: @@ -122,23 +124,32 @@ class RequestState: def add_request( self, req_id: str, - prompt_token_ids: list[int], + prompt_len: int, + prefill_token_ids: list[int], num_computed_tokens: int, sampling_params: SamplingParams, + lora_request: LoRARequest | None, ) -> None: - assert len(self.free_indices) > 0 + assert len(self.free_indices) > 0, "No free indices" req_idx = self.free_indices.pop() self.req_id_to_index[req_id] = req_idx self.index_to_req_id[req_idx] = req_id + self.extra_data[req_id] = ExtraData(lora_request) - # NOTE(woosuk): Strictly speaking, "prompt_len" here may include - # output tokens, if the request is resumed from preemption. - prompt_len = len(prompt_token_ids) - self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids - self.num_tokens.np[req_idx] = prompt_len + self.prompt_len[req_idx] = prompt_len + prefill_len = len(prefill_token_ids) + assert prefill_len >= prompt_len, ( + f"prefill_len {prefill_len} < prompt_len {prompt_len}" + ) + self.prefill_len.np[req_idx] = prefill_len + self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids + self.num_tokens[req_idx] = prefill_len self.num_computed_tokens[req_idx] = num_computed_tokens - # TODO(woosuk): Optimize. - self.last_sampled_tokens[req_idx].fill_(-1) + + if lora_request is not None: + self.lora_ids[req_idx] = lora_request.lora_int_id + else: + self.lora_ids[req_idx] = NO_LORA_ID self.temperature.np[req_idx] = sampling_params.temperature self.top_p.np[req_idx] = sampling_params.top_p @@ -165,6 +176,7 @@ class RequestState: self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs def remove_request(self, req_id: str) -> None: + self.extra_data.pop(req_id, None) req_idx = self.req_id_to_index.pop(req_id, None) if req_idx is None: # Request not found. @@ -205,9 +217,25 @@ class RequestState: max_num_logprobs=max_num_logprobs, ) + def make_lora_inputs( + self, + req_ids: list[str], + idx_mapping: np.ndarray, + num_scheduled_tokens: np.ndarray, + ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: + lora_ids = self.lora_ids[idx_mapping] + prompt_lora_mapping = tuple(lora_ids) + token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens)) + + active_lora_requests: set[LoRARequest] = set() + for req_id in req_ids: + lora_request = self.extra_data[req_id].lora_request + if lora_request is not None: + active_lora_requests.add(lora_request) + return prompt_lora_mapping, token_lora_mapping, active_lora_requests + class Param: - def __init__( self, size: int, @@ -227,3 +255,9 @@ class Param: n = x.shape[0] self.buffer.np[:n] = x return self.buffer.copy_to_gpu(n) + + +@dataclass +class ExtraData: + lora_request: LoRARequest | None + in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 9ca00366e5..47b509e683 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -42,6 +42,7 @@ from vllm.v1.outputs import ( ModelRunnerOutput, ) from vllm.v1.utils import report_usage_stats + # from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu.model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -495,6 +496,8 @@ class Worker(WorkerBase): self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: + return self.model_runner.execute_model(scheduler_output) + intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens