[Multi Modal] Add FA3 in VIT (#24347)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang
2025-09-12 06:27:24 -07:00
committed by GitHub
parent fdb09c77d6
commit 72fc8aa412
13 changed files with 247 additions and 66 deletions

View File

@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str):
torch.set_default_dtype(torch.float16)
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()), \
patch("vllm.platforms.current_platform", CpuPlatform()):
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CpuPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
assert attn.attn_backend == _Backend.TORCH_SDPA
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
RocmPlatform()), \
patch("vllm.platforms.current_platform", RocmPlatform()), \
patch("vllm.attention.layer.current_platform", RocmPlatform()):
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
else:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
assert attn.attn_backend == _Backend.FLASH_ATTN
with patch("vllm.attention.selector.current_platform",
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available
# - should use xformers
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
patch("vllm.attention.layer.check_upstream_fa_availability",
return_value=False):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available
# - should use upstream FA
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()), \
patch("vllm.attention.layer.check_upstream_fa_availability",
return_value=True), \
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
{
'flash_attn_varlen_func': lambda *args, **kwargs: None
})()}):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN
def ref_attention(
query: torch.Tensor,