[Hardware][TPU] Fix the recompiling issue in logits processor after warmup (#14510)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-03-09 01:44:39 -08:00
committed by GitHub
parent fb16eea48b
commit 212007b168
2 changed files with 41 additions and 10 deletions

View File

@ -21,7 +21,9 @@ sampling_params = SamplingParams(temperature=0.7,
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="google/gemma-2b", enforce_eager=True)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=64,
max_num_seqs=4)
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt

View File

@ -401,6 +401,7 @@ class TPUModelRunner:
self.query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens_per_req,
out=self.query_start_loc_np[1:num_reqs + 1])
self.query_start_loc_np[num_reqs + 1:] = 1
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
@ -441,7 +442,10 @@ class TPUModelRunner:
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs, self.max_num_reqs)
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
@ -551,7 +555,6 @@ class TPUModelRunner:
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
@ -579,12 +582,10 @@ class TPUModelRunner:
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:total_num_scheduled_tokens]
num_reqs = self.input_batch.num_reqs
logits_indices = logits_indices[:num_reqs]
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
selected_token_ids = self.model.compute_logits(hidden_states,
logits_indices, None)
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
# Then, let's update the cache state.
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
@ -726,12 +727,31 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(
hidden_states = self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = _get_padded_num_reqs_with_upper_limit(
64, self.max_num_reqs)
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
# compilation cache size of token_bucket_num multiplied by
# req_bucket_num. This is acceptable, given the graph's relatively
# small size.
while True:
logits_indices = torch.zeros(
num_reqs,
dtype=torch.int32,
device=self.device,
)
torch._dynamo.mark_dynamic(hidden_states, 0)
torch._dynamo.mark_dynamic(logits_indices, 0)
self.model.compute_logits(hidden_states, logits_indices, None)
if num_reqs >= self.max_num_reqs:
break
num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs + 1, self.max_num_reqs)
def capture_model(self) -> None:
"""Compile the model."""
@ -823,13 +843,17 @@ class ModelWrapperV1(nn.Module):
return hidden_states
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def compute_logits(
self,
hidden_states: torch.Tensor,
logits_indices: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, sampling_metadata)
return logits
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
return selected_token_ids
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
@ -846,3 +870,8 @@ def _get_padded_token_len(x: int) -> int:
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
return min(res, upper_limit)