[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)
This commit is contained in:
@ -219,7 +219,7 @@ def test_swap():
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_out(seq_group)
|
||||
assert list(mapping.keys()) == gpu_blocks
|
||||
assert [x[0] for x in mapping] == gpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||
@ -232,7 +232,7 @@ def test_swap():
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_in(seq_group)
|
||||
assert list(mapping.keys()) == cpu_blocks
|
||||
assert [x[0] for x in mapping] == cpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
||||
|
||||
@ -355,8 +355,8 @@ def test_swap():
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over new prefill.
|
||||
_, seq_group = create_dummy_prompt("2", prompt_length=60)
|
||||
@ -365,8 +365,8 @@ def test_swap():
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def test_running_prefill_prioritized_over_swap():
|
||||
@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
|
||||
# Add 1 more task. Swap is not possible, so prefill is running.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
|
||||
# Now although swap is possible, running prefill is prioritized.
|
||||
@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 1
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def test_chunked_prefill_preempt():
|
||||
@ -493,8 +493,8 @@ def test_chunked_prefill_preempt():
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert out.blocks_to_swap_in == []
|
||||
|
||||
# Make sure we can reschedule preempted request.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
|
||||
@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 2
|
||||
assert out.num_batched_tokens == 2
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over prefill.
|
||||
@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 3
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def initialize_scheduler(*,
|
||||
@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
|
||||
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
|
||||
# assert budget.num_curr_seqs == 1
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
assert output.blocks_to_swap_out == []
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
scheduler.block_manager.swap_out = MagicMock()
|
||||
expected_swap_mapping = {"5": "7"}
|
||||
expected_swap_mapping = [("5", "7")]
|
||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
assert len(output.preempted) == 0
|
||||
assert len(output.swapped_out) == 0
|
||||
# Nothing is preempted.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
assert output.blocks_to_swap_out == []
|
||||
# Since append_slot returns the source -> dist mapping, it should
|
||||
# applied.
|
||||
assert output.blocks_to_copy == [(2, 3)]
|
||||
@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
# swap in is the reverse of swap out
|
||||
blocks_to_swap_in_reverse = {}
|
||||
for swapin, swapout in output.blocks_to_swap_in.items():
|
||||
blocks_to_swap_in_reverse[swapout] = swapin
|
||||
blocks_to_swap_in_reverse = []
|
||||
for swapin, swapout in output.blocks_to_swap_in:
|
||||
blocks_to_swap_in_reverse.append((swapout, swapin))
|
||||
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
||||
|
||||
|
||||
@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = set()
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -808,7 +808,7 @@ def test_infeasible_swap():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
|
||||
@ -315,7 +315,10 @@ def test_swap_blocks(
|
||||
else:
|
||||
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
|
||||
block_mapping = dict(zip(src_blocks, dst_blocks))
|
||||
block_mapping = list(zip(src_blocks, dst_blocks))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu").view(-1, 2)
|
||||
|
||||
# Create the KV caches on the first device.
|
||||
src_key_caches, src_value_caches = kv_cache_factory(
|
||||
@ -331,10 +334,12 @@ def test_swap_blocks(
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
|
||||
block_mapping_tensor)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||
block_mapping_tensor)
|
||||
|
||||
for src, dst in block_mapping.items():
|
||||
for src, dst in block_mapping:
|
||||
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
||||
dist_key_caches[0][dst].cpu())
|
||||
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
||||
|
||||
@ -54,10 +54,10 @@ def test_swap() -> None:
|
||||
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
|
||||
|
||||
# Test swap out.
|
||||
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
|
||||
blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)]
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=[],
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_swap_in=[],
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=[],
|
||||
)
|
||||
@ -66,24 +66,24 @@ def test_swap() -> None:
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in blocks_to_swap_out.items():
|
||||
for src, dst in blocks_to_swap_out:
|
||||
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
|
||||
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
||||
|
||||
# Test swap in.
|
||||
execute_model_req.blocks_to_swap_out = {}
|
||||
execute_model_req.blocks_to_swap_in = {
|
||||
19: 45,
|
||||
67: 23,
|
||||
12: 78,
|
||||
40: 99,
|
||||
1: 71
|
||||
}
|
||||
execute_model_req.blocks_to_swap_out = []
|
||||
execute_model_req.blocks_to_swap_in = [
|
||||
(19, 45),
|
||||
(67, 23),
|
||||
(12, 78),
|
||||
(40, 99),
|
||||
(1, 71),
|
||||
]
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in execute_model_req.blocks_to_swap_in.items():
|
||||
for src, dst in execute_model_req.blocks_to_swap_in:
|
||||
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
||||
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
||||
|
||||
Reference in New Issue
Block a user