Compare commits
3 Commits
amd_mori
...
woosuk/fa3
| Author | SHA1 | Date | |
|---|---|---|---|
| a772948c9d | |||
| 06fba5410c | |||
| 3e56ae2878 |
@ -71,7 +71,8 @@ def llm_pair(request):
|
||||
[
|
||||
# Model names for the llm_pair fixture
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
"Qwen/Qwen2-1.5B-Instruct"
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
"google/gemma-3-1b-it",
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
@ -126,6 +127,8 @@ class TestFullCUDAGraph:
|
||||
("Qwen/Qwen2-1.5B-Instruct", True),
|
||||
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
|
||||
("deepseek-ai/DeepSeek-V2-Lite", False),
|
||||
# Full CUDA graph supports mixed full and sliding window attention.
|
||||
("google/gemma-3-1b-it", True),
|
||||
])
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
|
||||
@ -205,9 +205,11 @@ class FlashAttentionMetadataBuilder(
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
sliding_window = getattr(kv_cache_spec, "sliding_window", None)
|
||||
if sliding_window is not None:
|
||||
self.aot_sliding_window = (sliding_window - 1, 0)
|
||||
else:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
@ -231,23 +233,6 @@ class FlashAttentionMetadataBuilder(
|
||||
# the overhead of the aot schedule is not worth it for spec-decode
|
||||
aot_schedule = self.aot_schedule and not fast_build
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
# For the AOT scheduler we need the sliding window value to be
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
aot_schedule = False
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
|
||||
Reference in New Issue
Block a user