From b5d70751d82c272a72f105299ef24ae316c41ded Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 30 Oct 2025 12:39:34 +0800 Subject: [PATCH] [BugFix] Reordering extend logic fix (#27739) Signed-off-by: Lucas Wilkinson --- tests/v1/attention/test_batch_reordering.py | 21 ++++++++++++++++++--- vllm/v1/attention/backends/utils.py | 10 +++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/v1/attention/test_batch_reordering.py b/tests/v1/attention/test_batch_reordering.py index b271409b92..e372194542 100644 --- a/tests/v1/attention/test_batch_reordering.py +++ b/tests/v1/attention/test_batch_reordering.py @@ -53,7 +53,7 @@ REORDER_TEST_CASES = { expected_modified=True, ), "already_ordered": ReorderTestCase( - requests=[(1, 10), (1, 20), (100, 100), (200, 200)], + requests=[(1, 10), (1, 20), (100, 100), (200, 0)], expected_order=[0, 1, 2, 3], expected_modified=False, ), @@ -74,15 +74,30 @@ REORDER_TEST_CASES = { expected_modified=True, ), "decode_extend_prefill": ReorderTestCase( - requests=[(100, 100), (10, 50), (1, 10)], + requests=[(100, 0), (10, 50), (1, 10)], expected_order=[2, 1, 0], expected_modified=True, ), "extend_prefill_only": ReorderTestCase( - requests=[(100, 100), (10, 50), (200, 200), (20, 75)], + requests=[(100, 0), (10, 50), (200, 0), (20, 75)], expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place expected_modified=True, ), + "complicated_mixed_interleaved": ReorderTestCase( + requests=[ + (1, 20), + (1, 50), + (374, 0), + (300, 20), + (1, 20), + (256, 0), + (1, 5), + (27, 0), + (1, 4), + ], + expected_order=[0, 1, 6, 8, 4, 3, 2, 7, 5], + expected_modified=True, + ), } diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 389baf1488..07d62e9849 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -811,8 +811,8 @@ def reorder_batch_to_split_decodes_and_prefills( num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs] is_decode = num_scheduled_tokens_np <= decode_threshold - is_extend = (~is_decode) & (num_computed_tokens_np > num_scheduled_tokens_np) - is_prefill = (~is_decode) & (num_computed_tokens_np == num_scheduled_tokens_np) + is_extend = (~is_decode) & (num_computed_tokens_np > 0) + is_prefill = (~is_decode) & (num_computed_tokens_np == 0) # Desired order: decode → extend → prefill req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default @@ -832,11 +832,11 @@ def reorder_batch_to_split_decodes_and_prefills( return False # Extract indices that need swapping and sort by target region - swap_indices = np.where(needs_swap)[0] + orig_indices = np.where(needs_swap)[0] sorted_order = np.argsort(req_regions[needs_swap], kind="stable") - dest_indices = swap_indices[sorted_order] + src_indices = orig_indices[sorted_order] - src_dest_map = {int(src): int(dst) for src, dst in zip(swap_indices, dest_indices)} + src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)} for src in src_dest_map: dst = src_dest_map[src]