[EP/DP][API Server] Enable DP-aware routing in OpenAI API requests (#24945)
Co-authored-by: Cong Chen <prowindy@gmail.com>
This commit is contained in:
@ -651,3 +651,79 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
|
||||
assert engine_prompt.get("cache_salt") == "test_salt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_data_parallel_rank_extraction():
|
||||
"""Test that data_parallel_rank is properly extracted from header and
|
||||
passed to engine."""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# Mock the generate method to return an async generator
|
||||
async def mock_generate(*args, **kwargs):
|
||||
# Yield a fake RequestOutput
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
|
||||
yield RequestOutput(
|
||||
request_id="test-request",
|
||||
prompt="test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
index=0,
|
||||
text="test response",
|
||||
token_ids=[4, 5, 6],
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
stop_reason=None,
|
||||
)
|
||||
],
|
||||
finished=True,
|
||||
)
|
||||
|
||||
mock_engine.generate = AsyncMock(side_effect=mock_generate)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test when data_parallel_rank is present in header
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "what is 1+1?"}],
|
||||
)
|
||||
|
||||
# Mock request with X-data-parallel-rank header
|
||||
mock_raw_request = MagicMock()
|
||||
mock_raw_request.headers = {"X-data-parallel-rank": "2"}
|
||||
mock_raw_request.state = MagicMock()
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req, mock_raw_request)
|
||||
|
||||
# Verify that data_parallel_rank was passed to engine.generate
|
||||
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] == 2
|
||||
|
||||
# Test when data_parallel_rank is not present (defaults to None)
|
||||
req_no_dp = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "what is 2+2?"}],
|
||||
)
|
||||
|
||||
# Mock request with no header
|
||||
mock_raw_request_no_dp = MagicMock()
|
||||
mock_raw_request_no_dp.headers = {}
|
||||
mock_raw_request_no_dp.state = MagicMock()
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)
|
||||
|
||||
# Verify that data_parallel_rank defaults to None
|
||||
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
|
||||
|
||||
@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
@ -141,6 +141,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
@ -224,6 +227,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
@ -1298,6 +1298,21 @@ class OpenAIServing:
|
||||
|
||||
return raw_request.headers.get("X-Request-Id", default)
|
||||
|
||||
@staticmethod
|
||||
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
|
||||
"""Pulls the data parallel rank from a header, if provided"""
|
||||
if raw_request is None:
|
||||
return None
|
||||
|
||||
rank_str = raw_request.headers.get("X-data-parallel-rank")
|
||||
if rank_str is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(rank_str)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(
|
||||
logprob: Logprob,
|
||||
|
||||
Reference in New Issue
Block a user