[v1] Re-add fp32 support to v1 engine through FlexAttention (#19754)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-07-05 17:41:10 +08:00
committed by GitHub
parent 8aeaa910a2
commit 32c9be2200
8 changed files with 59 additions and 12 deletions

View File

@ -68,7 +68,7 @@ jobs:
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
- name: curl test
run: |

View File

@ -181,6 +181,34 @@ def test_env(
assert backend.get_name() == expected
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])
def test_fp32_fallback(
device: str,
use_v1: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Test attention backend selection with fp32."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1 else "TORCH_SDPA")
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "FLEX_ATTENTION"
if use_v1 else "XFORMERS")
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to

View File

@ -450,6 +450,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
@ -478,6 +479,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
@ -506,6 +508,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
@ -534,6 +537,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_without_kv_sharing():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
@ -601,6 +605,7 @@ def test_init_kv_cache_without_kv_sharing():
def test_init_kv_cache_with_kv_sharing_valid():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()

View File

@ -1393,13 +1393,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
# Only Fp16 and Bf16 dtypes since we only support FA.
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
if model_config.dtype not in V1_SUPPORTED_DTYPES:
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
recommend_to_remove=False)
return False
# No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,

View File

@ -104,8 +104,12 @@ class TensorizerLoader(BaseModelLoader):
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
model = init_tensorizer_model(tensorizer_config=tensorizer_config,
vllm_config=vllm_config)
device_config = vllm_config.device_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = init_tensorizer_model(
tensorizer_config=tensorizer_config,
vllm_config=vllm_config)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config)

View File

@ -251,6 +251,10 @@ class CudaPlatformBase(Platform):
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
f"Using FlexAttenion backend for {dtype} on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if cls.is_device_capability(100):
try:
import flashinfer # noqa: F401

View File

@ -463,6 +463,13 @@ class FlexAttentionImpl(AttentionImpl):
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
# default M=64, N=64 may run out of shared memory on
# some GPUs with fp32, so we use smaller M and N.
extra_kernel_options = {
"BLOCK_M": 32,
"BLOCK_N": 32
} if query.dtype == torch.float32 else {}
out = flex_attention_compiled(
query,
key_cache,
@ -471,7 +478,10 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata.block_mask,
self.scale,
enable_gqa=enable_gqa,
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
kernel_options={
"FORCE_USE_FLEX_ATTENTION": True,
**extra_kernel_options
},
)
# Flex doesn't have an out variant today, rely on epilogue fusion

View File

@ -101,7 +101,10 @@ class TopKTopPSampler(nn.Module):
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return flashinfer_sample(logits, k, p, generators)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators)
def forward_tpu(
self,