[ROCm] Env variable to trigger custom PA (#15557)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-03-27 01:46:12 -04:00
committed by GitHub
parent dcf2a590f5
commit ecff8309a3
2 changed files with 8 additions and 1 deletions

View File

@ -908,4 +908,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)

View File

@ -78,6 +78,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
@ -541,6 +542,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_MOE_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
# custom paged attention kernel for MI3* cards
"VLLM_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1")),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),