[EP/DP][API Server] Enable DP-aware routing in OpenAI API requests (#24945)
Co-authored-by: Cong Chen <prowindy@gmail.com>
This commit is contained in:
@ -651,3 +651,79 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
|
||||
assert engine_prompt.get("cache_salt") == "test_salt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_data_parallel_rank_extraction():
|
||||
"""Test that data_parallel_rank is properly extracted from header and
|
||||
passed to engine."""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# Mock the generate method to return an async generator
|
||||
async def mock_generate(*args, **kwargs):
|
||||
# Yield a fake RequestOutput
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
|
||||
yield RequestOutput(
|
||||
request_id="test-request",
|
||||
prompt="test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
index=0,
|
||||
text="test response",
|
||||
token_ids=[4, 5, 6],
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
stop_reason=None,
|
||||
)
|
||||
],
|
||||
finished=True,
|
||||
)
|
||||
|
||||
mock_engine.generate = AsyncMock(side_effect=mock_generate)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test when data_parallel_rank is present in header
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "what is 1+1?"}],
|
||||
)
|
||||
|
||||
# Mock request with X-data-parallel-rank header
|
||||
mock_raw_request = MagicMock()
|
||||
mock_raw_request.headers = {"X-data-parallel-rank": "2"}
|
||||
mock_raw_request.state = MagicMock()
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req, mock_raw_request)
|
||||
|
||||
# Verify that data_parallel_rank was passed to engine.generate
|
||||
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] == 2
|
||||
|
||||
# Test when data_parallel_rank is not present (defaults to None)
|
||||
req_no_dp = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "what is 2+2?"}],
|
||||
)
|
||||
|
||||
# Mock request with no header
|
||||
mock_raw_request_no_dp = MagicMock()
|
||||
mock_raw_request_no_dp.headers = {}
|
||||
mock_raw_request_no_dp.state = MagicMock()
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)
|
||||
|
||||
# Verify that data_parallel_rank defaults to None
|
||||
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
|
||||
|
||||
Reference in New Issue
Block a user