Compare commits

...

2 Commits

Author SHA1 Message Date
fefed35cee fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 18:58:00 -07:00
901afda905 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 09:32:49 -07:00
4 changed files with 89 additions and 69 deletions

View File

@ -842,6 +842,7 @@ class Scheduler(SchedulerInterface):
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]: ) -> dict[int, EngineCoreOutputs]:
num_sampled_tokens = model_runner_output.num_sampled_tokens
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
@ -849,6 +850,10 @@ class Scheduler(SchedulerInterface):
pooler_outputs = model_runner_output.pooler_output pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits num_nans_in_logits = model_runner_output.num_nans_in_logits
if sampled_token_ids is not None:
# Optimization: Avoid a .tolist() call for each request.
sampled_token_ids = sampled_token_ids.tolist()
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: Optional[SpecDecodingStats] = None
@ -867,14 +872,19 @@ class Scheduler(SchedulerInterface):
continue continue
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[ generated_token_ids: list[int] = []
req_index] if sampled_token_ids else [] if sampled_token_ids is not None:
assert num_sampled_tokens is not None
num_sampled = num_sampled_tokens[req_index]
if num_sampled > 0:
generated_token_ids = sampled_token_ids[
req_index][:num_sampled]
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids: if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids) num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1 num_accepted = num_sampled - 1
num_rejected = num_draft_tokens - num_accepted num_rejected = num_draft_tokens - num_accepted
# num_computed_tokens represents the number of tokens # num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled # processed in the current step, considering scheduled

View File

@ -4,6 +4,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
import numpy as np
import torch import torch
@ -88,11 +89,12 @@ class ModelRunnerOutput:
# req_id -> index # req_id -> index
req_id_to_index: dict[str, int] req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens # [num_reqs]
# num_generated_tokens is the number of tokens # Number of tokens sampled in the current step. Each request may generate
# generated in the current step. It can be different for # different number of tokens due to chunked prefilling and spec decoding.
# each request due to speculative/jump decoding. num_sampled_tokens: Optional[np.ndarray]
sampled_token_ids: list[list[int]] # [num_reqs, max_num_sampled_tokens]
sampled_token_ids: Optional[np.ndarray]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
@ -123,10 +125,13 @@ class DraftTokenIds:
draft_token_ids: list[list[int]] draft_token_ids: list[list[int]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_id_to_index={}, req_ids=[],
sampled_token_ids=[], req_id_to_index={},
logprobs=None, num_sampled_tokens=None,
prompt_logprobs_dict={}, sampled_token_ids=None,
pooler_output=[], logprobs=None,
num_nans_in_logits=None) prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None,
)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -106,9 +107,9 @@ class RejectionSampler(nn.Module):
@staticmethod @staticmethod
def parse_output( def parse_output(
output_token_ids: torch.Tensor, output_token_ids: np.ndarray,
vocab_size: int, vocab_size: int,
) -> list[list[int]]: ) -> np.ndarray:
"""Parse the output of the rejection sampler. """Parse the output of the rejection sampler.
Args: Args:
@ -119,17 +120,14 @@ class RejectionSampler(nn.Module):
vocab_size: The size of the vocabulary. vocab_size: The size of the vocabulary.
Returns: Returns:
A list of lists of token IDs. A Numpy array of the number of valid sampled tokens.
""" """
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens. # Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & valid_mask = ((output_token_ids != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size)) (output_token_ids < vocab_size))
outputs = [ # Get the number until the first valid_mask=False.
row[valid_mask[i]].tolist() num_sampled_tokens = np.cumprod(valid_mask, axis=1).sum(axis=1)
for i, row in enumerate(output_token_ids_np) return num_sampled_tokens
]
return outputs
def rejection_sample( def rejection_sample(

View File

@ -1456,7 +1456,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[], num_sampled_tokens=None,
sampled_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=pooler_output, pooler_output=pooler_output,
@ -1665,23 +1666,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if envs.VLLM_COMPUTE_NANS_IN_LOGITS: if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits) num_nans_in_logits = self._get_nans_in_logits(logits)
# TODO(woosuk): The following loop can be slow since it iterates over # Post-processing for chunked prefill.
# the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs
discard_sampled_tokens_req_indices = [] chunked_prefilling = (
for i, req_id in enumerate(self.input_batch.req_ids): self.input_batch.num_computed_tokens_cpu[:num_reqs] +
req_state = self.requests[req_id] num_scheduled_tokens_np
seq_len = (req_state.num_computed_tokens + < self.input_batch.num_tokens_no_spec[:num_reqs])
scheduler_output.num_scheduled_tokens[req_id]) if self.input_batch.generators:
if seq_len < req_state.num_tokens: chunked_prefill_indices = np.where(chunked_prefilling)[0]
for i in chunked_prefill_indices:
# Ignore the sampled token for partial prefills. # Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled. # Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details # This relies on cuda-specific torch-internal impl details
generator = self.input_batch.generators.get(i) generator = self.input_batch.generators.get(i)
if generator is not None: if generator is not None:
generator.set_offset(generator.get_offset() - 4) generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here. # NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point. # Move as many CPU operations as possible before this sync point.
@ -1700,16 +1699,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids) sampled_token_ids_np = self._to_numpy(sampled_token_ids)
num_sampled_tokens = (~chunked_prefilling).astype(np.int32)
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids_np = sampled_token_ids.cpu().numpy()
sampled_token_ids, num_sampled_tokens = self.rejection_sampler.parse_output(
sampled_token_ids_np,
self.input_batch.vocab_size, self.input_batch.vocab_size,
) )
# Mask out the sampled tokens that should not be sampled. num_sampled_tokens *= ~chunked_prefilling
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler # Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back. # doesn't need to send them back.
@ -1717,9 +1716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the sampled tokens back, because there's no direct communication # the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker. # between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids req_ids = self.input_batch.req_ids
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): for req_idx in range(num_reqs):
if not sampled_ids: num_sampled = num_sampled_tokens[req_idx]
if num_sampled == 0:
continue continue
sampled_ids = sampled_token_ids_np[req_idx][:num_sampled].tolist()
start_idx = self.input_batch.num_tokens_no_spec[req_idx] start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids) end_idx = start_idx + len(sampled_ids)
@ -1740,7 +1741,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert spec_decode_common_attn_metadata is not None assert spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids( self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output, scheduler_output,
valid_sampled_token_ids, num_sampled_tokens,
sampled_token_ids_np,
sampling_metadata, sampling_metadata,
hidden_states, hidden_states,
sample_hidden_states, sample_hidden_states,
@ -1754,7 +1756,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids, num_sampled_tokens=num_sampled_tokens,
sampled_token_ids=sampled_token_ids_np,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[], pooler_output=[],
@ -1776,7 +1779,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]], num_sampled_tokens: np.ndarray,
sampled_token_ids: np.ndarray,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor,
@ -1788,19 +1792,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.propose_ngram_draft_token_ids( draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids) num_sampled_tokens)
elif self.speculative_config.method == "medusa": elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer) assert isinstance(self.drafter, MedusaProposer)
if sample_hidden_states.shape[0] == len(sampled_token_ids): if sample_hidden_states.shape[0] == len(num_sampled_tokens):
# The input to the target model does not include draft tokens. # The input to the target model does not include draft tokens.
hidden_states = sample_hidden_states hidden_states = sample_hidden_states
else: else:
indices = [] indices = []
offset = 0 offset = 0
for num_draft, tokens in zip( for num_draft, num_sampled in zip(
spec_decode_metadata.num_draft_tokens, spec_decode_metadata.num_draft_tokens,
sampled_token_ids): num_sampled_tokens):
indices.append(offset + len(tokens) - 1) indices.append(offset + num_sampled - 1)
offset += num_draft + 1 offset += num_draft + 1
indices = torch.tensor(indices, device=self.device) indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices] hidden_states = sample_hidden_states[indices]
@ -1813,11 +1817,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop. # TODO(woosuk): Refactor the loop.
req_ids = self.input_batch.req_ids req_ids = self.input_batch.req_ids
num_reqs = self.input_batch.num_reqs
next_token_ids: list[int] = [] next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids): for i in range(num_reqs):
if token_ids: num_sampled = num_sampled_tokens[i]
if num_sampled > 0:
# Common case. # Common case.
next_token_id = token_ids[-1] next_token_id = sampled_token_ids[i][num_sampled - 1]
else: else:
# Partial prefill (rare case). # Partial prefill (rare case).
# Get the next token id from the request state. # Get the next token id from the request state.
@ -1844,13 +1850,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_hidden_states = hidden_states[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens]
else: else:
# TODO(woosuk): Refactor this. # TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens num_draft_tokens = np.asarray(
num_rejected_tokens = [ spec_decode_metadata.num_draft_tokens, dtype=np.int32)
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 num_accepted_tokens = num_sampled_tokens - 1
for i, n in enumerate(num_draft_tokens) num_rejected_tokens = np.clip(num_draft_tokens -
] num_accepted_tokens,
num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, a_min=0)
dtype=torch.int32) num_rejected_tokens_cpu = torch.from_numpy(num_rejected_tokens)
common_attn_metadata, token_indices =\ common_attn_metadata, token_indices =\
self.drafter.prepare_inputs( self.drafter.prepare_inputs(
common_attn_metadata, num_rejected_tokens_cpu) common_attn_metadata, num_rejected_tokens_cpu)
@ -1881,13 +1887,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_ngram_draft_token_ids( def propose_ngram_draft_token_ids(
self, self,
sampled_token_ids: list[list[int]], num_sampled_tokens: np.ndarray,
) -> list[list[int]]: ) -> list[list[int]]:
# TODO(woosuk): Optimize. # TODO(woosuk): Optimize.
req_ids = self.input_batch.req_ids req_ids = self.input_batch.req_ids
num_reqs = self.input_batch.num_reqs
draft_token_ids: list[list[int]] = [] draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids): for i in range(num_reqs):
num_sampled_ids = len(sampled_ids) num_sampled_ids = num_sampled_tokens[i]
if not num_sampled_ids: if not num_sampled_ids:
# Skip speculative decoding. # Skip speculative decoding.
draft_token_ids.append([]) draft_token_ids.append([])
@ -3267,7 +3274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return kv_cache_spec return kv_cache_spec
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: def _to_numpy(self, sampled_token_ids: torch.Tensor) -> np.ndarray:
# This is a short term mitigation for issue mentioned in # This is a short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/22754. # https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a cuda wise stream sync, which # `tolist` would trigger a cuda wise stream sync, which
@ -3280,4 +3287,4 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pinned.copy_(sampled_token_ids, non_blocking=True) pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record() self.transfer_event.record()
self.transfer_event.synchronize() self.transfer_event.synchronize()
return pinned.tolist() return pinned.numpy()