[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)

This commit is contained in:
youkaichao
2024-05-08 12:07:05 -07:00
committed by GitHub
parent ad932a221d
commit 20cfcdec99
21 changed files with 137 additions and 109 deletions

View File

@ -219,7 +219,7 @@ def test_swap():
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
assert list(mapping.keys()) == gpu_blocks
assert [x[0] for x in mapping] == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
@ -232,7 +232,7 @@ def test_swap():
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
assert list(mapping.keys()) == cpu_blocks
assert [x[0] for x in mapping] == cpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks

View File

@ -355,8 +355,8 @@ def test_swap():
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
# Add 1 more task. Swap should be prioritized over new prefill.
_, seq_group = create_dummy_prompt("2", prompt_length=60)
@ -365,8 +365,8 @@ def test_swap():
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def test_running_prefill_prioritized_over_swap():
@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap():
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
# Add 1 more task. Swap is not possible, so prefill is running.
scheduler.block_manager.can_swap_in = MagicMock()
@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert out.scheduled_seq_groups[0].seq_group == seq_group2
# Now although swap is possible, running prefill is prioritized.
@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1)
@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 1
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1)
@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap():
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def test_chunked_prefill_preempt():
@ -493,8 +493,8 @@ def test_chunked_prefill_preempt():
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out == []
assert out.blocks_to_swap_in == []
# Make sure we can reschedule preempted request.
_, out = schedule_and_update_computed_tokens(scheduler)

View File

@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 2
assert out.num_batched_tokens == 2
assert out.blocks_to_swap_out != {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
append_new_token(out, 1)
# Add 1 more task. Swap should be prioritized over prefill.
@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
assert len(out.scheduled_seq_groups) == 3
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 3
assert out.blocks_to_swap_in != {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def initialize_scheduler(*,
@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
# assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {}
assert output.blocks_to_swap_out == []
# Nothing is copied.
assert output.blocks_to_copy == []
@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
scheduler.block_manager.swap_out = MagicMock()
expected_swap_mapping = {"5": "7"}
expected_swap_mapping = [("5", "7")]
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
remainig_running, output = scheduler._schedule_running(
@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert len(output.preempted) == 0
assert len(output.swapped_out) == 0
# Nothing is preempted.
assert output.blocks_to_swap_out == {}
assert output.blocks_to_swap_out == []
# Since append_slot returns the source -> dist mapping, it should
# applied.
assert output.blocks_to_copy == [(2, 3)]
@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
# swap in is the reverse of swap out
blocks_to_swap_in_reverse = {}
for swapin, swapout in output.blocks_to_swap_in.items():
blocks_to_swap_in_reverse[swapout] = swapin
blocks_to_swap_in_reverse = []
for swapin, swapout in output.blocks_to_swap_in:
blocks_to_swap_in_reverse.append((swapout, swapin))
assert blocks_to_swap_out == blocks_to_swap_in_reverse
@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group)
@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = set()
blocks_to_swap_out = {}
blocks_to_swap_out = []
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@ -808,7 +808,7 @@ def test_infeasible_swap():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out = {}
blocks_to_swap_out = []
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)

View File

@ -315,7 +315,10 @@ def test_swap_blocks(
else:
dst_blocks = random.sample(range(num_blocks), num_mappings)
block_mapping = dict(zip(src_blocks, dst_blocks))
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
# Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory(
@ -331,10 +334,12 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping_tensor)
for src, dst in block_mapping.items():
for src, dst in block_mapping:
assert torch.allclose(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(),

View File

@ -54,10 +54,10 @@ def test_swap() -> None:
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
# Test swap out.
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)]
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=[],
blocks_to_swap_in={},
blocks_to_swap_in=[],
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=[],
)
@ -66,24 +66,24 @@ def test_swap() -> None:
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_out.items():
for src, dst in blocks_to_swap_out:
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
# Test swap in.
execute_model_req.blocks_to_swap_out = {}
execute_model_req.blocks_to_swap_in = {
19: 45,
67: 23,
12: 78,
40: 99,
1: 71
}
execute_model_req.blocks_to_swap_out = []
execute_model_req.blocks_to_swap_in = [
(19, 45),
(67, 23),
(12, 78),
(40, 99),
(1, 71),
]
worker.execute_model(execute_model_req=execute_model_req)
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in execute_model_req.blocks_to_swap_in.items():
for src, dst in execute_model_req.blocks_to_swap_in:
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])