Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-09-07 12:01:45 -07:00
parent 286eeb91e8
commit 5f95309a6d

View File

@ -252,7 +252,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.rejection_sampler = RejectionSampler()
# Request states.
self.requests = RequestState(
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
@ -431,7 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
for req_id in scheduler_output.finished_req_ids:
self.requests.remove_request(req_id)
self.req_states.remove_request(req_id)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
@ -450,14 +450,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
self.requests.add_request(
self.req_states.add_request(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
)
req_index = self.requests.req_id_to_index[req_id]
req_index = self.req_states.req_id_to_index[req_id]
req_indices.append(req_index)
for i, block_ids in enumerate(new_req_data.block_ids):
x = cu_num_new_blocks[i][-1]
@ -473,7 +473,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_last_rank = get_pp_group().is_last_rank
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.requests.req_id_to_index[req_id]
req_index = self.req_states.req_id_to_index[req_id]
# Update input batch.
if not is_last_rank:
@ -481,7 +481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = cached_reqs.new_token_ids[i]
self.requests.append_token_ids(req_index, new_token_ids)
self.req_states.append_token_ids(req_index, new_token_ids)
req_new_block_ids = cached_reqs.new_block_ids[i]
if req_new_block_ids is not None:
@ -494,7 +494,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# overwrite the existing block IDs.
overwrite.append(cached_reqs.resumed_from_preemption[i])
self.requests.num_computed_tokens.np[req_index] = (
self.req_states.num_computed_tokens.np[req_index] = (
cached_reqs.num_computed_tokens[i])
if req_indices:
@ -506,10 +506,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
def _init_mrope_positions(self, req_id: str) -> None:
req_idx = self.requests.req_id_to_index[req_id]
req_data = self.requests.req_data[req_idx]
prompt_len = self.requests.num_prompt_tokens.np[req_idx]
prompt_token_ids = self.requests.token_ids.np[req_idx, :prompt_len]
req_idx = self.req_states.req_id_to_index[req_id]
req_data = self.req_states.req_data[req_idx]
prompt_len = self.req_states.num_prompt_tokens.np[req_idx]
prompt_token_ids = self.req_states.token_ids.np[req_idx, :prompt_len]
image_grid_thw = []
video_grid_thw = []
@ -585,7 +585,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
key=scheduler_output.num_scheduled_tokens.get)
# batch_idx -> req_idx
idx_mapping_list = [
self.requests.req_id_to_index[req_id] for req_id in req_ids
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
self.idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = self.idx_mapping.np[:num_reqs]
@ -604,8 +604,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prepare_inputs(
idx_mapping_np,
self.requests.token_ids.np,
self.requests.num_computed_tokens.np,
self.req_states.token_ids.np,
self.req_states.num_computed_tokens.np,
num_scheduled_tokens,
self.input_ids.np,
self.query_start_loc.np,
@ -661,7 +661,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_np = self.requests.num_computed_tokens.np[
num_computed_tokens_np = self.req_states.num_computed_tokens.np[
idx_mapping_np]
num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np)
spec_decode_common_attn_metadata = None
@ -865,7 +865,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=np.int32,
out=self.cu_num_draft_tokens.np[:num_reqs])
cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs)
return self.requests.make_spec_decode_metadata(
return self.req_states.make_spec_decode_metadata(
query_start_loc,
cu_num_draft_tokens,
cu_num_draft_tokens.np[:num_reqs],
@ -879,13 +879,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
mrope_pos_ptr = 0
for i, req_id in enumerate(req_ids):
req_idx = self.requests.req_id_to_index[req_id]
req_data = self.requests.req_data[req_idx]
req_idx = self.req_states.req_id_to_index[req_id]
req_data = self.req_states.req_data[req_idx]
assert req_data.mrope_positions is not None
num_computed_tokens = self.requests.num_computed_tokens.np[req_idx]
num_computed_tokens = self.req_states.num_computed_tokens.np[req_idx]
num_scheduled_tokens = query_lens[i]
num_prompt_tokens = self.requests.num_prompt_tokens.np[req_idx]
num_prompt_tokens = self.req_states.num_prompt_tokens.np[req_idx]
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
@ -959,8 +959,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_idx = self.requests.req_id_to_index[req_id]
req_data = self.requests.req_data[req_idx]
req_idx = self.req_states.req_id_to_index[req_id]
req_data = self.req_states.req_data[req_idx]
for mm_input_id in encoder_input_ids:
mm_hash = req_data.mm_hashes[mm_input_id]
@ -1014,11 +1014,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds: list[torch.Tensor] = []
for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_tokens = input_batch.num_scheduled_tokens[i]
req_idx = self.requests.req_id_to_index[req_id]
req_idx = self.req_states.req_id_to_index[req_id]
num_computed_tokens = (
self.requests.num_computed_tokens.np[req_idx] +
self.req_states.num_computed_tokens.np[req_idx] +
shift_computed_tokens)
req_data = self.requests.req_data[req_idx]
req_data = self.req_states.req_data[req_idx]
mm_positions = req_data.mm_positions
mm_hashes = req_data.mm_hashes
for i, pos_info in enumerate(mm_positions):
@ -1135,7 +1135,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0
seq = sorted(self.requests.req_id_to_index.items(), key=lambda x: x[1])
seq = sorted(self.req_states.req_id_to_index.items(), key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
@ -1458,7 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_metadata: Optional[SpecDecodeMetadata]
) -> SamplerOutput:
# Sample the next token and get logprobs if needed.
sampling_metadata = self.requests.make_sampling_metadata(
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping)
if input_batch.spec_decode_metadata is None:
sampler_output = self.sampler(
@ -1517,14 +1517,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_idx = self.requests.req_id_to_index[req_id]
seq_len = (self.requests.num_computed_tokens.np[req_idx] +
req_idx = self.req_states.req_id_to_index[req_id]
seq_len = (self.req_states.num_computed_tokens.np[req_idx] +
input_batch.num_scheduled_tokens[i])
if seq_len < self.requests.num_tokens.np[req_idx]:
if seq_len < self.req_states.num_tokens.np[req_idx]:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator = self.requests.generators.get(req_idx)
generator = self.req_states.generators.get(req_idx)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
@ -1584,18 +1584,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampled_ids = valid_sampled_token_ids[i]
if not sampled_ids:
continue
req_idx = self.requests.req_id_to_index[req_id]
req_idx = self.req_states.req_id_to_index[req_id]
start_idx = self.requests.num_tokens.np[req_idx]
start_idx = self.req_states.num_tokens.np[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
self.requests.token_ids.np[req_idx,
self.req_states.token_ids.np[req_idx,
start_idx:end_idx] = sampled_ids
self.requests.num_tokens.np[req_idx] = end_idx
self.req_states.num_tokens.np[req_idx] = end_idx
if self.speculative_config:
assert input_batch.spec_decode_common_attn_metadata is not None
@ -1770,7 +1770,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = req_ids[i]
req_state = self.requests[req_id]
req_state = self.req_states[req_id]
seq_len = (req_state.num_computed_tokens +
input_batch.num_scheduled_tokens[i])
next_token_id = req_state.get_token_id(seq_len)
@ -1850,14 +1850,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# draft_token_ids.append([])
# continue
num_tokens = self.requests.num_tokens.np[i]
num_tokens = self.req_states.num_tokens.np[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
drafter_output = self.drafter.propose(
self.requests.token_ids.np[i, :num_tokens])
self.req_states.token_ids.np[i, :num_tokens])
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
@ -1995,11 +1995,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states: torch.Tensor,
num_scheduled_tokens: dict[str, int],
) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.requests.num_prompt_logprobs
num_prompt_logprobs_dict = self.req_states.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
in_progress_dict = self.requests.in_progress_prompt_logprobs_cpu
in_progress_dict = self.req_states.in_progress_prompt_logprobs_cpu
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
@ -2009,7 +2009,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens = num_scheduled_tokens[req_id]
# Get metadata for this request.
request = self.requests[req_id]
request = self.req_states[req_id]
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)
@ -2255,7 +2255,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
1],
seq_lens=self.seq_lens.gpu[:num_reqs],
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
num_computed_tokens_cpu=self.requests.num_computed_tokens.
num_computed_tokens_cpu=self.req_states.num_computed_tokens.
cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,