[Multi Modal] Add FA3 in VIT (#24347)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()), \
|
||||
patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CpuPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
RocmPlatform()), \
|
||||
patch("vllm.platforms.current_platform", RocmPlatform()), \
|
||||
patch("vllm.attention.layer.current_platform", RocmPlatform()):
|
||||
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
RocmPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
else:
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA not available
|
||||
# - should use xformers
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
patch("vllm.attention.layer.check_upstream_fa_availability",
|
||||
return_value=False):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA available
|
||||
# - should use upstream FA
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.attention.layer.check_upstream_fa_availability",
|
||||
return_value=True), \
|
||||
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
|
||||
{
|
||||
'flash_attn_varlen_func': lambda *args, **kwargs: None
|
||||
})()}):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
|
||||
|
||||
def ref_attention(
|
||||
query: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user