[V0 Deprecation] Remove placeholder attn (#25510)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@ -85,8 +85,7 @@ def test_env(
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size,
|
||||
False)
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
elif device == "hip":
|
||||
@ -106,7 +105,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
@ -117,7 +115,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
@ -127,7 +124,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -136,7 +132,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "TRITON_ATTN_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -164,7 +159,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "CUTLASS_MLA_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -179,7 +173,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
@ -199,7 +192,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -208,7 +200,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
@ -218,7 +209,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "TRITON_MLA_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -227,7 +217,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASHINFER_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -236,7 +225,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASH_ATTN_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -245,7 +233,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert backend.get_name() == "FLEX_ATTENTION", (
|
||||
"Should fallback to FlexAttention if head size is "
|
||||
@ -264,13 +251,13 @@ def test_fp32_fallback(
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
|
||||
@ -286,29 +273,29 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(torch.cuda,
|
||||
"get_device_capability",
|
||||
lambda _=None: (7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Reset the monkeypatch for subsequent tests
|
||||
monkeypatch.undo()
|
||||
|
||||
# Unsupported data type
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported block size
|
||||
backend = get_attn_backend(16, torch.float16, None, 8, False)
|
||||
backend = get_attn_backend(16, torch.float16, None, 8)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# flash-attn is not installed
|
||||
import sys
|
||||
original_module = sys.modules.get('vllm_flash_attn')
|
||||
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Restore the original module if it existed
|
||||
@ -319,11 +306,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
|
||||
|
||||
# Unsupported head size
|
||||
backend = get_attn_backend(17, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Attention-free models should bypass env and use PlaceholderAttention
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, True)
|
||||
backend = get_attn_backend(17, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
@ -336,5 +319,5 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# Should raise ValueError for invalid backend
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(32, torch.float16, None, 16, False)
|
||||
get_attn_backend(32, torch.float16, None, 16)
|
||||
assert "Invalid value 'INVALID'" in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user