Use NCCL instead of ray for control-plane communication to remove serialization overhead (#2221)

This commit is contained in:
Zhuohan Li
2024-01-04 03:30:22 +08:00
committed by GitHub
parent 1066cbd152
commit fd4ea8ef5c
34 changed files with 524 additions and 262 deletions

View File

@ -8,11 +8,11 @@ import pytest
import requests
def _query_server(prompt: str) -> dict:
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
response = requests.post("http://localhost:8000/generate",
json={
"prompt": prompt,
"max_tokens": 100,
"max_tokens": max_tokens,
"temperature": 0,
"ignore_eos": True
})
@ -20,6 +20,10 @@ def _query_server(prompt: str) -> dict:
return response.json()
def _query_server_long(prompt: str) -> dict:
return _query_server(prompt, max_tokens=500)
@pytest.fixture
def api_server():
script_path = Path(__file__).parent.joinpath(
@ -68,10 +72,11 @@ def test_api_server(api_server):
for result in pool.map(_query_server, prompts):
assert result
with Pool(32) as pool:
# Cancel requests
prompts = ["canceled requests"] * 100
pool.map_async(_query_server, prompts)
time.sleep(0.001)
pool.map_async(_query_server_long, prompts)
time.sleep(0.01)
pool.terminate()
pool.join()

View File

@ -49,12 +49,13 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {}
copy_src = []
copy_dst = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
copy_src.append(src_blocks[i])
copy_dst.append(dst_blocks[2 * i])
copy_src.append(src_blocks[i])
copy_dst.append(dst_blocks[2 * i + 1])
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
@ -66,15 +67,14 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst)
# Run the reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
for src, dst in zip(copy_src, copy_dst):
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):

View File

@ -33,8 +33,9 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _ = model_runner._prepare_prompt(
seq_group_metadata_list)
input_tokens, input_positions, _, return_prompt_lens = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)