Compare commits
1 Commits
khluu/use_
...
compile-ep
| Author | SHA1 | Date | |
|---|---|---|---|
| 787384dd4a |
@ -100,19 +100,21 @@ def test_models(
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
if model not in V0_UNSUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v0_outputs = None
|
||||
if model not in V0_UNSUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v0_outputs = None
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v1_outputs = None
|
||||
|
||||
@ -135,7 +137,7 @@ def test_models(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_batching(
|
||||
@ -145,6 +147,10 @@ def test_batching(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
if model in V0_UNSUPPORTED_MODELS:
|
||||
pytest.skip(
|
||||
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@ -182,32 +188,29 @@ def test_chunked_prefill(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
chunked_prefill_token_size: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(model,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
chunked = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
with vllm_runner(model,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
chunked = vllm_model.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model,
|
||||
enable_chunked_prefill=False,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
non_chunked = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
with vllm_runner(model,
|
||||
enable_chunked_prefill=False,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
non_chunked = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=chunked,
|
||||
outputs_1_lst=non_chunked,
|
||||
name_0="chunked",
|
||||
name_1="non_chunked",
|
||||
)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=chunked,
|
||||
outputs_1_lst=non_chunked,
|
||||
name_0="chunked",
|
||||
name_1="non_chunked",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@ -278,29 +281,25 @@ def test_models_preemption_recompute(
|
||||
example_prompts,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests that outputs are identical with and w/o preemptions (recompute).
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
scheduler = vllm_model.llm.llm_engine.scheduler[0]
|
||||
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
|
||||
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
scheduler = vllm_model.llm.llm_engine.scheduler[0]
|
||||
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
|
||||
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=preempt_vllm_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="vllm_preepmtions",
|
||||
name_1="vllm",
|
||||
)
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=preempt_vllm_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="vllm_preepmtions",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@ -403,18 +402,24 @@ def test_full_cuda_graph(
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
if model not in V0_UNSUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v0_outputs = None
|
||||
if model not in V0_UNSUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v0_outputs = None
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
compilation_config={'full_cuda_graph': True},
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if hf_outputs is not None and vllm_v0_outputs is not None:
|
||||
check_logprobs_close(
|
||||
@ -461,20 +466,24 @@ def test_fp32_state(
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32",
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if hf_outputs is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
|
||||
@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
module.quant_method.init_prepare_finalize()
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -1463,6 +1463,11 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# V1 mamba models are unoptimized.
|
||||
if model_config.has_inner_state and _warn_or_fallback(
|
||||
feature_name="Mamba"):
|
||||
return False
|
||||
|
||||
# No Concurrent Partial Prefills so far.
|
||||
if (self.max_num_partial_prefills
|
||||
!= SchedulerConfig.max_num_partial_prefills
|
||||
|
||||
@ -166,7 +166,6 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1145,12 +1144,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ENABLE_CUDAGRAPH_GC":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
|
||||
|
||||
# Disable padding to CUDA graph capture batch sizes.
|
||||
# TODO(wentao): https://github.com/vllm-project/vllm/issues/23378
|
||||
# After the issue is fixed, we can remove this flag.
|
||||
"VLLM_DISABLE_PAD_FOR_CUDAGRAPH":
|
||||
lambda: bool(int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0"))),
|
||||
|
||||
# Used to force set up loopback IP
|
||||
"VLLM_LOOPBACK_IP":
|
||||
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
|
||||
|
||||
@ -257,8 +257,6 @@ class InputPreprocessor:
|
||||
mm_processor_kwargs: Optional[Mapping[str, object]],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
@ -275,13 +273,10 @@ class InputPreprocessor:
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
return mm_processor.apply(
|
||||
prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
return mm_processor.apply(prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs)
|
||||
|
||||
async def _process_multimodal_async(
|
||||
self,
|
||||
@ -290,8 +285,6 @@ class InputPreprocessor:
|
||||
mm_processor_kwargs: Optional[Mapping[str, object]],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -308,13 +301,10 @@ class InputPreprocessor:
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
return mm_processor.apply(
|
||||
prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
return mm_processor.apply(prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs)
|
||||
|
||||
def _process_embeds(
|
||||
self,
|
||||
@ -351,8 +341,6 @@ class InputPreprocessor:
|
||||
parsed_content: TokensPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
token_type_ids = parsed_content.get("token_type_ids")
|
||||
@ -365,7 +353,6 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(
|
||||
@ -383,8 +370,6 @@ class InputPreprocessor:
|
||||
parsed_content: TokensPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
token_type_ids = parsed_content.get("token_type_ids")
|
||||
@ -397,7 +382,6 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(
|
||||
@ -415,8 +399,6 @@ class InputPreprocessor:
|
||||
parsed_content: TextPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
@ -428,7 +410,6 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
@ -451,8 +432,6 @@ class InputPreprocessor:
|
||||
parsed_content: TextPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
@ -464,7 +443,6 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
@ -487,8 +465,6 @@ class InputPreprocessor:
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> SingletonInputs:
|
||||
"""
|
||||
Extract the singleton inputs from a prompt.
|
||||
@ -510,21 +486,18 @@ class InputPreprocessor:
|
||||
return self._process_tokens(
|
||||
parsed["content"],
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
if parsed["type"] == "text":
|
||||
return self._process_text(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return self._process_text(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
@ -534,8 +507,6 @@ class InputPreprocessor:
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> SingletonInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -549,21 +520,18 @@ class InputPreprocessor:
|
||||
return await self._process_tokens_async(
|
||||
parsed["content"],
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
if parsed["type"] == "text":
|
||||
return await self._process_text_async(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return await self._process_text_async(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
@ -673,8 +641,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@ -716,7 +682,6 @@ class InputPreprocessor:
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
prompt["encoder_prompt"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_inputs = None
|
||||
@ -732,7 +697,6 @@ class InputPreprocessor:
|
||||
inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
@ -748,8 +712,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -762,7 +724,6 @@ class InputPreprocessor:
|
||||
encoder_task = self._prompt_to_llm_inputs_async(
|
||||
prompt["encoder_prompt"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
@ -772,7 +733,6 @@ class InputPreprocessor:
|
||||
decoder_task = self._prompt_to_llm_inputs_async(
|
||||
decoder_input,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||
@ -788,7 +748,6 @@ class InputPreprocessor:
|
||||
inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
@ -815,8 +774,6 @@ class InputPreprocessor:
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
For decoder-only models:
|
||||
@ -837,7 +794,6 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||
@ -847,8 +803,6 @@ class InputPreprocessor:
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -858,7 +812,6 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||
@ -868,8 +821,6 @@ class InputPreprocessor:
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
if self.model_config.is_encoder_decoder:
|
||||
@ -878,7 +829,6 @@ class InputPreprocessor:
|
||||
return self._process_encoder_decoder_prompt(
|
||||
prompt,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
@ -890,7 +840,6 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
async def preprocess_async(
|
||||
@ -898,8 +847,6 @@ class InputPreprocessor:
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -911,7 +858,6 @@ class InputPreprocessor:
|
||||
return await self._process_encoder_decoder_prompt_async(
|
||||
prompt,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
@ -923,7 +869,6 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
|
||||
@ -450,12 +450,6 @@ class FusedMoEConfig:
|
||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Config)
|
||||
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
|
||||
quant_dtype = "mxfp8"
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config)
|
||||
if quant_dtype is None and isinstance(quant_config,
|
||||
|
||||
@ -200,7 +200,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
# Note: init_prepare_finalize should only be called by
|
||||
# prepare_communication_buffer_for_model.
|
||||
def init_prepare_finalize(self, layer: torch.nn.Module):
|
||||
def init_prepare_finalize(self):
|
||||
assert self.moe is not None
|
||||
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
||||
|
||||
@ -211,7 +211,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
assert self.fused_experts is None, \
|
||||
f"Attempt to override experts for {id(self)}!"
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
|
||||
experts = self.select_gemm_impl(prepare_finalize, self.moe)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
@ -221,7 +221,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
@ -274,7 +273,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
# TODO(bnell): Remove. Every layer should have an moe config object.
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
@ -1403,6 +1401,66 @@ class FusedMoE(CustomOp):
|
||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||
|
||||
@staticmethod
|
||||
@torch.compile(dynamic=True,
|
||||
backend=current_platform.simple_compile_backend)
|
||||
def handle_eplb(
|
||||
topk_ids: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
indices_type: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
# 1. Convert the logical expert ids to physical expert ids
|
||||
# Directly select a random replica for each logical expert
|
||||
|
||||
# TODO: maybe optimize this by using specified kernels,
|
||||
# or compute pseudo-random indices by modulo
|
||||
|
||||
# In case `indices_type` is not `torch.long` or `torch.int`,
|
||||
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
||||
topk_ids_long = topk_ids.long()
|
||||
replica_indices = (
|
||||
torch.rand_like(topk_ids, dtype=torch.float) *
|
||||
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
|
||||
physical_ids = logical_to_physical_map[topk_ids_long].gather(
|
||||
-1, replica_indices).squeeze(-1)
|
||||
|
||||
topk_ids = physical_ids
|
||||
|
||||
# 2. Record expert load metrics.
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
# so we decide to keep the logic here.
|
||||
#
|
||||
# If later refactor moved all the MoE kernel calls
|
||||
# to the modular kernel, we can move this logic there
|
||||
# to achieve better efficiency.
|
||||
|
||||
# `expert_load_view`: (num_physical_experts,)
|
||||
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
|
||||
# Performance optimization:
|
||||
# `masked_fill` is significantly faster than `masked_select`
|
||||
invalid_mask = topk_ids_flatten < 0
|
||||
# Replace invalid expert ids with 0 (just a dummy position)
|
||||
# to avoid out-of-bounds errors in scatter_add_
|
||||
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
|
||||
# `src` is the valid mask, which is 1 for valid and 0 for invalid
|
||||
src = ~invalid_mask
|
||||
|
||||
expert_load_view.scatter_add_(dim=0,
|
||||
index=index.long(),
|
||||
src=src.to(expert_load_view))
|
||||
|
||||
return topk_ids.to(dtype=indices_type)
|
||||
|
||||
@staticmethod
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1482,56 +1540,12 @@ class FusedMoE(CustomOp):
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
|
||||
# 1. Convert the logical expert ids to physical expert ids
|
||||
# Directly select a random replica for each logical expert
|
||||
|
||||
# TODO: maybe optimize this by using specified kernels,
|
||||
# or compute pseudo-random indices by modulo
|
||||
|
||||
# In case `indices_type` is not `torch.long` or `torch.int`,
|
||||
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
||||
topk_ids_long = topk_ids.long()
|
||||
replica_indices = (
|
||||
torch.rand_like(topk_ids, dtype=torch.float) *
|
||||
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
|
||||
physical_ids = logical_to_physical_map[topk_ids_long].gather(
|
||||
-1, replica_indices).squeeze(-1)
|
||||
|
||||
topk_ids = physical_ids
|
||||
|
||||
# 2. Record expert load metrics.
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
# so we decide to keep the logic here.
|
||||
#
|
||||
# If later refactor moved all the MoE kernel calls
|
||||
# to the modular kernel, we can move this logic there
|
||||
# to achieve better efficiency.
|
||||
|
||||
# `expert_load_view`: (num_physical_experts,)
|
||||
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
|
||||
# Performance optimization:
|
||||
# `masked_fill` is significantly faster than `masked_select`
|
||||
invalid_mask = topk_ids_flatten < 0
|
||||
# Replace invalid expert ids with 0 (just a dummy position)
|
||||
# to avoid out-of-bounds errors in scatter_add_
|
||||
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
|
||||
# `src` is the valid mask, which is 1 for valid and 0 for invalid
|
||||
src = ~invalid_mask
|
||||
|
||||
expert_load_view.scatter_add_(dim=0,
|
||||
index=index.long(),
|
||||
src=src.to(expert_load_view))
|
||||
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
topk_ids = FusedMoE.handle_eplb(
|
||||
topk_ids=topk_ids,
|
||||
logical_replica_count=logical_replica_count,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
expert_load_view=expert_load_view,
|
||||
indices_type=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
|
||||
|
||||
@ -1,197 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.utils import next_power_of_2
|
||||
|
||||
|
||||
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
gemm1_alpha,
|
||||
gemm1_beta,
|
||||
gemm1_clamp_limit,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
max_capture_size,
|
||||
):
|
||||
super().__init__(moe.quant_config)
|
||||
self.moe = moe
|
||||
self.gemm1_alpha = gemm1_alpha
|
||||
self.gemm1_beta = gemm1_beta
|
||||
self.gemm1_clamp_limit = gemm1_clamp_limit
|
||||
self.w13_bias = w13_bias
|
||||
self.w2_bias = w2_bias
|
||||
self.max_capture_size = max_capture_size
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
# TODO(varun) : workspace1 is could be used as the output tensor. This
|
||||
# is error-prone. Allow the `workspace_shapes` to return None workspaces
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0, 0)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
|
||||
local_num_experts: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# 1.0 means perfect expert distribution.
|
||||
# > 1.0 means some experts have more tokens than the perfect
|
||||
# distribution.
|
||||
# < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert assuming perfect
|
||||
# distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
|
||||
# kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
topk = topk_ids.size(-1)
|
||||
local_num_experts = w1.size(0)
|
||||
intermediate_size = w2.size(1)
|
||||
local_expert_offset = self.moe.ep_rank * local_num_experts
|
||||
|
||||
x_quant = hidden_states
|
||||
x_scale = a1q_scale
|
||||
if x_scale is not None:
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*x_quant.shape[:-1], -1)
|
||||
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16).view(torch.int16)
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
kwargs = {
|
||||
"topk_ids":
|
||||
packed_tensor,
|
||||
"routing_bias":
|
||||
None,
|
||||
"hidden_states":
|
||||
x_quant,
|
||||
"hidden_states_scale":
|
||||
x_scale,
|
||||
"gemm1_weights":
|
||||
w1,
|
||||
"gemm1_weights_scale":
|
||||
w1_scale,
|
||||
"gemm1_bias":
|
||||
self.w13_bias,
|
||||
"gemm1_alpha":
|
||||
self.gemm1_alpha,
|
||||
"gemm1_beta":
|
||||
self.gemm1_beta,
|
||||
"gemm1_clamp_limit":
|
||||
self.gemm1_clamp_limit,
|
||||
"gemm2_weights":
|
||||
w2,
|
||||
"gemm2_weights_scale":
|
||||
w2_scale,
|
||||
"gemm2_bias":
|
||||
self.w2_bias,
|
||||
"output1_scale_scalar":
|
||||
None,
|
||||
"output1_scale_gate_scalar":
|
||||
None,
|
||||
"output2_scale_scalar":
|
||||
None,
|
||||
"num_experts":
|
||||
global_num_experts,
|
||||
"top_k":
|
||||
topk,
|
||||
"n_group":
|
||||
None,
|
||||
"topk_group":
|
||||
None,
|
||||
"intermediate_size":
|
||||
intermediate_size,
|
||||
"local_expert_offset":
|
||||
local_expert_offset,
|
||||
"local_num_experts":
|
||||
local_num_experts,
|
||||
"routed_scaling_factor":
|
||||
None,
|
||||
"tile_tokens_dim":
|
||||
self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
|
||||
"routing_method_type":
|
||||
1,
|
||||
"do_finalize":
|
||||
True,
|
||||
"output":
|
||||
output,
|
||||
"tune_max_num_tokens":
|
||||
self.max_capture_size,
|
||||
}
|
||||
|
||||
from flashinfer import trtllm_fp4_block_scale_routed_moe
|
||||
trtllm_fp4_block_scale_routed_moe(**kwargs)
|
||||
return output
|
||||
@ -12,8 +12,6 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8, per_token_quant_int8)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
quant_dequant_mxfp4)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
mxfp8_quantize)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
@ -179,18 +177,6 @@ def _mxfp4_quantize(
|
||||
return A, None
|
||||
|
||||
|
||||
def _mxfp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert A_scale is None
|
||||
assert not per_act_token_quant
|
||||
assert block_shape is None
|
||||
return mxfp8_quantize(A)
|
||||
|
||||
|
||||
def moe_kernel_quantize_input(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
@ -209,8 +195,6 @@ def moe_kernel_quantize_input(
|
||||
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
||||
elif quant_dtype == "mxfp4":
|
||||
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == "mxfp8":
|
||||
return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
else:
|
||||
return A, A_scale
|
||||
|
||||
|
||||
@ -322,7 +322,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return the appropriate GEMM experts implementation."""
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
@ -720,9 +719,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
dtype=torch.int64)
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# cutlass path
|
||||
if self.use_cutlass:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
|
||||
@ -897,7 +897,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)
|
||||
|
||||
@ -311,7 +311,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
moe,
|
||||
@ -891,11 +890,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Block scale must be represented as FP8-E4M3")
|
||||
|
||||
if self.backend == "marlin":
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
elif self.backend == "flashinfer-trtllm":
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
@ -920,6 +915,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
if self.backend == "marlin":
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1032,7 +1032,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
moe,
|
||||
@ -1311,13 +1310,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
del layer.w2_weight_scale
|
||||
del layer.w13_weight
|
||||
del layer.w13_weight_scale
|
||||
elif self.use_marlin:
|
||||
# Marlin processing
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
del layer.g1_alphas
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
else:
|
||||
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||
@ -1339,6 +1331,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
del layer.g1_alphas
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@ -10,8 +10,6 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -447,91 +445,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
if (prepare_finalize.activation_format ==
|
||||
mk.FusedMoEActivationFormat.BatchedExperts):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP")
|
||||
else:
|
||||
if should_use_flashinfer_mxfp4():
|
||||
# B200 code-path
|
||||
kwargs = {
|
||||
"gemm1_alpha": layer.gemm1_alpha,
|
||||
"gemm1_beta": layer.gemm1_beta,
|
||||
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
||||
"w13_bias": layer.w13_bias,
|
||||
"w2_bias": layer.w2_bias,
|
||||
"max_capture_size": self.max_capture_size,
|
||||
}
|
||||
return TrtLlmGenExperts(moe, **kwargs)
|
||||
else:
|
||||
# Use matmul_ogs from triton_kernels here!
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support non-batched experts format for EP")
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -590,29 +503,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation=activation,
|
||||
expert_map=expert_map)
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self._route_and_experts(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||
custom_routing_function, e_score_correction_bias,
|
||||
|
||||
@ -66,10 +66,11 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None):
|
||||
return not (use_grouped_topk or topk_group or num_expert_group
|
||||
or custom_routing_function or e_score_correction_bias
|
||||
or apply_router_weight_on_input or scoring_func != "softmax"
|
||||
or activation != "swigluoai" or expert_load_view
|
||||
or logical_to_physical_map or logical_replica_count)
|
||||
or expert_map or custom_routing_function
|
||||
or e_score_correction_bias or apply_router_weight_on_input
|
||||
or scoring_func != "softmax" or activation != "swigluoai"
|
||||
or expert_load_view or logical_to_physical_map
|
||||
or logical_replica_count)
|
||||
|
||||
|
||||
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
try:
|
||||
from flashinfer import mxfp8_quantize
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `flashinfer` is required to do "
|
||||
"MX-FP8 quantization. Please install it with" \
|
||||
"`pip install flashinfer`") from err
|
||||
|
||||
return mxfp8_quantize(x, is_sf_swizzled_layout=False)
|
||||
@ -417,5 +417,4 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
"MambaForCausalLM": MambaModelConfig,
|
||||
"Mamba2ForCausalLM": MambaModelConfig,
|
||||
"FalconMambaForCausalLM": MambaModelConfig,
|
||||
}
|
||||
|
||||
@ -290,7 +290,6 @@ class DeepseekVL2MultiModalProcessor(
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
# The processor logic is different for len(images) <= 2 vs > 2
|
||||
# Since the processing cache assumes that the processor output is
|
||||
@ -302,7 +301,6 @@ class DeepseekVL2MultiModalProcessor(
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
return super()._cached_apply_hf_processor(
|
||||
@ -310,7 +308,6 @@ class DeepseekVL2MultiModalProcessor(
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -479,7 +479,6 @@ class H2OVLMultiModalProcessor(
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
# The processor logic is different for len(images) <= 1 vs > 1
|
||||
# Since the processing cache assumes that the processor output is
|
||||
@ -491,7 +490,6 @@ class H2OVLMultiModalProcessor(
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
return super()._cached_apply_hf_processor(
|
||||
@ -499,7 +497,6 @@ class H2OVLMultiModalProcessor(
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -795,7 +795,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> MultiModalInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
@ -806,11 +805,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
image_height=-1,
|
||||
)
|
||||
|
||||
result = super().apply(prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides)
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs)
|
||||
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
|
||||
@ -184,13 +184,9 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> MultiModalEncDecInputs:
|
||||
mm_inputs = super().apply(prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides)
|
||||
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs)
|
||||
|
||||
image_token_id = self.info.get_hf_config().image_token_index
|
||||
# Check that the number of image tokens in the decoder prompt matches
|
||||
|
||||
@ -203,13 +203,9 @@ class PaliGemmaMultiModalProcessor(
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> MultiModalInputs:
|
||||
mm_inputs = super().apply(prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides)
|
||||
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
@ -314,14 +314,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
# NOTE: The tokens are already inserted by the chat template
|
||||
|
||||
@ -138,7 +138,6 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> MultiModalInputs:
|
||||
if "image" in mm_data:
|
||||
image_data = mm_data["image"]
|
||||
@ -147,10 +146,8 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_data = {"image": mm_data}
|
||||
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
|
||||
self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs))
|
||||
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs or {})
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
mm_processed_data = BatchFeature(image_data)
|
||||
|
||||
@ -327,7 +327,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@ -394,11 +393,9 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
|
||||
num_image_patches),
|
||||
)
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
|
||||
self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs))
|
||||
|
||||
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs)
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
|
||||
@ -288,14 +288,12 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
# NOTE: The tokens are already inserted by the chat template
|
||||
|
||||
@ -1020,13 +1020,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_hash_overrides: Optional[MultiModalHashes] = None,
|
||||
) -> MultiModalInputs:
|
||||
return self.apply(prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides)
|
||||
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
"""
|
||||
@ -1362,11 +1357,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> MultiModalHashes:
|
||||
"""Create MM hashes to be returned (only used in V1).
|
||||
|
||||
Note: When overrides are provided via callers of `apply`,
|
||||
`_hash_mm_items` will be bypassed and the overrides will be used.
|
||||
"""
|
||||
"""Create MM hashes to be returned (only used in V1)."""
|
||||
model_id = self.info.model_id
|
||||
|
||||
return {
|
||||
@ -1473,8 +1464,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_hash_overrides: Optional[MultiModalHashes] = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
(
|
||||
prompt_ids,
|
||||
@ -1494,10 +1483,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs),
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
|
||||
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs))
|
||||
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs)
|
||||
|
||||
mm_prompt_updates = self._get_mm_prompt_updates(
|
||||
mm_data_items,
|
||||
@ -1519,8 +1506,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_hash_overrides: Optional[MultiModalHashes] = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
@ -1535,13 +1520,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
|
||||
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs))
|
||||
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs)
|
||||
|
||||
mm_missing_data_items = self._get_cache_missing_items(
|
||||
cache=cache,
|
||||
@ -1741,8 +1723,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@ -1771,7 +1751,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
# NOTE: tokenization_kwargs are not required to init processor
|
||||
@ -1856,8 +1835,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[MultiModalHashes] = None,
|
||||
) -> MultiModalEncDecInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@ -1872,7 +1849,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
return self._get_enc_dec_inputs(
|
||||
|
||||
@ -225,41 +225,6 @@ class Processor:
|
||||
# Remember that this backend was set automatically
|
||||
params.guided_decoding.backend_was_auto = True
|
||||
|
||||
def _maybe_build_mm_hash_overrides(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
) -> Optional[dict[str, list[str]]]:
|
||||
"""Build per-item multimodal hash overrides when enabled. In this case,
|
||||
multimodal data items are identified by their request id, modality and
|
||||
index rather than their content.
|
||||
|
||||
Returns a dictionary of modality -> list[str] of overrides, or None if
|
||||
disabled or no multimodal data is present.
|
||||
"""
|
||||
|
||||
def _extract_mm_data(p: PromptType):
|
||||
if isinstance(p, dict) and "encoder_prompt" in p:
|
||||
enc = p.get("encoder_prompt")
|
||||
if isinstance(enc, dict):
|
||||
return enc.get("multi_modal_data")
|
||||
return None
|
||||
if isinstance(p, dict):
|
||||
return p.get("multi_modal_data")
|
||||
return None
|
||||
|
||||
mm_data = _extract_mm_data(prompt)
|
||||
if not mm_data:
|
||||
return None
|
||||
|
||||
overrides: dict[str, list[str]] = {}
|
||||
for modality, data in mm_data.items():
|
||||
n = len(data) if isinstance(data, list) else 1
|
||||
overrides[modality] = [
|
||||
f"{request_id}-{modality}-{i}" for i in range(n)
|
||||
]
|
||||
return overrides
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -289,18 +254,6 @@ class Processor:
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Optionally generate multimodal hash overrides based on request id.
|
||||
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore hashing is no longer necessary.
|
||||
if (self.model_config.multimodal_config and
|
||||
self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
|
||||
request_id, prompt)
|
||||
else:
|
||||
mm_hash_overrides = None
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
@ -309,7 +262,6 @@ class Processor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.validate_request(
|
||||
|
||||
@ -1491,7 +1491,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
|
||||
Reference in New Issue
Block a user