[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)

This commit is contained in:
youkaichao
2024-05-08 12:07:05 -07:00
committed by GitHub
parent ad932a221d
commit 20cfcdec99
21 changed files with 137 additions and 109 deletions

View File

@ -54,10 +54,10 @@ def test_swap() -> None:
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
# Test swap out.
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)]
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=[],
blocks_to_swap_in={},
blocks_to_swap_in=[],
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=[],
)
@ -66,24 +66,24 @@ def test_swap() -> None:
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_out.items():
for src, dst in blocks_to_swap_out:
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
# Test swap in.
execute_model_req.blocks_to_swap_out = {}
execute_model_req.blocks_to_swap_in = {
19: 45,
67: 23,
12: 78,
40: 99,
1: 71
}
execute_model_req.blocks_to_swap_out = []
execute_model_req.blocks_to_swap_in = [
(19, 45),
(67, 23),
(12, 78),
(40, 99),
(1, 71),
]
worker.execute_model(execute_model_req=execute_model_req)
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in execute_model_req.blocks_to_swap_in.items():
for src, dst in execute_model_req.blocks_to_swap_in:
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])