[V0 Deprecation] Remove placeholder attn (#25510)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-09-24 00:12:14 +02:00
committed by GitHub
parent 4f8c4b890a
commit 969b4da3a6
5 changed files with 10 additions and 354 deletions

View File

@ -85,8 +85,7 @@ def test_env(
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size,
False)
backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
elif device == "hip":
@ -106,7 +105,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
@ -117,7 +115,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
@ -127,7 +124,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1"
assert backend.get_name() == expected
@ -136,7 +132,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1"
assert backend.get_name() == expected
@ -164,7 +159,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "CUTLASS_MLA_VLLM_V1"
assert backend.get_name() == expected
@ -179,7 +173,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
@ -199,7 +192,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1"
assert backend.get_name() == expected
@ -208,7 +200,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
@ -218,7 +209,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_MLA_VLLM_V1"
assert backend.get_name() == expected
@ -227,7 +217,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_VLLM_V1"
assert backend.get_name() == expected
@ -236,7 +225,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_VLLM_V1"
assert backend.get_name() == expected
@ -245,7 +233,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == "FLEX_ATTENTION", (
"Should fallback to FlexAttention if head size is "
@ -264,13 +251,13 @@ def test_fp32_fallback(
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16, False)
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16, False)
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "FLEX_ATTENTION"
@ -286,29 +273,29 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(torch.cuda,
"get_device_capability",
lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16, False)
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Reset the monkeypatch for subsequent tests
monkeypatch.undo()
# Unsupported data type
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
backend = get_attn_backend(16, torch.float16, "fp8", 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = get_attn_backend(16, torch.float16, None, 8, False)
backend = get_attn_backend(16, torch.float16, None, 8)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed
import sys
original_module = sys.modules.get('vllm_flash_attn')
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
backend = get_attn_backend(16, torch.float16, None, 16, False)
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Restore the original module if it existed
@ -319,11 +306,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = get_attn_backend(16, torch.float16, None, 16, True)
backend = get_attn_backend(17, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
@ -336,5 +319,5 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:
get_attn_backend(32, torch.float16, None, 16, False)
get_attn_backend(32, torch.float16, None, 16)
assert "Invalid value 'INVALID'" in str(exc_info.value)