@ -156,8 +156,8 @@ class BlockTables:
|
||||
self,
|
||||
cu_num_tokens: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_tokens = pos.shape[0]
|
||||
num_reqs = cu_num_tokens.shape[0] - 1
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_reqs + 1, num_groups)](
|
||||
|
||||
@ -3,10 +3,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import types
|
||||
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
@ -35,109 +33,3 @@ class InputBatch:
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
||||
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:], # idx_mapping
|
||||
types.int32[:, :], # token_ids
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
types.int64[:], # positions
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def prepare_inputs(
|
||||
# Inputs
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
# Outputs
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
|
||||
cu_num_tokens = 0
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
start = num_computed_tokens[req_idx]
|
||||
end = start + num_scheduled_tokens[i]
|
||||
seq_lens[i] = end
|
||||
|
||||
start_idx = cu_num_tokens
|
||||
end_idx = start_idx + num_scheduled_tokens[i]
|
||||
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
||||
positions[start_idx:end_idx] = np.arange(start, end)
|
||||
|
||||
cu_num_tokens = end_idx
|
||||
query_start_loc[i + 1] = cu_num_tokens
|
||||
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
|
||||
def prepare_spec_decode(
|
||||
# Inputs
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
num_draft_tokens: np.ndarray, # [B]
|
||||
# Outputs
|
||||
cu_num_draft_tokens: np.ndarray, # [B]
|
||||
logits_indices: np.ndarray, # [N + B]
|
||||
target_logits_indices: np.ndarray, # [N]
|
||||
bonus_logits_indices: np.ndarray, # [B]
|
||||
) -> int: # N
|
||||
# Inputs:
|
||||
# query_start_loc: [ 0, 4, 104, 107, 207, 209]
|
||||
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
||||
# Outputs:
|
||||
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
||||
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
||||
# 206, 207, 208]
|
||||
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
||||
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
||||
# return: 6 (total number of draft tokens)
|
||||
|
||||
cu_num_draft = 0
|
||||
cu_num_sample = 0
|
||||
num_reqs = num_draft_tokens.shape[0]
|
||||
for i in range(num_reqs):
|
||||
q_end_idx = query_start_loc[i + 1]
|
||||
draft_len = num_draft_tokens[i]
|
||||
|
||||
# The last draft_len + 1 query tokens are used for sampling.
|
||||
sample_len = draft_len + 1
|
||||
sample_start_idx = cu_num_sample
|
||||
sample_end_idx = sample_start_idx + sample_len
|
||||
logits_indices[sample_start_idx:sample_end_idx] = (np.arange(
|
||||
q_end_idx - sample_len, q_end_idx))
|
||||
|
||||
# For each query, the first draft_len tokens need target logits for
|
||||
# rejection sampling. The draft_len + 1th token is used for bonus token.
|
||||
draft_start_idx = cu_num_draft
|
||||
draft_end_idx = draft_start_idx + draft_len
|
||||
target_logits_indices[draft_start_idx:draft_end_idx] = (np.arange(
|
||||
sample_start_idx, sample_end_idx - 1))
|
||||
bonus_logits_indices[i] = sample_end_idx - 1
|
||||
|
||||
cu_num_draft += draft_len
|
||||
cu_num_draft_tokens[i] = cu_num_draft
|
||||
cu_num_sample += sample_len
|
||||
|
||||
return cu_num_draft
|
||||
|
||||
@ -77,9 +77,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_block_table import BlockTables
|
||||
from vllm.v1.worker.gpu_input_batch import (InputBatch, prepare_inputs,
|
||||
prepare_spec_decode)
|
||||
from vllm.v1.worker.gpu_worker_states import RequestState
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_worker_states import RequestState, prepare_inputs
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@ -233,24 +232,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self.input_ids = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.int32)
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
self.positions = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.int64)
|
||||
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
|
||||
dtype=torch.int32)
|
||||
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
self.cu_num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.spec_logits_indices = self._make_buffer(self.max_num_tokens +
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.target_logits_indices = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.int32)
|
||||
self.bonus_logits_indices = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
@ -543,8 +535,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get)
|
||||
# req_id -> batch_idx
|
||||
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping_list = [
|
||||
self.requests.req_id_to_index[req_id] for req_id in req_ids
|
||||
@ -552,49 +542,50 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.idx_mapping.np[:num_reqs] = idx_mapping_list
|
||||
idx_mapping_np = self.idx_mapping.np[:num_reqs]
|
||||
idx_mapping = self.idx_mapping.copy_to_gpu(num_reqs)
|
||||
# req_id -> batch_idx
|
||||
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
block_tables = self.block_tables.compute_block_tables(idx_mapping)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
|
||||
dtype=np.int32)
|
||||
|
||||
prepare_inputs(
|
||||
idx_mapping=idx_mapping_np,
|
||||
token_ids=self.requests.token_ids.np,
|
||||
num_computed_tokens=self.requests.num_computed_tokens.np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
input_ids=self.input_ids.np,
|
||||
query_start_loc=self.query_start_loc.np,
|
||||
seq_lens=self.seq_lens.np,
|
||||
positions=self.positions.np,
|
||||
idx_mapping_np,
|
||||
self.requests.token_ids.np,
|
||||
self.requests.num_computed_tokens.np,
|
||||
num_scheduled_tokens,
|
||||
self.input_ids.np,
|
||||
self.query_start_loc.np,
|
||||
self.seq_lens.np,
|
||||
self.positions.np,
|
||||
)
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
|
||||
|
||||
self.seq_lens.copy_to_gpu()
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self.positions.copy_to_gpu(total_num_scheduled_tokens)
|
||||
|
||||
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
|
||||
# tensors from CPU to GPU, because they may include paddings needed
|
||||
# for full CUDA graph mode.
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
self.seq_lens.copy_to_gpu()
|
||||
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
|
||||
max_query_len = int(num_scheduled_tokens.max())
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||
max_seq_len = int(self.seq_lens.np[:num_reqs].max())
|
||||
|
||||
# Compute the slot mappings on GPUs.
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc, self.positions.gpu, total_num_scheduled_tokens)
|
||||
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.positions.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self._calc_mrope_positions(req_ids, num_scheduled_tokens)
|
||||
# Optimization: To avoid gather and scatter, copy the whole M-RoPE
|
||||
# tensor from CPU to GPU although only a part of it is used.
|
||||
self.mrope_positions.copy_to_gpu()
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
@ -603,19 +594,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(req_ids):
|
||||
draft_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
if draft_token_ids:
|
||||
num_draft_tokens[i] = len(draft_token_ids)
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens)
|
||||
spec_decode_metadata = self._prepare_spec_decode_metadata(
|
||||
req_ids,
|
||||
scheduler_output.scheduled_spec_decode_tokens,
|
||||
query_start_loc,
|
||||
)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
logits_indices_padded = None
|
||||
@ -643,9 +630,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
|
||||
)
|
||||
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc, self.positions.gpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
|
||||
@ -689,7 +673,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=blk_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
@ -734,7 +718,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_reqs=num_reqs,
|
||||
total_num_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
max_query_len=max_query_len,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_decode_metadata=spec_decode_metadata,
|
||||
spec_decode_common_attn_metadata=spec_decode_common_attn_metadata,
|
||||
@ -836,17 +820,44 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
return common_prefix_len if use_cascade else 0
|
||||
|
||||
def _calc_mrope_positions(self, input_batch: InputBatch):
|
||||
mrope_pos_ptr = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req = self.requests[req_id]
|
||||
assert req.mrope_positions is not None
|
||||
def _prepare_spec_decode_metadata(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
req_id_to_draft_token_ids: dict[str, list[int]],
|
||||
query_start_loc: torch.Tensor,
|
||||
) -> SpecDecodeMetadata:
|
||||
# Get the number of draft tokens for each request.
|
||||
num_reqs = len(req_ids)
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(req_ids):
|
||||
draft_token_ids = req_id_to_draft_token_ids.get(req_id)
|
||||
if draft_token_ids:
|
||||
num_draft_tokens[i] = len(draft_token_ids)
|
||||
np.cumsum(num_draft_tokens,
|
||||
dtype=np.int32,
|
||||
out=self.cu_num_draft_tokens.np[:num_reqs])
|
||||
cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs)
|
||||
return self.requests.make_spec_decode_metadata(
|
||||
query_start_loc,
|
||||
cu_num_draft_tokens,
|
||||
cu_num_draft_tokens.np[:num_reqs],
|
||||
self.input_ids.gpu,
|
||||
)
|
||||
|
||||
num_computed_tokens = \
|
||||
self.requests.num_computed_tokens_cpu[i]
|
||||
num_scheduled_tokens = \
|
||||
input_batch.num_scheduled_tokens[i]
|
||||
num_prompt_tokens = len(req.prompt_token_ids)
|
||||
def _calc_mrope_positions(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
query_lens: np.ndarray,
|
||||
):
|
||||
mrope_pos_ptr = 0
|
||||
for i, req_id in enumerate(req_ids):
|
||||
req_idx = self.requests.req_id_to_index[req_id]
|
||||
req_data = self.requests.req_data[req_idx]
|
||||
assert req_data.mrope_positions is not None
|
||||
|
||||
num_computed_tokens = self.requests.num_computed_tokens.np[req_idx]
|
||||
num_scheduled_tokens = query_lens[i]
|
||||
num_prompt_tokens = self.requests.num_prompt_tokens.np[req_idx]
|
||||
|
||||
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
||||
prompt_part_len = max(0,
|
||||
@ -867,7 +878,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
src_end = num_computed_tokens + prompt_part_len
|
||||
|
||||
self.mrope_positions.cpu[:, dst_start:dst_end] = (
|
||||
req.mrope_positions[:, src_start:src_end])
|
||||
req_data.mrope_positions[:, src_start:src_end])
|
||||
mrope_pos_ptr += prompt_part_len
|
||||
|
||||
if completion_part_len > 0:
|
||||
@ -878,49 +889,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||
out=self.mrope_positions.np,
|
||||
out_offset=dst_start,
|
||||
mrope_position_delta=req.mrope_position_delta,
|
||||
mrope_position_delta=req_data.mrope_position_delta,
|
||||
context_len=num_computed_tokens + prompt_part_len,
|
||||
num_new_tokens=completion_part_len,
|
||||
)
|
||||
|
||||
mrope_pos_ptr += completion_part_len
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
self,
|
||||
num_draft_tokens: np.ndarray,
|
||||
) -> SpecDecodeMetadata:
|
||||
num_reqs = num_draft_tokens.shape[0]
|
||||
total_num_draft_tokens = prepare_spec_decode(
|
||||
self.query_start_loc.np,
|
||||
num_draft_tokens,
|
||||
self.cu_num_draft_tokens.np,
|
||||
self.logits_indices.np,
|
||||
self.target_logits_indices.np,
|
||||
self.bonus_logits_indices.np,
|
||||
)
|
||||
|
||||
cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs)
|
||||
logits_indices = self.logits_indices.copy_to_gpu(
|
||||
num_reqs + total_num_draft_tokens)
|
||||
target_logits_indices = self.target_logits_indices.copy_to_gpu(
|
||||
total_num_draft_tokens)
|
||||
bonus_logits_indices = self.bonus_logits_indices.copy_to_gpu(num_reqs)
|
||||
|
||||
# Compute the draft token ids.
|
||||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||||
draft_token_ids = self.input_ids.gpu[logits_indices]
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
|
||||
metadata = SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
return metadata
|
||||
|
||||
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
@ -1353,7 +1328,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
inputs_embeds = None
|
||||
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions.gpu[:, :num_input_tokens]
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_input_tokens]
|
||||
|
||||
@ -2117,7 +2092,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
model_kwargs = self._init_model_kwargs(num_tokens)
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions.gpu[:, :num_tokens]
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_tokens]
|
||||
|
||||
|
||||
@ -5,10 +5,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from numba import types
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -18,6 +20,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -158,6 +161,10 @@ class RequestState:
|
||||
is_scalar=num_cols == 1,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_cached_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
@ -292,9 +299,43 @@ class RequestState:
|
||||
logitsprocs=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_cached_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
def make_spec_decode_metadata(
|
||||
self,
|
||||
query_start_loc: torch.Tensor,
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
cu_num_draft_tokens_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
) -> SpecDecodeMetadata:
|
||||
batch_size = query_start_loc.shape[0] - 1
|
||||
total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1]
|
||||
logits_indices = torch.empty(total_num_draft_tokens + batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
target_logits_indices = torch.empty(total_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
bonus_logits_indices = torch.empty(batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
_prepare_spec_decode_kernel[(batch_size, )](
|
||||
query_start_loc,
|
||||
cu_num_draft_tokens,
|
||||
logits_indices,
|
||||
target_logits_indices,
|
||||
bonus_logits_indices,
|
||||
BLOCK_SIZE=triton.next_power_of_2(32 + 1),
|
||||
)
|
||||
|
||||
draft_token_ids = input_ids[logits_indices]
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
return SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=cu_num_draft_tokens_np.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -333,3 +374,90 @@ def _make_sampling_metadata_kernel(
|
||||
|
||||
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
|
||||
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
|
||||
|
||||
|
||||
def _prepare_spec_decode_kernel(
|
||||
query_start_loc, # [B + 1]
|
||||
cu_num_draft_tokens, # [B]
|
||||
logits_indices, # [N + B]
|
||||
target_logits_indices, # [N]
|
||||
bonus_logits_indices, # [B]
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
|
||||
if batch_idx == 0:
|
||||
draft_start_idx = 0
|
||||
else:
|
||||
draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1)
|
||||
draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx)
|
||||
draft_len = draft_end_idx - draft_start_idx
|
||||
sample_len = draft_len + 1
|
||||
|
||||
q_end_idx = tl.load(query_start_loc + batch_idx + 1)
|
||||
|
||||
sample_start_idx = draft_start_idx + batch_idx
|
||||
sample_end_idx = sample_start_idx + sample_len
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(logits_indices + sample_start_idx + offset,
|
||||
q_end_idx - sample_len + offset,
|
||||
mask=offset < sample_len)
|
||||
tl.store(target_logits_indices + draft_start_idx + offset,
|
||||
sample_start_idx + offset,
|
||||
mask=offset < draft_len)
|
||||
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)
|
||||
|
||||
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:], # idx_mapping
|
||||
types.int32[:, :], # token_ids
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
types.int64[:], # positions
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def prepare_inputs(
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
|
||||
cu_num_tokens = 0
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
query_len = num_scheduled_tokens[i]
|
||||
start = num_computed_tokens[req_idx]
|
||||
end = start + query_len
|
||||
seq_lens[i] = end
|
||||
|
||||
start_idx = cu_num_tokens
|
||||
end_idx = start_idx + query_len
|
||||
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
||||
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
|
||||
|
||||
cu_num_tokens = end_idx
|
||||
query_start_loc[i + 1] = cu_num_tokens
|
||||
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
Reference in New Issue
Block a user