Compare commits
8 Commits
woosuk/rou
...
woosuk/tes
| Author | SHA1 | Date | |
|---|---|---|---|
| cb439737db | |||
| a1cac48477 | |||
| 6102536d65 | |||
| f65da69c72 | |||
| a5281395e9 | |||
| eda71c2847 | |||
| 1bff9a59ec | |||
| 69c9a01538 |
@ -12,4 +12,5 @@ torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytor
|
||||
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
||||
xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.4.0
|
||||
flashinfer-python==0.4.0
|
||||
apache-tvm-ffi==0.1.0b15
|
||||
|
||||
@ -75,7 +75,7 @@ class SampleRequest:
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
prompt: str | list[str]
|
||||
prompt: str | list[str] | list[int]
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None
|
||||
@ -402,8 +402,9 @@ def gen_prompt_decode_to_target_len(
|
||||
remain_num_try = max_retry
|
||||
token_mismatch = 0
|
||||
while True:
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||
# prompt = tokenizer.decode(token_sequence)
|
||||
# token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||
prompt = token_sequence
|
||||
if remain_num_try <= 0:
|
||||
if len(token_sequence) != target_token_len:
|
||||
token_mismatch = len(token_sequence) - target_token_len
|
||||
|
||||
@ -165,9 +165,10 @@ async def async_request_openai_completions(
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
# NOTE(woosuk): Disabled for PD.
|
||||
# "stream_options": {
|
||||
# "include_usage": True,
|
||||
# },
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
|
||||
@ -36,9 +36,9 @@ def kernel_warmup(worker: "Worker"):
|
||||
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
||||
deep_gemm_warmup(model, max_tokens)
|
||||
|
||||
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||
if has_flashinfer() and current_platform.has_device_capability(90):
|
||||
flashinfer_autotune(worker.model_runner)
|
||||
# # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||
# if has_flashinfer() and current_platform.has_device_capability(90):
|
||||
# flashinfer_autotune(worker.model_runner)
|
||||
|
||||
# FlashInfer attention warmup
|
||||
# Only warmup if the model has FlashInfer attention groups
|
||||
|
||||
@ -116,9 +116,14 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens_device,
|
||||
self.num_q_heads,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
|
||||
@ -509,6 +509,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# device_id = self.device.index
|
||||
|
||||
# def cb(_device, _alloc, _device_alloc, _device_free):
|
||||
# torch.cuda.memory._dump_snapshot(f"/tmp/vllm_oom_{device_id}.pickle")
|
||||
|
||||
# torch.cuda.memory._record_memory_history(max_entries=100_000)
|
||||
# torch._C._cuda_attach_out_of_memory_observer(cb)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
|
||||
Reference in New Issue
Block a user