[Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -787,7 +787,7 @@ class VllmRunner:
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: list[str],
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
@ -809,16 +809,18 @@ class VllmRunner:
|
||||
if audios is not None and (audio := audios[i]) is not None:
|
||||
multi_modal_data["audio"] = audio
|
||||
|
||||
inputs.append(
|
||||
TextPrompt(prompt=prompt,
|
||||
multi_modal_data=multi_modal_data
|
||||
if multi_modal_data else None))
|
||||
text_prompt_kwargs = {
|
||||
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
|
||||
prompt,
|
||||
"multi_modal_data": multi_modal_data or None
|
||||
}
|
||||
inputs.append(TextPrompt(**text_prompt_kwargs))
|
||||
|
||||
return inputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
@ -844,7 +846,7 @@ class VllmRunner:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||
req_sample_output_strs.append(prompt_str + output_str)
|
||||
req_sample_output_strs.append((prompt_str or "") + output_str)
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
@ -911,7 +913,7 @@ class VllmRunner:
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: list[str],
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
|
||||
@ -2,16 +2,18 @@
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
import torch
|
||||
from torch import Use # noqa
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroup
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
|
||||
from .utils import (append_new_token, append_new_token_seq,
|
||||
append_new_token_seq_group, create_dummy_prompt,
|
||||
@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
|
||||
), "A partial prefix of C (4 tokens) should be prefilled, with the "
|
||||
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
|
||||
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
|
||||
|
||||
|
||||
def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
|
||||
"""
|
||||
Test that the scheduler does not schedule batches with prompt tokens and
|
||||
prompt embeddings co-mingled.
|
||||
"""
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=100,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
|
||||
# the odd indexed inputs should be passed in via embeddings,
|
||||
# evens via token_ids
|
||||
seq_length = 7
|
||||
embedding_size = 5
|
||||
num_seqs = 11
|
||||
seq_tokens: list[list[int]] = []
|
||||
seq_embeds: list[Optional[torch.Tensor]] = []
|
||||
for i in range(num_seqs):
|
||||
if i % 2:
|
||||
seq_tokens.append(list(range(seq_length)))
|
||||
seq_embeds.append(None)
|
||||
else:
|
||||
seq_tokens.append([0] * seq_length)
|
||||
seq_embeds.append(torch.rand(embedding_size))
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens[i],
|
||||
prompt_embeds=seq_embeds[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
|
||||
unfinished_seq_groups = [
|
||||
seq_group for _, seq_group in seq_and_seq_groups
|
||||
if not seq_group.is_finished()
|
||||
]
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) > 0
|
||||
batch_is_prompt_embeds = out.scheduled_seq_groups[
|
||||
0].seq_group.uses_prompt_embeds()
|
||||
expected_scheduled_seq_groups = [
|
||||
seq_group for seq_group in unfinished_seq_groups
|
||||
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
|
||||
]
|
||||
|
||||
# We should have as many scheduled groups as possible, without mixing
|
||||
assert len(out.scheduled_seq_groups) == min(
|
||||
max_seq_group, len(expected_scheduled_seq_groups))
|
||||
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
|
||||
batch_is_prompt_embeds
|
||||
for scheduled_seq_group in out.scheduled_seq_groups)
|
||||
|
||||
# Finish the scheduled groups
|
||||
for scheduled_seq_group in out.scheduled_seq_groups:
|
||||
for seq in scheduled_seq_group.seq_group.seqs:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
@ -5,9 +5,11 @@ from collections import defaultdict
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
||||
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata)
|
||||
@ -19,6 +21,7 @@ def create_dummy_prompt(
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_tokens: Optional[list[int]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
min_tokens: int = 0,
|
||||
max_tokens: int = 16,
|
||||
) -> tuple[Sequence, SequenceGroup]:
|
||||
@ -31,9 +34,13 @@ def create_dummy_prompt(
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
inputs = token_inputs(
|
||||
prompt_token_ids=prompt_tokens,
|
||||
prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
|
||||
prompt_embeds=prompt_embeds)
|
||||
prompt = Sequence(
|
||||
int(request_id),
|
||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||
inputs=inputs,
|
||||
block_size=block_size,
|
||||
)
|
||||
seq_group = SequenceGroup(
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
|
||||
"VLLM_USE_V1") == "0" else None
|
||||
prompt_token_ids = []
|
||||
for prompt in example_prompts:
|
||||
token_ids = hf_model.tokenizer(prompt,
|
||||
return_tensors="pt").input_ids.to(
|
||||
hf_model.model.device)
|
||||
prompt_token_ids.append(token_ids)
|
||||
if prompt_embeds is not None:
|
||||
prompt_embeds.append(hf_model.model.get_input_embeddings()(
|
||||
token_ids).squeeze(0))
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
tokenizer_name=model_info.tokenizer or model,
|
||||
@ -119,6 +134,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
if prompt_embeds is not None:
|
||||
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
|
||||
prompt_embeds, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@ -126,6 +144,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
if prompt_embeds is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=vllm_outputs,
|
||||
outputs_1_lst=vllm_outputs_from_embeds,
|
||||
name_0="vllm",
|
||||
name_1="vllm_from_embeds",
|
||||
)
|
||||
|
||||
if use_rocm_aiter:
|
||||
# this is to ensure that vllm engine
|
||||
# has deallocated the memory before running the next
|
||||
|
||||
@ -31,8 +31,13 @@ def test_deepseek_mla_attn_backend_module():
|
||||
assert model_runner.attn_backend.__name__ == "TritonMLABackend"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_prompt(batch_size):
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
|
||||
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||
def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
max_num_batched_tokens=100000,
|
||||
@ -43,11 +48,20 @@ def test_prepare_prompt(batch_size):
|
||||
seq_lens: list[int] = []
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
block_tables = {0: [1]}
|
||||
expected_input_embeds_len = 0
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * seq_len,
|
||||
prompt_embeds=torch.rand(seq_len, 10),
|
||||
)
|
||||
expected_input_embeds_len += seq_len
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len))
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -68,6 +82,7 @@ def test_prepare_prompt(batch_size):
|
||||
seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
@ -121,7 +136,11 @@ def test_prepare_prompt(batch_size):
|
||||
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
if expected_input_embeds_len == 0:
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
assert input_embeds is None
|
||||
else:
|
||||
assert len(input_embeds) == expected_input_embeds_len
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
@ -145,8 +164,13 @@ def test_prepare_prompt(batch_size):
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_decode_cuda_graph(batch_size):
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
|
||||
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||
def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
@ -164,10 +188,19 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
context_lens.append(context_len)
|
||||
seq_data = SequenceData.from_seqs(range(context_len))
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * context_len,
|
||||
prompt_embeds=torch.rand(context_len, 10),
|
||||
)
|
||||
output_embed = torch.rand(10)
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(context_len))
|
||||
output_embed = None
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
# Append one token ID since prefill is finished.
|
||||
seq_data.append_token_id(1, 0)
|
||||
seq_data.append_token_id(1, 0, output_embed)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
@ -180,9 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input.input_tokens, model_input.input_positions,
|
||||
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
|
||||
@ -227,7 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# block table's first index corresponds to each batch, meaning in
|
||||
# decoding it is each token.
|
||||
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
|
||||
# Block table's second dim correspondsd to each token's block number.
|
||||
# Block table's second dim corresponds to each token's block number.
|
||||
# It is padded up to
|
||||
assert attn_metadata.block_tables.shape[1] == (
|
||||
model_runner.get_max_block_per_batch())
|
||||
@ -235,7 +271,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
|
||||
assert len(input_tokens) == expected_bs
|
||||
assert len(input_positions) == expected_bs
|
||||
torch.allclose(input_tokens, input_positions)
|
||||
if use_prompt_embeds:
|
||||
expected_input_embeds_length = start_loc[-1]
|
||||
assert len(input_embeds) == expected_input_embeds_length
|
||||
assert expected_input_embeds_length <= expected_bs
|
||||
else:
|
||||
assert input_embeds is None
|
||||
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
@ -266,25 +307,27 @@ def test_empty_seq_group():
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert attn_metadata is None
|
||||
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert input_embeds is None
|
||||
assert attn_metadata is None
|
||||
assert return_seq_lens is None
|
||||
|
||||
@ -299,9 +342,15 @@ def distributed_init():
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
@pytest.mark.parametrize('use_prompt_embeds', [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
|
||||
distributed_init, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
@ -320,11 +369,20 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
block_tables = {0: [1]}
|
||||
prefill_batch_size = batch_size // 2
|
||||
decode_batch_size = batch_size - prefill_batch_size
|
||||
expected_input_embeds_len = 0
|
||||
for i in range(prefill_batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * seq_len,
|
||||
prompt_embeds=torch.rand(seq_len, 10),
|
||||
)
|
||||
expected_input_embeds_len += seq_len
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(seq_len), )
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -340,8 +398,21 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_data = SequenceData.from_seqs(range(context_len))
|
||||
seq_data.append_token_id(1, 0)
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * context_len,
|
||||
prompt_embeds=torch.rand(context_len, 10),
|
||||
)
|
||||
output_embed = torch.rand(10)
|
||||
# This also iterates the expected input_embeds, because the model
|
||||
# needs both the input and output embeddings passed into together
|
||||
expected_input_embeds_len += 1
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(context_len), )
|
||||
output_embed = None
|
||||
assert len(seq_data.prompt_token_ids) == context_len
|
||||
seq_data.append_token_id(1, 0, output_embed)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@ -355,11 +426,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
decode_metadata_list.append(seq_group_metadata)
|
||||
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
|
||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||
decode_meta_actual = attn_metadata.decode_metadata
|
||||
@ -369,6 +440,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
assert attn_metadata.num_prefills == prefill_batch_size
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||
if expected_input_embeds_len == 0:
|
||||
assert input_embeds is None
|
||||
else:
|
||||
assert len(input_embeds) == expected_input_embeds_len
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
|
||||
Reference in New Issue
Block a user