Compare commits

...

3 Commits

Author SHA1 Message Date
a772948c9d add gemma3 to test
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-04 12:56:13 -07:00
06fba5410c Merge branch 'main' into woosuk/fa3-swa-cudagraph 2025-08-04 12:47:49 -07:00
3e56ae2878 [Bugfix] Support full cuda graph with sliding window attention
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-03 20:28:31 -07:00
2 changed files with 9 additions and 21 deletions

View File

@ -71,7 +71,8 @@ def llm_pair(request):
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite",
"Qwen/Qwen2-1.5B-Instruct"
"Qwen/Qwen2-1.5B-Instruct",
"google/gemma-3-1b-it",
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
@ -126,6 +127,8 @@ class TestFullCUDAGraph:
("Qwen/Qwen2-1.5B-Instruct", True),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
("deepseek-ai/DeepSeek-V2-Lite", False),
# Full CUDA graph supports mixed full and sliding window attention.
("google/gemma-3-1b-it", True),
])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")

View File

@ -205,9 +205,11 @@ class FlashAttentionMetadataBuilder(
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
sliding_window = getattr(kv_cache_spec, "sliding_window", None)
if sliding_window is not None:
self.aot_sliding_window = (sliding_window - 1, 0)
else:
self.aot_sliding_window = (-1, -1)
def build(self,
common_prefix_len: int,
@ -231,23 +233,6 @@ class FlashAttentionMetadataBuilder(
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self.aot_schedule and not fast_build
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False
aot_schedule = False
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
cache_dtype = self.cache_config.cache_dtype