Compare commits

...

21 Commits

Author SHA1 Message Date
01e389cd94 fix 2025-10-16 16:48:51 +00:00
9decb2a5b1 Merge remote-tracking branch 'test/nixl-ptp-gt-dtp' into woosuk/router-nixl 2025-10-16 16:34:15 +00:00
8935ca208d Merge branch 'main' into woosuk/test-router 2025-10-16 00:32:13 +00:00
938c43ea7f [ci] Adjusting AMD test composition 2025-10-14 (#26852)
Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
2025-10-15 23:52:13 +00:00
0a9ef0cfce Move query quantization to attention layer for Flashinfer & Triton. (#26534)
Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: Adrian Abeyta <aabeyta@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-10-15 19:01:38 -04:00
e5b438a247 [Bug] Temporally Disable VLLM_ALLREDUCE_USE_SYMM_MEM by Default (#26925)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-15 16:18:50 -04:00
0b99f5d302 support flashinfer_fp4 moe for 5090 gpu (#26669)
Signed-off-by: XiaobingSuper <xiaobingzhangupc@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2025-10-15 15:06:47 -04:00
1f491aa0c8 Vectorize RMS norm variance using vectorize_read_with_alignment (#26234)
Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
2025-10-15 11:54:41 -07:00
de92d916fe [NVIDIA] Add support for cudnn fp4 gemm via flashinfer (#26107)
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
2025-10-15 13:53:00 -04:00
a1063628a4 [Chore] Clean up CODEOWNERS (#26923)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-10-15 10:52:54 -07:00
dddad8a81c minor 2025-10-14 22:41:25 +00:00
7f783b8a4a merge 2025-10-14 22:39:55 +00:00
1dc9df9842 more integration tests
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-13 14:20:41 +00:00
b8d520232f fix mla
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-13 14:01:34 +00:00
6601c9c5be add and update tests
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-10 10:30:38 +00:00
9f38fed93c clean up
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-09 15:43:43 +00:00
7bb3861faf hacky
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-09 13:24:04 +00:00
684c9b7b6d init
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-08 17:50:07 +00:00
5d45b77124 docs
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-08 14:12:35 +00:00
84dfd367a1 review
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-08 12:56:04 +00:00
1a1c81ca2f init
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-08 10:29:47 +00:00
24 changed files with 773 additions and 292 deletions

View File

@ -63,7 +63,7 @@ steps:
- label: Async Engine, Inputs, Utils, Worker Test (CPU) # 4 mins
timeout_in_minutes: 10
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
source_file_dependencies:
@ -353,7 +353,7 @@ steps:
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
- label: V1 Test others (CPU) # 5 mins
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
source_file_dependencies:
@ -459,6 +459,7 @@ steps:
- pytest -v -s compile/test_fusion_all_reduce.py
- pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py
- pytest -v -s compile/test_aot_compile.py
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30
@ -487,14 +488,14 @@ steps:
- label: Kernels Core Operation Test # 48min
timeout_in_minutes: 75
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
source_file_dependencies:
- csrc/
- tests/kernels/core
commands:
- pytest -v -s kernels/core
- pytest -v -s kernels/core kernels/test_top_k_per_row.py
- label: Kernels Attention Test %N # 23min
timeout_in_minutes: 35
@ -632,7 +633,7 @@ steps:
- label: OpenAI-Compatible Tool Use # 23 min
timeout_in_minutes: 35
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
fast_check: false

3
.github/CODEOWNERS vendored
View File

@ -5,9 +5,7 @@
/vllm/attention @LucasWilkinson
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
/vllm/model_executor/layers/fused_moe @mgoin
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
/vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/model_loader @22quinn
@ -26,7 +24,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
# vLLM V1
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
/vllm/v1/attention @LucasWilkinson
/vllm/v1/attention/backends/flashinfer.py @mgoin
/vllm/v1/attention/backends/triton_attn.py @tdoublep

View File

@ -2,6 +2,7 @@
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
@ -18,11 +19,22 @@ __global__ void rms_norm_kernel(
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
const scalar_t* input_row = input + blockIdx.x * input_stride;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
constexpr int VEC_SIZE = 8;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x;
}
};
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;

View File

@ -10,6 +10,7 @@
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel(
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
const scalar_t* input_row = input + blockIdx.x * input_stride;
constexpr int VEC_SIZE = 8;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x;
}
};
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;

View File

@ -421,7 +421,9 @@ def test_attention_quant_pattern(
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
# Note: fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

View File

@ -649,5 +649,65 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
req.cache_salt = "test_salt"
with suppress(Exception):
await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt"
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio
async def test_serving_chat_data_parallel_rank_extraction():
"""Test that data_parallel_rank is properly extracted from header and passed to engine."""
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
# Test when data_parallel_rank is present in header
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
)
# Mock request with X-data-parallel-rank header
mock_raw_request = MagicMock()
mock_raw_request.headers = {"X-data-parallel-rank": "2"}
mock_raw_request.state = MagicMock()
with suppress(Exception):
await serving_chat.create_chat_completion(req, mock_raw_request)
# Verify that data_parallel_rank was passed to engine.generate
assert 'data_parallel_rank' in mock_engine.generate.call_args.kwargs
assert mock_engine.generate.call_args.kwargs['data_parallel_rank'] == 2
# Test when data_parallel_rank is not present (defaults to None)
req_no_dp = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 2+2?"
}],
)
# Mock request with no header
mock_raw_request_no_dp = MagicMock()
mock_raw_request_no_dp.headers = {}
mock_raw_request_no_dp.state = MagicMock()
with suppress(Exception):
await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)
# Verify that data_parallel_rank defaults to None
assert 'data_parallel_rank' in mock_engine.generate.call_args.kwargs
assert mock_engine.generate.call_args.kwargs['data_parallel_rank'] is None

View File

@ -34,15 +34,21 @@ else
fi
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
)
MODEL_NAMES=${MODEL_NAMES:-}
if [[ -n "$MODEL_NAMES" ]]; then
MODELS=("$MODEL_NAMES")
else
MODELS=(
"Qwen/Qwen3-0.6B"
)
fi
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
@ -130,7 +136,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
@ -171,7 +177,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"

View File

@ -12,7 +12,11 @@ FILTER = "exact_match,strict-match"
RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59}
EXPECTED_VALUES = {
"Qwen/Qwen3-0.6B": 0.41,
"deepseek-ai/deepseek-vl2-small": 0.59,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
}
SIMPLE_PROMPT = (
"The best part about working on vLLM is that I got to meet so many people across "

View File

@ -0,0 +1,43 @@
#!/usr/bin/env bash
set -euo pipefail
# Utility to run integration tests sequentially with varying TP configurations.
# If FLASHINFER is set, reruns all tests with VLLM_ATTENTION_BACKEND=FLASHINFER.
SCRIPT="tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh"
# Define test configurations
configs=(
"PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.6 MODEL_NAMES=deepseek-ai/DeepSeek-V2-Lite-Chat" # MLA case
# TP greater than num heads
)
run_tests() {
local label=$1
local extra_env=$2
echo "=== Running tests (${label}) ==="
for cfg in "${configs[@]}"; do
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
# Use 'env' to safely set variables without eval
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
exit 1
fi
done
echo "✅ All ${label} tests passed!"
}
# Run base tests
run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty)
if [[ -n "${FLASHINFER:-}" ]]; then
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
else
echo "FLASHINFER not set, skipping FLASHINFER runs."
fi

View File

@ -308,21 +308,42 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert expected_engine_id == self.REMOTE_ENGINE_ID
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
),
remote_tp_size=remote_tp_size,
)
return {0: remote_agent_name}
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_info.tp_ratio(remote_tp_size=remote_tp_size)
if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens = [
block_len // (-tp_ratio) for block_len in remote_block_lens
]
elif remote_tp_size < self.world_size:
remote_block_lens = [
block_len * tp_ratio for block_len in remote_block_lens
]
# When remote tp_size > local tp_size, handshake with multiple
# remote ranks.
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
remote_agents: dict[int, str] = {}
for remote_tp_rank in range(num_hanshakes):
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_lens=remote_block_lens,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
)
remote_agents[remote_tp_rank] = remote_agent_name
return remote_agents
class TestNixlHandshake:
@ -353,7 +374,13 @@ class TestNixlHandshake:
vllm_config, connector.engine_id, hand_shake_latency=0
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
worker.kv_cache_layout = "HND"
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
@ -465,6 +492,70 @@ class TestNixlHandshake:
return
raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_info.tp_ratio(remote_engine_id)
# ensure src_xfer_side_chunked_handles is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_side_chunked_handles
assert len(worker.src_xfer_side_chunked_handles[-tp_ratio]) == tp_ratio
assert remote_engine_id in worker.dst_xfer_side_handles
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
range(tp_ratio)
)
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
# NOTE flexiblity: a second remote with higher number of ranks
# is discovered
worker.REMOTE_ENGINE_ID = "remote_engine_2"
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=6,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(6)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
@ -565,12 +656,9 @@ class TestNixlHandshake:
kv_cache_layout=mismatched_layout,
)
# We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved.
with pytest.raises(RuntimeError):
# mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2)
# Layout check done for both homogeneous and heterogeneous TP.
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=2)
worker.add_remote_agent(meta, remote_tp_size=1)
@patch(
@ -1180,7 +1268,8 @@ def test_shutdown_cleans_up_resources(dist_init):
):
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
worker.src_xfer_side_chunked_handles = {-2: [456]}
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]
@ -1194,7 +1283,7 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener.join.assert_called_once_with(timeout=0)
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
assert mock_rel_dlist.call_count == 3
mock_rel_dlist.assert_any_call(456) # src handle
mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1")

View File

@ -41,14 +41,6 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
# Whether this backend supports receiving pre-quantized query input.
# If True, the attention layer will handle query quantization instead
# of the backend, allowing torch.compile to fuse quantization with
# previous operations.
# Needs to be worked through for all backends
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
@staticmethod
@abstractmethod
def get_name() -> str:
@ -199,6 +191,22 @@ class AttentionImpl(ABC, Generic[T]):
"""
return False
def supports_quant_query_input(self) -> bool:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod

View File

@ -36,6 +36,7 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
@ -304,7 +305,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.attn_backend.supports_quant_query_input
and self.impl.supports_quant_query_input()
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
@ -329,7 +330,6 @@ class Attention(nn.Module, AttentionLayerBase):
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
@ -338,7 +338,10 @@ class Attention(nn.Module, AttentionLayerBase):
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
query, _ = self.query_quant(query, self._q_scale)
# check if query quantization is supported
if self.impl.supports_quant_query_input():
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import contextlib
import copy
import logging
@ -36,7 +38,6 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
get_tp_group,
)
from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -513,6 +514,88 @@ class NixlConnectorScheduler:
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
@dataclass
class KVInfo:
tp_size: int
tp_rank: int
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
def tp_ratio(
self,
remote_engine_id: Optional[EngineId] = None,
remote_tp_size: Optional[int] = None,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`. If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
"""
if remote_tp_size is None:
assert remote_engine_id is not None
remote_tp_size = self.remote_tp_size[remote_engine_id]
if self.tp_size >= remote_tp_size:
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
else:
assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}."
)
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def is_kv_replicated(
self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None
) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
if tp_size is None:
assert engine_id is not None
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(
self,
remote_engine_id: Optional[EngineId] = None,
remote_tp_size: Optional[int] = None,
) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(
remote_engine_id, remote_tp_size
)
def get_target_remote_ranks(
self,
remote_engine_id: Optional[EngineId] = None,
remote_tp_size: Optional[int] = None,
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
"""
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
else:
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
if self.replicates_kv_cache(remote_engine_id, remote_tp_size):
# When cache is replicated on remote, we only need to read
# from one remote (they all have the same cache).
return [self.tp_rank * tp_ratio]
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
@ -601,8 +684,10 @@ class NixlConnectorWorker:
self.copy_blocks: CopyBlocksOp | None = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
# rank may pull from multiple remote TP workers.
self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = (
defaultdict(dict)
)
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
@ -611,8 +696,13 @@ class NixlConnectorWorker:
# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
# Populated dynamically during handshake based on remote configuration.
# Keep track of regions at different tp_ratio values. tp_ratio->handles
self.src_xfer_side_chunked_handles: dict[int, list[int]] = {}
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[EngineId, int] = {}
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(
dict
)
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
@ -646,7 +736,6 @@ class NixlConnectorWorker:
# Protects _handshake_futures and _remote_agents.
self._handshake_lock = threading.RLock()
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
@ -678,6 +767,14 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
self.kv_info = self.KVInfo(
tp_size=self.world_size,
tp_rank=self.tp_rank,
remote_tp_size=self._tp_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
)
@staticmethod
def _nixl_handshake_listener(
metadata: NixlAgentMetadata,
@ -717,52 +814,53 @@ class NixlConnectorWorker:
start_time = time.perf_counter()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug(
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
# When target instance TP > local TP, we need to perform multiple
# handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
p_remote_ranks = self.kv_info.get_target_remote_ranks(
remote_tp_size=remote_tp_size
)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
remote_rank_to_agent_name = {}
for remote_rank in p_remote_ranks:
path = make_zmq_path("tcp", host, port + remote_rank)
logger.warning(
"Querying metadata on path: %s at remote rank %s", path, remote_rank
)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
)
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, p_remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
# Remote rank -> agent name.
return {p_remote_rank: remote_agent_name}
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
remote_rank_to_agent_name[remote_rank] = remote_agent_name
return remote_rank_to_agent_name
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
@ -916,7 +1014,7 @@ class NixlConnectorWorker:
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys())
@ -942,7 +1040,7 @@ class NixlConnectorWorker:
self.num_regions *= 2
# Register local/src descr for NIXL xfer.
blocks_data = []
self.src_blocks_data = []
for i, base_addr in enumerate(seen_base_addresses):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# NOTE With heter-TP, more blocks are prepared than what are
@ -954,7 +1052,7 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.tp_rank))
self.src_blocks_data.append((addr, kv_block_len, self.tp_rank))
if self._use_flashinfer:
# Separate and interleave K/V regions to maintain the same
@ -965,15 +1063,17 @@ class NixlConnectorWorker:
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
self.src_blocks_data.append((v_addr, kv_block_len, self.tp_rank))
logger.debug(
"Created %s blocks for src engine %s and rank %s",
len(blocks_data),
len(self.src_blocks_data),
self.engine_id,
self.tp_rank,
)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
descs = self.nixl_wrapper.get_xfer_descs(
self.src_blocks_data, self.nixl_memory_type
)
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs
@ -981,13 +1081,11 @@ class NixlConnectorWorker:
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
if self.vllm_config.model_config.hf_config.model_type == "llama4":
if self.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(
self.vllm_config.model_config.hf_text_config, Llama4TextConfig
)
llama4_config = self.vllm_config.model_config.hf_text_config
assert isinstance(self.model_config.hf_text_config, Llama4TextConfig)
llama4_config = self.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
@ -1007,7 +1105,7 @@ class NixlConnectorWorker:
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
@ -1035,10 +1133,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
The latter, in the case of D.world_size < P.world_size, requires that a
local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA case):
Here's an example for the last case described above (non-MLA):
rank_offset p_remote_tp_rank
(kv split no)
@ -1070,107 +1170,91 @@ class NixlConnectorWorker:
engine_id = nixl_agent_meta.engine_id
# TODO re-evaluate refreshing for scaling/recovery
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
logger.warning(
"Remote agent with engine_id %s and rank"
"%s already exchanged metadata, skip handshake.",
engine_id,
remote_tp_rank,
)
return self._remote_agents[engine_id][remote_tp_rank]
### Register remote agent metadata
if engine_id not in self._tp_size:
self._tp_size[engine_id] = remote_tp_size
else:
assert self._tp_size[engine_id] == remote_tp_size
# TODO We may eventually want to skip enforcing the same attn backend.
assert nixl_agent_meta.attn_backend_name == self.backend_name
remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata
)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id])
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
# Handle tp_size>num_kv_heads: replicate KV cache.
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
remote_block_len = nixl_agent_meta.block_lens[0]
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
if (
self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
and nixl_agent_meta.kv_cache_layout == "HND"
):
logger.info(
"Remote is HND and local is NHD, enabled additional permute "
"on local device KV."
)
self.enable_permute_local_kv = True
else:
raise RuntimeError(
"Heterogeneous TP expects same kv_cache_layout. "
"Or enable experimental feature to use HND to NHD support by "
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
if self.use_mla or is_kv_replicated:
# With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
"KV cache sizes must match between P and D when replicated"
)
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
else:
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, (
"All remote layers must have the same block size"
)
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio
)
if self._use_flashinfer:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout.
if nixl_agent_meta.kv_cache_layout == "NHD":
raise ValueError(
"Heterogeneous TP is not supported for remote with NHD."
)
if self.device_type == "xpu":
raise ValueError("Heterogeneous TP is not supported on XPU")
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
assert self.block_size == remote_block_size, (
"Remote P worker with different page/block size is not supported "
f"{self.block_size=}, {remote_block_size=}"
)
# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
else:
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
nixl_agent_meta.kv_caches_base_addr
)
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# this is the ratio between the two sizes.
tp_ratio = self.kv_info.tp_ratio(engine_id)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = (
not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio > 0
)
logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id,
remote_tp_rank,
tp_ratio,
)
### (Optional) Register local agent memory regions.
# MLA-optimization: only prepare one region.
if (
tp_ratio < 0
and not self.use_mla
and tp_ratio not in self.src_xfer_side_chunked_handles
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self.src_xfer_side_chunked_handles[tp_ratio] = []
for i in range(-tp_ratio):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len = local_block_len // (-tp_ratio)
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_side_chunked_handles[tp_ratio].append(handle)
### Register remote agent memory regions
blocks_data = []
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# Read our whole local region size from remote.
local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len = local_block_len // (-tp_ratio)
rank_offset = (
self.tp_rank % tp_ratio * kv_block_len
if not (self.use_mla or is_kv_replicated)
else 0
self.tp_rank % tp_ratio * local_block_len if indexes_into_remote else 0
)
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_lens[i]
@ -1179,7 +1263,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, remote_tp_rank))
blocks_data.append((addr, local_block_len, remote_tp_rank))
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
@ -1187,7 +1271,7 @@ class NixlConnectorWorker:
block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
blocks_data.append((v_addr, local_block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
@ -1199,12 +1283,87 @@ class NixlConnectorWorker:
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
)
return remote_agent_name
def _validate_remote_agent_handshake(
self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int
):
"""
Validate the remote agent handshake metadata ensuring the
invariants hold true.
"""
remote_engine_id = nixl_agent_meta.engine_id
assert self._tp_size[remote_engine_id] == remote_tp_size
# TODO We may eventually want to skip enforcing the same attn backend.
assert nixl_agent_meta.attn_backend_name == self.backend_name
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
tp_ratio = self.kv_info.tp_ratio(remote_engine_id)
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_info.is_kv_replicated(remote_engine_id))
# Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0]
if self.kv_info.replicates_kv_cache(remote_engine_id):
# With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
"KV cache sizes must match between P and D when replicated"
)
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
else:
if tp_ratio != 1 and self.device_type == "xpu":
# XPU uses NHD, hence it does not support splitting on H
raise ValueError("Heterogeneous TP is not supported on XPU")
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, (
"All remote layers must have the same block size"
)
if tp_ratio > 0:
# Remote NHD/H'D*tp_ratio=N -page_size-
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio
)
# Remote tp is smaller: remote block_len size is bigger
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
else:
# Remote NHD/(H'D/tp_ratio)=N -page_size-
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] // (-tp_ratio)
)
# Remote tp is bigger: remote block_len size is smaller
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
if self._use_flashinfer:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
# We may allow it in the future with logical kvcache manager block_size
assert self.block_size == remote_block_size, (
"Remote P worker with different page/block size is not supported "
f"{self.block_size=}, {remote_block_size=}"
)
# TP workers (handhshakes with same remote) have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# Same number of regions/~layers.
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
"""copy recved kv from host buffer to device."""
assert self.use_host_buffer
@ -1384,7 +1543,7 @@ class NixlConnectorWorker:
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
in_progress = False
in_progress = []
for handle, _xfer_stime in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
@ -1393,7 +1552,7 @@ class NixlConnectorWorker:
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
in_progress = True
in_progress.append((handle, _xfer_stime))
continue
else:
# transfer failed - mark blocks as invalid
@ -1410,8 +1569,11 @@ class NixlConnectorWorker:
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
if not in_progress:
# Only report request as completed when all transfers are done.
done_req_ids.add(req_id)
del transfers[req_id]
else:
transfers[req_id] = in_progress
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
@ -1466,17 +1628,37 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
meta.remote_engine_id,
req_id,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
)
remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id)
tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
logger.debug(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s for req %s",
meta.remote_engine_id,
remote_rank,
req_id,
)
if tp_ratio < 0 and not self.use_mla:
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]
else:
# Single read from remote, we write to the whole memory region.
local_xfer_side_handle = self.src_xfer_side_handle
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][
remote_rank
]
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_rank=remote_rank,
local_xfer_side_handle=local_xfer_side_handle,
remote_xfer_side_handle=remote_xfer_side_handle,
)
def _read_blocks(
self,
@ -1484,7 +1666,14 @@ class NixlConnectorWorker:
remote_block_ids: list[int],
dst_engine_id: str,
request_id: str,
remote_rank: int,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
"""
Post a READ xfer request from a single local worker to a single
remote worker.
"""
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
@ -1497,14 +1686,14 @@ class NixlConnectorWorker:
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks.
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id]
# Cap to 1 when P TP > D TP: only a single rank will read from remote.
tp_ratio = max(1, self.kv_info.tp_ratio(dst_engine_id))
notif_id = f"{request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
remote_rank = self.tp_rank // tp_ratio
agent_name = self._remote_agents[dst_engine_id][remote_rank]
try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
@ -1524,10 +1713,6 @@ class NixlConnectorWorker:
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
@ -1680,15 +1865,20 @@ class NixlConnectorWorker:
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join(timeout=0)
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values():
for handle, _ in handles:
for rcv_handles in self._recving_transfers.values():
for handle, _ in rcv_handles:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
for handles in self.src_xfer_side_chunked_handles.values():
for handle in handles:
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_side_chunked_handles.clear()
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():

View File

@ -386,6 +386,24 @@ async def get_server_load_metrics(request: Request):
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
@router.get("/get_server_info")
async def get_server_info(raw_request: Request):
"""Returns server information including DP size for router"""
config = raw_request.app.state.vllm_config
# Extract dp_size from parallel_config
dp_size = 1 # Default value
if hasattr(config, 'parallel_config') and hasattr(config.parallel_config, 'data_parallel_size'):
dp_size = config.parallel_config.data_parallel_size
server_info = {
"vllm_config": str(config),
"dp_size": dp_size
}
return JSONResponse(content=server_info)
@router.get("/ping", response_class=Response)
@router.post("/ping", response_class=Response)
async def ping(raw_request: Request) -> Response:

View File

@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing):
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)

View File

@ -141,6 +141,10 @@ class OpenAIServingCompletion(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
@ -224,6 +228,7 @@ class OpenAIServingCompletion(OpenAIServing):
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)

View File

@ -1297,6 +1297,21 @@ class OpenAIServing:
return raw_request.headers.get("X-Request-Id", default)
@staticmethod
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
"""Pulls the data parallel rank from a header, if provided"""
if raw_request is None:
return None
rank_str = raw_request.headers.get("X-data-parallel-rank")
if rank_str is None:
return None
try:
return int(rank_str)
except ValueError:
return None
@staticmethod
def _get_decoded_token(
logprob: Logprob,

View File

@ -191,13 +191,14 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: str | None = None
VLLM_NVFP4_GEMM_BACKEND: str | None = None
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: str | None = None
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
@ -1292,11 +1293,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM.
"VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool(
int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))
# Supported options:
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
# - <none>: automatically pick an available backend
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND",
None,
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.
@ -1338,7 +1343,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(
int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))
int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))
),
# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
@ -1492,7 +1497,6 @@ def compute_hash() -> str:
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP16",
"VLLM_USE_FLASHINFER_MOE_FP8",
@ -1524,6 +1528,7 @@ def compute_hash() -> str:
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
"VLLM_NVFP4_GEMM_BACKEND",
"VLLM_USE_FBGEMM",
]
for key in environment_variables_to_hash:

View File

@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
swizzle_blockscale,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
@ -29,10 +32,12 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
self.backend = "flashinfer-trtllm"
logger.info_once("Using flashinfer-trtllm for FP4")
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
if has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif envs.VLLM_USE_FBGEMM:
self.backend = "fbgemm"
try:
@ -42,12 +47,17 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
elif has_flashinfer():
self.backend = "flashinfer-cutlass"
logger.info_once("Using flashinfer-cutlass for FP4")
else:
self.backend = "cutlass"
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
if self.backend == "none":
raise ValueError(
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
self.group_size = 16
@classmethod
@ -184,10 +194,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer.alpha,
output_dtype,
)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
elif self.backend == "fbgemm":
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
@ -198,6 +207,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
use_mx=False,
).to(output_dtype)
else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None:

View File

@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
self.backend = "flashinfer-trtllm"
elif has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif is_fp4_marlin_supported():
self.backend = "marlin"
else:
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
if has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif is_fp4_marlin_supported():
self.backend = "marlin"
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
if self.backend == "none":
raise ValueError(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
def create_weights(
self,
layer: torch.nn.Module,
@ -1109,11 +1113,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha,
output_dtype,
)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None:

View File

@ -32,7 +32,7 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool:
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.is_device_capability(100)
and current_platform.has_device_capability(100)
)

View File

@ -49,7 +49,6 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supports_quant_query_input: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
@ -494,6 +493,9 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer"
)
def supports_quant_query_input(self) -> bool:
return True
def forward(
self,
layer: torch.nn.Module,

View File

@ -16,7 +16,6 @@ from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
@ -828,6 +827,12 @@ class FlashInferImpl(AttentionImpl):
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
)
def supports_quant_query_input(self) -> bool:
if flashinfer_disable_q_quantization():
return False
return self.support_trtllm_attn
def forward(
self,
layer: torch.nn.Module,
@ -859,6 +864,12 @@ class FlashInferImpl(AttentionImpl):
# Profiling run.
return output.fill_(0)
# Ensure query dtype matches the expected dtype from attention metadata
assert attn_metadata.q_data_type == query.dtype, (
f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
f"got {query.dtype}"
)
if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
@ -899,15 +910,6 @@ class FlashInferImpl(AttentionImpl):
elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float
# Insert FP8 quant for query
if attn_metadata.q_data_type == FP8_DTYPE:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead

View File

@ -32,11 +32,6 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.kv_cache_interface import AttentionSpec
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def supports_quant_query_input(self) -> bool:
return current_platform.is_cuda()
def __init__(
self,
num_heads: int,
@ -338,19 +336,9 @@ class TritonAttentionImpl(AttentionImpl):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens