From a2981c42720a34b5abf59c4c14df701f8105d4cd Mon Sep 17 00:00:00 2001 From: cong-meta Date: Thu, 30 Oct 2025 12:10:16 -0700 Subject: [PATCH] [EP/DP][API Server] Enable DP-aware routing in OpenAI API requests (#24945) Co-authored-by: Cong Chen --- tests/entrypoints/openai/test_serving_chat.py | 76 +++++++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 4 + vllm/entrypoints/openai/serving_completion.py | 4 + vllm/entrypoints/openai/serving_engine.py | 15 ++++ 4 files changed, 99 insertions(+) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index d1367b4eea..1b83ed7e31 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -651,3 +651,79 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): await serving_chat.create_chat_completion(req) engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1] assert engine_prompt.get("cache_salt") == "test_salt" + + +@pytest.mark.asyncio +async def test_serving_chat_data_parallel_rank_extraction(): + """Test that data_parallel_rank is properly extracted from header and + passed to engine.""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() + + # Mock the generate method to return an async generator + async def mock_generate(*args, **kwargs): + # Yield a fake RequestOutput + from vllm.outputs import CompletionOutput, RequestOutput + + yield RequestOutput( + request_id="test-request", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="test response", + token_ids=[4, 5, 6], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop", + stop_reason=None, + ) + ], + finished=True, + ) + + mock_engine.generate = AsyncMock(side_effect=mock_generate) + + serving_chat = _build_serving_chat(mock_engine) + + # Test when data_parallel_rank is present in header + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "what is 1+1?"}], + ) + + # Mock request with X-data-parallel-rank header + mock_raw_request = MagicMock() + mock_raw_request.headers = {"X-data-parallel-rank": "2"} + mock_raw_request.state = MagicMock() + + with suppress(Exception): + await serving_chat.create_chat_completion(req, mock_raw_request) + + # Verify that data_parallel_rank was passed to engine.generate + assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs + assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] == 2 + + # Test when data_parallel_rank is not present (defaults to None) + req_no_dp = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "what is 2+2?"}], + ) + + # Mock request with no header + mock_raw_request_no_dp = MagicMock() + mock_raw_request_no_dp.headers = {} + mock_raw_request_no_dp.state = MagicMock() + + with suppress(Exception): + await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp) + + # Verify that data_parallel_rank defaults to None + assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs + assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 934ff78b2c..bb770ecf03 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing): if raw_request: raw_request.state.request_metadata = request_metadata + # Extract data_parallel_rank from header (router can inject it) + data_parallel_rank = self._get_data_parallel_rank(raw_request) + # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] try: @@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing): priority=request.priority, prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, + data_parallel_rank=data_parallel_rank, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 62bc932f8b..14dbdd4cb4 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -141,6 +141,9 @@ class OpenAIServingCompletion(OpenAIServing): logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + # Extract data_parallel_rank from header (router can inject it) + data_parallel_rank = self._get_data_parallel_rank(raw_request) + # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] try: @@ -224,6 +227,7 @@ class OpenAIServingCompletion(OpenAIServing): priority=request.priority, prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, + data_parallel_rank=data_parallel_rank, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index af5a423134..c0750cd641 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1298,6 +1298,21 @@ class OpenAIServing: return raw_request.headers.get("X-Request-Id", default) + @staticmethod + def _get_data_parallel_rank(raw_request: Request | None) -> int | None: + """Pulls the data parallel rank from a header, if provided""" + if raw_request is None: + return None + + rank_str = raw_request.headers.get("X-data-parallel-rank") + if rank_str is None: + return None + + try: + return int(rank_str) + except ValueError: + return None + @staticmethod def _get_decoded_token( logprob: Logprob,