[Bugfix] Fix MTP+FlashInfer crash when trtllm kernels are available but disabled (#26361)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benjamin Chislett
2025-10-07 18:12:26 -04:00
committed by GitHub
parent 1b86bd8e18
commit caf8b1c084

View File

@ -220,6 +220,8 @@ def force_use_trtllm_attention() -> bool | None:
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
"""Check if the current configuration supports TRTLLM attention."""
if force_use_trtllm_attention() is False:
return False
has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0)