[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)

This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend.

It also refactors subquery_start_loc which was not refactored in the previous PR
This commit is contained in:
SangBin Cho
2024-05-15 14:00:10 +09:00
committed by GitHub
parent 8a7cc254a0
commit 65bf2ac165
18 changed files with 781 additions and 734 deletions

View File

@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = model_input.slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is True
assert attn_metadata.num_prefills > 0
assert attn_metadata.num_decode_tokens == 0
assert torch.allclose(
attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_seq_len == max(seq_lens)
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
assert attn_metadata.max_decode_seq_len == 0
# Test subquery start locs.
start_idx = 0
@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size):
start_idx += seq_len
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.subquery_start_loc,
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
# equivalent to query_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size):
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
assert input_tokens == input_positions
torch.allclose(input_tokens, input_positions)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False,
)
seq_lens = []
context_lens = []
seq_group_metadata_list = []
# Assume each seq group finishes prefill.
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = list(range(seq_len))
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = list(range(context_len))
seq_data = SequenceData(seq_data)
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.slot_mapping)
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is False
assert attn_metadata.seq_lens is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_seq_len == max(seq_lens)
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
seq_lens = [context_len + 1 for context_len in context_lens]
# seq_lens are padded to expected_bs
for _ in range(expected_bs - len(seq_lens)):
seq_lens.append(1)
assert attn_metadata.seq_lens == seq_lens
start_idx = 0
start_loc = [start_idx]
for _ in context_lens:
# decode has only 1 token for query.
start_idx += 1
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
assert torch.allclose(
attn_metadata.context_lens_tensor,
torch.tensor(context_lens, dtype=torch.int, device=device))
assert attn_metadata.max_decode_seq_len == max(seq_lens)
assert torch.allclose(
attn_metadata.seq_lens_tensor[:len(seq_lens)],
torch.tensor(seq_lens, dtype=torch.int, device=device))
@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size):
# It is padded up to
assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True
assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
assert input_tokens == input_positions
torch.allclose(input_tokens, input_positions)
# Verify Sampling
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
for _ in context_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
# query lens is all 1 for decode.
query_lens=[1 for _ in range(len(context_lens))],
device=model_runner.device,
pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices
@ -220,15 +257,27 @@ def test_empty_seq_group():
enforce_eager=False,
)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, slot_mapping,
return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
model_input.seq_lens,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(seq_len))
context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_decode_tokens == decode_batch_size
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta = model_runner._prepare_prompt(
prefill_metadata_list).attn_metadata
decode_meta = model_runner._prepare_decode(
decode_metadata_list).attn_metadata
attn_metadata = model_runner._prepare_model_input(
seq_group_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(prefill_meta),
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(decode_meta),
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]