Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-31 20:41:38 -07:00
parent c11d1e6781
commit 22771e5d83
4 changed files with 219 additions and 224 deletions

View File

@ -156,8 +156,8 @@ class BlockTables:
self,
cu_num_tokens: torch.Tensor,
pos: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, ...]:
num_tokens = pos.shape[0]
num_reqs = cu_num_tokens.shape[0] - 1
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_reqs + 1, num_groups)](

View File

@ -3,10 +3,8 @@
from dataclasses import dataclass
from typing import Any, Optional
import numba
import numpy as np
import torch
from numba import types
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@ -35,109 +33,3 @@ class InputBatch:
spec_decode_metadata: Optional[SpecDecodeMetadata]
logits_indices: torch.Tensor
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
types.int64[:], # positions
)
],
nopython=True,
cache=True,
)
def prepare_inputs(
# Inputs
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
# Outputs
input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
positions: np.ndarray, # [num_input_tokens]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
start = num_computed_tokens[req_idx]
end = start + num_scheduled_tokens[i]
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + num_scheduled_tokens[i]
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_spec_decode(
# Inputs
query_start_loc: np.ndarray, # [B + 1]
num_draft_tokens: np.ndarray, # [B]
# Outputs
cu_num_draft_tokens: np.ndarray, # [B]
logits_indices: np.ndarray, # [N + B]
target_logits_indices: np.ndarray, # [N]
bonus_logits_indices: np.ndarray, # [B]
) -> int: # N
# Inputs:
# query_start_loc: [ 0, 4, 104, 107, 207, 209]
# num_draft_tokens: [ 3, 0, 2, 0, 1]
# Outputs:
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
# return: 6 (total number of draft tokens)
cu_num_draft = 0
cu_num_sample = 0
num_reqs = num_draft_tokens.shape[0]
for i in range(num_reqs):
q_end_idx = query_start_loc[i + 1]
draft_len = num_draft_tokens[i]
# The last draft_len + 1 query tokens are used for sampling.
sample_len = draft_len + 1
sample_start_idx = cu_num_sample
sample_end_idx = sample_start_idx + sample_len
logits_indices[sample_start_idx:sample_end_idx] = (np.arange(
q_end_idx - sample_len, q_end_idx))
# For each query, the first draft_len tokens need target logits for
# rejection sampling. The draft_len + 1th token is used for bonus token.
draft_start_idx = cu_num_draft
draft_end_idx = draft_start_idx + draft_len
target_logits_indices[draft_start_idx:draft_end_idx] = (np.arange(
sample_start_idx, sample_end_idx - 1))
bonus_logits_indices[i] = sample_end_idx - 1
cu_num_draft += draft_len
cu_num_draft_tokens[i] = cu_num_draft
cu_num_sample += sample_len
return cu_num_draft

View File

@ -77,9 +77,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_block_table import BlockTables
from vllm.v1.worker.gpu_input_batch import (InputBatch, prepare_inputs,
prepare_spec_decode)
from vllm.v1.worker.gpu_worker_states import RequestState
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_worker_states import RequestState, prepare_inputs
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@ -233,24 +232,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Persistent buffers for CUDA graphs.
self.input_ids = self._make_buffer(self.max_num_tokens,
dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
self.positions = self._make_buffer(self.max_num_tokens,
dtype=torch.int64)
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
dtype=torch.int32)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
self.cu_num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.spec_logits_indices = self._make_buffer(self.max_num_tokens +
self.max_num_reqs,
dtype=torch.int32)
self.target_logits_indices = self._make_buffer(self.max_num_tokens,
dtype=torch.int32)
self.bonus_logits_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
@ -543,8 +535,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# batch_idx -> req_id
req_ids = sorted(scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get)
# req_id -> batch_idx
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
# batch_idx -> req_idx
idx_mapping_list = [
self.requests.req_id_to_index[req_id] for req_id in req_ids
@ -552,49 +542,50 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = self.idx_mapping.np[:num_reqs]
idx_mapping = self.idx_mapping.copy_to_gpu(num_reqs)
# req_id -> batch_idx
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
block_tables = self.block_tables.compute_block_tables(idx_mapping)
# Get the number of scheduled tokens for each request.
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)
prepare_inputs(
idx_mapping=idx_mapping_np,
token_ids=self.requests.token_ids.np,
num_computed_tokens=self.requests.num_computed_tokens.np,
num_scheduled_tokens=num_scheduled_tokens,
input_ids=self.input_ids.np,
query_start_loc=self.query_start_loc.np,
seq_lens=self.seq_lens.np,
positions=self.positions.np,
idx_mapping_np,
self.requests.token_ids.np,
self.requests.num_computed_tokens.np,
num_scheduled_tokens,
self.input_ids.np,
self.query_start_loc.np,
self.seq_lens.np,
self.positions.np,
)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Prepare the attention metadata.
self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
self.seq_lens.copy_to_gpu()
seq_lens = self.seq_lens.gpu[:num_reqs]
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
# Copy the tensors to the GPU.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
self.query_start_loc.copy_to_gpu()
self.seq_lens.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
max_query_len = int(num_scheduled_tokens.max())
seq_lens = self.seq_lens.gpu[:num_reqs]
max_seq_len = int(self.seq_lens.np[:num_reqs].max())
# Compute the slot mappings on GPUs.
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, self.positions.gpu, total_num_scheduled_tokens)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
self._calc_mrope_positions(req_ids, num_scheduled_tokens)
# Optimization: To avoid gather and scatter, copy the whole M-RoPE
# tensor from CPU to GPU although only a part of it is used.
self.mrope_positions.copy_to_gpu()
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
@ -603,19 +594,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for i, req_id in enumerate(req_ids):
draft_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if draft_token_ids:
num_draft_tokens[i] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens)
spec_decode_metadata = self._prepare_spec_decode_metadata(
req_ids,
scheduler_output.scheduled_spec_decode_tokens,
query_start_loc,
)
logits_indices = spec_decode_metadata.logits_indices
logits_indices_padded = None
@ -643,9 +630,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
)
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, self.positions.gpu[:total_num_scheduled_tokens])
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
@ -689,7 +673,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
@ -734,7 +718,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np=idx_mapping_np,
num_reqs=num_reqs,
total_num_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
max_query_len=max_query_len,
attn_metadata=attn_metadata,
spec_decode_metadata=spec_decode_metadata,
spec_decode_common_attn_metadata=spec_decode_common_attn_metadata,
@ -836,17 +820,44 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
return common_prefix_len if use_cascade else 0
def _calc_mrope_positions(self, input_batch: InputBatch):
mrope_pos_ptr = 0
for i, req_id in enumerate(input_batch.req_ids):
req = self.requests[req_id]
assert req.mrope_positions is not None
def _prepare_spec_decode_metadata(
self,
req_ids: list[str],
req_id_to_draft_token_ids: dict[str, list[int]],
query_start_loc: torch.Tensor,
) -> SpecDecodeMetadata:
# Get the number of draft tokens for each request.
num_reqs = len(req_ids)
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for i, req_id in enumerate(req_ids):
draft_token_ids = req_id_to_draft_token_ids.get(req_id)
if draft_token_ids:
num_draft_tokens[i] = len(draft_token_ids)
np.cumsum(num_draft_tokens,
dtype=np.int32,
out=self.cu_num_draft_tokens.np[:num_reqs])
cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs)
return self.requests.make_spec_decode_metadata(
query_start_loc,
cu_num_draft_tokens,
cu_num_draft_tokens.np[:num_reqs],
self.input_ids.gpu,
)
num_computed_tokens = \
self.requests.num_computed_tokens_cpu[i]
num_scheduled_tokens = \
input_batch.num_scheduled_tokens[i]
num_prompt_tokens = len(req.prompt_token_ids)
def _calc_mrope_positions(
self,
req_ids: list[str],
query_lens: np.ndarray,
):
mrope_pos_ptr = 0
for i, req_id in enumerate(req_ids):
req_idx = self.requests.req_id_to_index[req_id]
req_data = self.requests.req_data[req_idx]
assert req_data.mrope_positions is not None
num_computed_tokens = self.requests.num_computed_tokens.np[req_idx]
num_scheduled_tokens = query_lens[i]
num_prompt_tokens = self.requests.num_prompt_tokens.np[req_idx]
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
@ -867,7 +878,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
src_end = num_computed_tokens + prompt_part_len
self.mrope_positions.cpu[:, dst_start:dst_end] = (
req.mrope_positions[:, src_start:src_end])
req_data.mrope_positions[:, src_start:src_end])
mrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
@ -878,49 +889,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions.np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
mrope_position_delta=req_data.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
mrope_pos_ptr += completion_part_len
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
) -> SpecDecodeMetadata:
num_reqs = num_draft_tokens.shape[0]
total_num_draft_tokens = prepare_spec_decode(
self.query_start_loc.np,
num_draft_tokens,
self.cu_num_draft_tokens.np,
self.logits_indices.np,
self.target_logits_indices.np,
self.bonus_logits_indices.np,
)
cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs)
logits_indices = self.logits_indices.copy_to_gpu(
num_reqs + total_num_draft_tokens)
target_logits_indices = self.target_logits_indices.copy_to_gpu(
total_num_draft_tokens)
bonus_logits_indices = self.bonus_logits_indices.copy_to_gpu(num_reqs)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids.gpu[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
return metadata
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
@ -1353,7 +1328,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_input_tokens]
positions = self.mrope_positions[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
@ -2117,7 +2092,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens]
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions.gpu[:num_tokens]

View File

@ -5,10 +5,12 @@
from dataclasses import dataclass
from typing import Optional, Union
import numba
import numpy as np
import torch
import triton
import triton.language as tl
from numba import types
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
@ -18,6 +20,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@dataclass
@ -158,6 +161,10 @@ class RequestState:
is_scalar=num_cols == 1,
)
@property
def num_cached_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
@ -292,9 +299,43 @@ class RequestState:
logitsprocs=None,
)
@property
def num_cached_reqs(self) -> int:
return len(self.req_id_to_index)
def make_spec_decode_metadata(
self,
query_start_loc: torch.Tensor,
cu_num_draft_tokens: torch.Tensor,
cu_num_draft_tokens_np: np.ndarray,
input_ids: torch.Tensor,
) -> SpecDecodeMetadata:
batch_size = query_start_loc.shape[0] - 1
total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1]
logits_indices = torch.empty(total_num_draft_tokens + batch_size,
dtype=torch.int32,
device=self.device)
target_logits_indices = torch.empty(total_num_draft_tokens,
dtype=torch.int32,
device=self.device)
bonus_logits_indices = torch.empty(batch_size,
dtype=torch.int32,
device=self.device)
_prepare_spec_decode_kernel[(batch_size, )](
query_start_loc,
cu_num_draft_tokens,
logits_indices,
target_logits_indices,
bonus_logits_indices,
BLOCK_SIZE=triton.next_power_of_2(32 + 1),
)
draft_token_ids = input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
return SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=cu_num_draft_tokens_np.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@triton.jit
@ -333,3 +374,90 @@ def _make_sampling_metadata_kernel(
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
def _prepare_spec_decode_kernel(
query_start_loc, # [B + 1]
cu_num_draft_tokens, # [B]
logits_indices, # [N + B]
target_logits_indices, # [N]
bonus_logits_indices, # [B]
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
if batch_idx == 0:
draft_start_idx = 0
else:
draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1)
draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx)
draft_len = draft_end_idx - draft_start_idx
sample_len = draft_len + 1
q_end_idx = tl.load(query_start_loc + batch_idx + 1)
sample_start_idx = draft_start_idx + batch_idx
sample_end_idx = sample_start_idx + sample_len
offset = tl.arange(0, BLOCK_SIZE)
tl.store(logits_indices + sample_start_idx + offset,
q_end_idx - sample_len + offset,
mask=offset < sample_len)
tl.store(target_logits_indices + draft_start_idx + offset,
sample_start_idx + offset,
mask=offset < draft_len)
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
types.int64[:], # positions
)
],
nopython=True,
cache=True,
)
def prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
positions: np.ndarray, # [num_input_tokens]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)