[Hardware][TPU] Fix the recompiling issue in logits processor after warmup (#14510)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@ -21,7 +21,9 @@ sampling_params = SamplingParams(temperature=0.7,
|
||||
|
||||
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
||||
# In real workloads, `enforace_eager` should be `False`.
|
||||
llm = LLM(model="google/gemma-2b", enforce_eager=True)
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=4)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output, answer in zip(outputs, answers):
|
||||
prompt = output.prompt
|
||||
|
||||
@ -401,6 +401,7 @@ class TPUModelRunner:
|
||||
self.query_start_loc_np[0] = 0
|
||||
np.cumsum(num_scheduled_tokens_per_req,
|
||||
out=self.query_start_loc_np[1:num_reqs + 1])
|
||||
self.query_start_loc_np[num_reqs + 1:] = 1
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
@ -441,7 +442,10 @@ class TPUModelRunner:
|
||||
# partial request, we do so for simplicity. We will ignore the sampled
|
||||
# token from the partial request.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
num_reqs, self.max_num_reqs)
|
||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||
logits_indices = logits_indices.to(self.device)
|
||||
return attn_metadata, logits_indices
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
@ -551,7 +555,6 @@ class TPUModelRunner:
|
||||
|
||||
# Prepare inputs
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
@ -579,12 +582,10 @@ class TPUModelRunner:
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = hidden_states[:total_num_scheduled_tokens]
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
logits_indices = logits_indices[:num_reqs]
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
selected_token_ids = self.model.compute_logits(hidden_states,
|
||||
logits_indices, None)
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
|
||||
# Then, let's update the cache state.
|
||||
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
||||
@ -726,12 +727,31 @@ class TPUModelRunner:
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
self.model(
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
64, self.max_num_reqs)
|
||||
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
|
||||
# compilation cache size of token_bucket_num multiplied by
|
||||
# req_bucket_num. This is acceptable, given the graph's relatively
|
||||
# small size.
|
||||
while True:
|
||||
logits_indices = torch.zeros(
|
||||
num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
torch._dynamo.mark_dynamic(logits_indices, 0)
|
||||
self.model.compute_logits(hidden_states, logits_indices, None)
|
||||
if num_reqs >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
num_reqs + 1, self.max_num_reqs)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
@ -823,13 +843,17 @@ class ModelWrapperV1(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits_indices: torch.Tensor,
|
||||
sampling_metadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
return logits
|
||||
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
return selected_token_ids
|
||||
|
||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||
@ -846,3 +870,8 @@ def _get_padded_token_len(x: int) -> int:
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
||||
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
|
||||
return min(res, upper_limit)
|
||||
|
||||
Reference in New Issue
Block a user