Compare commits
11 Commits
amd_dev
...
woosuk/tes
| Author | SHA1 | Date | |
|---|---|---|---|
| cb439737db | |||
| a1cac48477 | |||
| 6102536d65 | |||
| f65da69c72 | |||
| a5281395e9 | |||
| eda71c2847 | |||
| 1bff9a59ec | |||
| 69c9a01538 | |||
| 8935ca208d | |||
| dddad8a81c | |||
| 7f783b8a4a |
@ -12,4 +12,5 @@ torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytor
|
|||||||
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
||||||
xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
|
xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
|
||||||
# FlashInfer should be updated together with the Dockerfile
|
# FlashInfer should be updated together with the Dockerfile
|
||||||
flashinfer-python==0.4.0
|
flashinfer-python==0.4.0
|
||||||
|
apache-tvm-ffi==0.1.0b15
|
||||||
|
|||||||
@ -649,5 +649,65 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
|||||||
req.cache_salt = "test_salt"
|
req.cache_salt = "test_salt"
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
await serving_chat.create_chat_completion(req)
|
await serving_chat.create_chat_completion(req)
|
||||||
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
|
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
|
||||||
assert engine_prompt.get("cache_salt") == "test_salt"
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serving_chat_data_parallel_rank_extraction():
|
||||||
|
"""Test that data_parallel_rank is properly extracted from header and passed to engine."""
|
||||||
|
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||||
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
mock_engine.errored = False
|
||||||
|
|
||||||
|
models = OpenAIServingModels(engine_client=mock_engine,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
|
model_config=MockModelConfig())
|
||||||
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
|
MockModelConfig(),
|
||||||
|
models,
|
||||||
|
response_role="assistant",
|
||||||
|
chat_template=CHAT_TEMPLATE,
|
||||||
|
chat_template_content_format="auto",
|
||||||
|
request_logger=None)
|
||||||
|
|
||||||
|
# Test when data_parallel_rank is present in header
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock request with X-data-parallel-rank header
|
||||||
|
mock_raw_request = MagicMock()
|
||||||
|
mock_raw_request.headers = {"X-data-parallel-rank": "2"}
|
||||||
|
mock_raw_request.state = MagicMock()
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
await serving_chat.create_chat_completion(req, mock_raw_request)
|
||||||
|
|
||||||
|
# Verify that data_parallel_rank was passed to engine.generate
|
||||||
|
assert 'data_parallel_rank' in mock_engine.generate.call_args.kwargs
|
||||||
|
assert mock_engine.generate.call_args.kwargs['data_parallel_rank'] == 2
|
||||||
|
|
||||||
|
# Test when data_parallel_rank is not present (defaults to None)
|
||||||
|
req_no_dp = ChatCompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 2+2?"
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock request with no header
|
||||||
|
mock_raw_request_no_dp = MagicMock()
|
||||||
|
mock_raw_request_no_dp.headers = {}
|
||||||
|
mock_raw_request_no_dp.state = MagicMock()
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)
|
||||||
|
|
||||||
|
# Verify that data_parallel_rank defaults to None
|
||||||
|
assert 'data_parallel_rank' in mock_engine.generate.call_args.kwargs
|
||||||
|
assert mock_engine.generate.call_args.kwargs['data_parallel_rank'] is None
|
||||||
|
|||||||
@ -75,7 +75,7 @@ class SampleRequest:
|
|||||||
Represents a single inference request for benchmarking.
|
Represents a single inference request for benchmarking.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt: str | list[str]
|
prompt: str | list[str] | list[int]
|
||||||
prompt_len: int
|
prompt_len: int
|
||||||
expected_output_len: int
|
expected_output_len: int
|
||||||
multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None
|
multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None
|
||||||
@ -402,8 +402,9 @@ def gen_prompt_decode_to_target_len(
|
|||||||
remain_num_try = max_retry
|
remain_num_try = max_retry
|
||||||
token_mismatch = 0
|
token_mismatch = 0
|
||||||
while True:
|
while True:
|
||||||
prompt = tokenizer.decode(token_sequence)
|
# prompt = tokenizer.decode(token_sequence)
|
||||||
token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
# token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||||
|
prompt = token_sequence
|
||||||
if remain_num_try <= 0:
|
if remain_num_try <= 0:
|
||||||
if len(token_sequence) != target_token_len:
|
if len(token_sequence) != target_token_len:
|
||||||
token_mismatch = len(token_sequence) - target_token_len
|
token_mismatch = len(token_sequence) - target_token_len
|
||||||
|
|||||||
@ -165,9 +165,10 @@ async def async_request_openai_completions(
|
|||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
"logprobs": request_func_input.logprobs,
|
"logprobs": request_func_input.logprobs,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"stream_options": {
|
# NOTE(woosuk): Disabled for PD.
|
||||||
"include_usage": True,
|
# "stream_options": {
|
||||||
},
|
# "include_usage": True,
|
||||||
|
# },
|
||||||
}
|
}
|
||||||
_update_payload_common(payload, request_func_input)
|
_update_payload_common(payload, request_func_input)
|
||||||
|
|
||||||
|
|||||||
@ -386,6 +386,24 @@ async def get_server_load_metrics(request: Request):
|
|||||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/get_server_info")
|
||||||
|
async def get_server_info(raw_request: Request):
|
||||||
|
"""Returns server information including DP size for router"""
|
||||||
|
config = raw_request.app.state.vllm_config
|
||||||
|
|
||||||
|
# Extract dp_size from parallel_config
|
||||||
|
dp_size = 1 # Default value
|
||||||
|
if hasattr(config, 'parallel_config') and hasattr(config.parallel_config, 'data_parallel_size'):
|
||||||
|
dp_size = config.parallel_config.data_parallel_size
|
||||||
|
|
||||||
|
server_info = {
|
||||||
|
"vllm_config": str(config),
|
||||||
|
"dp_size": dp_size
|
||||||
|
}
|
||||||
|
return JSONResponse(content=server_info)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ping", response_class=Response)
|
@router.get("/ping", response_class=Response)
|
||||||
@router.post("/ping", response_class=Response)
|
@router.post("/ping", response_class=Response)
|
||||||
async def ping(raw_request: Request) -> Response:
|
async def ping(raw_request: Request) -> Response:
|
||||||
|
|||||||
@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if raw_request:
|
if raw_request:
|
||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
|
# Extract data_parallel_rank from header (router can inject it)
|
||||||
|
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|||||||
@ -141,6 +141,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# Extract data_parallel_rank from header (router can inject it)
|
||||||
|
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||||
|
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
@ -224,6 +228,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|||||||
@ -1297,6 +1297,21 @@ class OpenAIServing:
|
|||||||
|
|
||||||
return raw_request.headers.get("X-Request-Id", default)
|
return raw_request.headers.get("X-Request-Id", default)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
|
||||||
|
"""Pulls the data parallel rank from a header, if provided"""
|
||||||
|
if raw_request is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rank_str = raw_request.headers.get("X-data-parallel-rank")
|
||||||
|
if rank_str is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return int(rank_str)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_decoded_token(
|
def _get_decoded_token(
|
||||||
logprob: Logprob,
|
logprob: Logprob,
|
||||||
|
|||||||
@ -36,9 +36,9 @@ def kernel_warmup(worker: "Worker"):
|
|||||||
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
||||||
deep_gemm_warmup(model, max_tokens)
|
deep_gemm_warmup(model, max_tokens)
|
||||||
|
|
||||||
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
# # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||||
if has_flashinfer() and current_platform.has_device_capability(90):
|
# if has_flashinfer() and current_platform.has_device_capability(90):
|
||||||
flashinfer_autotune(worker.model_runner)
|
# flashinfer_autotune(worker.model_runner)
|
||||||
|
|
||||||
# FlashInfer attention warmup
|
# FlashInfer attention warmup
|
||||||
# Only warmup if the model has FlashInfer attention groups
|
# Only warmup if the model has FlashInfer attention groups
|
||||||
|
|||||||
@ -116,9 +116,14 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
num_decode_tokens: int,
|
num_decode_tokens: int,
|
||||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||||
) -> FlashMLADecodeMetadata:
|
) -> FlashMLADecodeMetadata:
|
||||||
|
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
|
# we use the max but all should be the same due to uniform length requirement
|
||||||
|
max_query_len = query_lens_cpu.max().item()
|
||||||
|
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||||
|
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||||
seq_lens_device,
|
seq_lens_device,
|
||||||
self.num_q_heads,
|
num_q_tokens_per_head_k,
|
||||||
1, # MQA for the decode path
|
1, # MQA for the decode path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -509,6 +509,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# device_id = self.device.index
|
||||||
|
|
||||||
|
# def cb(_device, _alloc, _device_alloc, _device_free):
|
||||||
|
# torch.cuda.memory._dump_snapshot(f"/tmp/vllm_oom_{device_id}.pickle")
|
||||||
|
|
||||||
|
# torch.cuda.memory._record_memory_history(max_entries=100_000)
|
||||||
|
# torch._C._cuda_attach_out_of_memory_observer(cb)
|
||||||
|
|
||||||
def reset_mm_cache(self) -> None:
|
def reset_mm_cache(self) -> None:
|
||||||
if self.mm_budget:
|
if self.mm_budget:
|
||||||
self.mm_budget.reset_cache()
|
self.mm_budget.reset_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user