[Core] Refactor Worker and ModelRunner to consolidate control plane communication (#5408)
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu> Signed-off-by: Stephanie <swang@anyscale.com> Co-authored-by: Stephanie <swang@anyscale.com>
This commit is contained in:
@ -61,12 +61,13 @@ def test_prepare_prompt(batch_size):
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
seq_len - 1)
|
||||
selected_token_start_idx += seq_len
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = model_input.slot_mapping
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
@ -174,10 +175,11 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input.input_tokens, model_input.input_positions,
|
||||
model_input.attn_metadata, model_input.slot_mapping)
|
||||
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
@ -259,32 +261,29 @@ def test_empty_seq_group():
|
||||
enforce_eager=False,
|
||||
)
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.slot_mapping,
|
||||
)
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, slot_mapping,
|
||||
return_seq_lens) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.slot_mapping,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
assert len(return_seq_lens) == 0
|
||||
assert return_seq_lens is None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -353,8 +352,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
decode_metadata_list.append(seq_group_metadata)
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, _, _, _,
|
||||
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
)
|
||||
|
||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||
decode_meta_actual = attn_metadata.decode_metadata
|
||||
@ -367,7 +370,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
attn_metadata = model_runner._prepare_model_input(
|
||||
attn_metadata = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list).attn_metadata
|
||||
|
||||
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
|
||||
|
||||
Reference in New Issue
Block a user