[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)
This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend. It also refactors subquery_start_loc which was not refactored in the previous PR
This commit is contained in:
@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size):
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
seq_len - 1)
|
||||
selected_token_start_idx += seq_len
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = model_input.slot_mapping
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is True
|
||||
assert attn_metadata.num_prefills > 0
|
||||
assert attn_metadata.num_decode_tokens == 0
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
||||
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
|
||||
assert attn_metadata.max_decode_seq_len == 0
|
||||
|
||||
# Test subquery start locs.
|
||||
start_idx = 0
|
||||
@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size):
|
||||
start_idx += seq_len
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
# Test seq start locs. Note that for normal prefill it is
|
||||
# equivalent to subquery_start_loc.
|
||||
# equivalent to query_start_loc.
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size):
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
assert input_tokens == input_positions
|
||||
torch.allclose(input_tokens, input_positions)
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
seq_lens = []
|
||||
context_lens = []
|
||||
seq_group_metadata_list = []
|
||||
# Assume each seq group finishes prefill.
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = list(range(seq_len))
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
context_lens.append(context_len)
|
||||
seq_data = list(range(context_len))
|
||||
seq_data = SequenceData(seq_data)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
# Append one token ID since prefill is finished.
|
||||
seq_data.append_token_id(1, 0)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input.input_tokens, model_input.input_positions,
|
||||
model_input.attn_metadata, model_input.slot_mapping)
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is False
|
||||
assert attn_metadata.seq_lens is None
|
||||
assert attn_metadata.subquery_start_loc is None
|
||||
assert attn_metadata.seq_start_loc is None
|
||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_prefill_tokens == 0
|
||||
seq_lens = [context_len + 1 for context_len in context_lens]
|
||||
# seq_lens are padded to expected_bs
|
||||
for _ in range(expected_bs - len(seq_lens)):
|
||||
seq_lens.append(1)
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for _ in context_lens:
|
||||
# decode has only 1 token for query.
|
||||
start_idx += 1
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
assert torch.allclose(
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.tensor(context_lens, dtype=torch.int, device=device))
|
||||
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
||||
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
||||
@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# It is padded up to
|
||||
assert attn_metadata.block_tables.shape[1] == (
|
||||
model_runner.get_max_block_per_batch())
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is True
|
||||
|
||||
assert len(input_tokens) == expected_bs
|
||||
assert len(input_positions) == expected_bs
|
||||
assert input_tokens == input_positions
|
||||
torch.allclose(input_tokens, input_positions)
|
||||
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
for _ in context_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx)
|
||||
selected_token_start_idx += 1
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
# query lens is all 1 for decode.
|
||||
query_lens=[1 for _ in range(len(context_lens))],
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
@ -220,15 +257,27 @@ def test_empty_seq_group():
|
||||
enforce_eager=False,
|
||||
)
|
||||
seq_group_metadata_list = []
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.slot_mapping,
|
||||
)
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, slot_mapping,
|
||||
return_seq_lens) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.slot_mapping,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
# Add decode requests
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(seq_len))
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(context_len))
|
||||
seq_data = SequenceData(prompt_toks)
|
||||
seq_data.append_token_id(1, 0)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
||||
assert len(input_positions) == len(input_tokens)
|
||||
assert attn_metadata.num_prefills == prefill_batch_size
|
||||
if enforce_eager:
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
else:
|
||||
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
|
||||
decode_batch_size)
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
prefill_meta = model_runner._prepare_prompt(
|
||||
prefill_metadata_list).attn_metadata
|
||||
decode_meta = model_runner._prepare_decode(
|
||||
decode_metadata_list).attn_metadata
|
||||
attn_metadata = model_runner._prepare_model_input(
|
||||
seq_group_metadata_list).attn_metadata
|
||||
|
||||
for attr_expected, attr_actual in zip(vars(prefill_meta),
|
||||
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
|
||||
vars(prefill_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
for attr_expected, attr_actual in zip(vars(decode_meta),
|
||||
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
|
||||
vars(decode_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
|
||||
Reference in New Issue
Block a user