diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 3fd3759590..159d19bfad 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -220,6 +220,8 @@ def force_use_trtllm_attention() -> bool | None: def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: """Check if the current configuration supports TRTLLM attention.""" + if force_use_trtllm_attention() is False: + return False has_trtllm = supports_trtllm_attention() return has_trtllm and (num_qo_heads % num_kv_heads == 0)