Compare commits
4 Commits
whisper-tr
...
khluu/try_
| Author | SHA1 | Date | |
|---|---|---|---|
| db9dfcfa6a | |||
| 9ef98d527e | |||
| 93491aefc7 | |||
| 7acd539cd7 |
@ -15,14 +15,12 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
---
|
||||
|
||||
[2025/03] We are collaborating with Ollama to host an [Inference Night](https://lu.ma/vllm-ollama) at Y Combinator in San Francisco on Thursday, March 27, at 6 PM. Discuss all things inference local or data center!
|
||||
|
||||
[2025/04] We're hosting our first-ever *vLLM Asia Developer Day* in Singapore on *April 3rd*! This is a full-day event (9 AM - 9 PM SGT) in partnership with SGInnovate, AMD, and Embedded LLM. Meet the vLLM team and learn about LLM inference for RL, MI300X, and more! [Register Now](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
||||
|
||||
@ -4,6 +4,8 @@
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [The first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg), March 16th 2025. [[Slides]](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [The East Coast vLLM Meetup](https://lu.ma/7mu4k4xx), March 11th 2025. [[Slides]](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0)
|
||||
- [The ninth vLLM meetup](https://lu.ma/h7g3kuj9), with Meta, February 27th 2025. [[Slides]](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing)
|
||||
- [The eighth vLLM meetup](https://lu.ma/zep56hui), with Google Cloud, January 22nd 2025. [[Slides]](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing)
|
||||
|
||||
@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `MiniMaxText01ForCausalLM`
|
||||
* MiniMax-Text
|
||||
* `MiniMaxAI/MiniMax-Text-01`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
- * `Zamba2ForCausalLM`
|
||||
* Zamba2
|
||||
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# Usage Stats Collection
|
||||
|
||||
vLLM collects anonymous usage data by default to help the engineering team better understand which hardware and model configurations are widely used. This data allows them to prioritize their efforts on the most common workloads. The collected data is transparent, does not contain any sensitive information, and will be publicly released for the community's benefit.
|
||||
vLLM collects anonymous usage data by default to help the engineering team better understand which hardware and model configurations are widely used. This data allows them to prioritize their efforts on the most common workloads. The collected data is transparent, does not contain any sensitive information.
|
||||
|
||||
A subset of the data, after cleaning and aggregation, will be publicly released for the community's benefit. For example, you can see the 2024 usage report [here](https://2024.vllm.ai).
|
||||
|
||||
## What data is collected?
|
||||
|
||||
|
||||
286
tests/kernels/test_lightning_attn.py
Normal file
286
tests/kernels/test_lightning_attn.py
Normal file
@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
linear_decode_forward_triton)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [4, 8]
|
||||
HEAD_SIZES = [64]
|
||||
BATCH_SIZES = [1, 2]
|
||||
SEQ_LENGTHS = [16]
|
||||
DTYPES = [torch.float32]
|
||||
|
||||
|
||||
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
"""Reference implementation of lightning attention core algorithm
|
||||
|
||||
The difference from the main implementation is that this processes
|
||||
each step sequentially, instead of using parallelized triton kernels
|
||||
"""
|
||||
B, H, S, D = q.shape
|
||||
E = v.shape[-1]
|
||||
dtype = q.dtype
|
||||
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
|
||||
|
||||
# Use clone() to ensure an independent copy
|
||||
if kv_history is None:
|
||||
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
|
||||
else:
|
||||
kv_cache = kv_history.clone()
|
||||
|
||||
# More efficient implementation
|
||||
# Convert decay factors to matrix form
|
||||
if ed.dim() == 1:
|
||||
decay = torch.exp(-ed).view(1, -1, 1, 1)
|
||||
else:
|
||||
decay = torch.exp(-ed)
|
||||
|
||||
for b in range(B):
|
||||
for step in range(S):
|
||||
# Process all heads at once for this position
|
||||
q_bs = q[b, :, step] # [H, D]
|
||||
k_bs = k[b, :, step] # [H, D]
|
||||
v_bs = v[b, :, step] # [H, E]
|
||||
|
||||
# Calculate KV outer products for all heads
|
||||
for h in range(H):
|
||||
# Calculate KV outer product
|
||||
kv_outer = torch.outer(k_bs[h], v_bs[h])
|
||||
|
||||
# Update KV cache with decay
|
||||
# Note: Using the same order as in the Triton kernel
|
||||
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
|
||||
|
||||
# Calculate attention output
|
||||
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
|
||||
|
||||
# Match the shape returned by the actual implementation
|
||||
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
|
||||
# where dimension 2 contains both KV and KV history
|
||||
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
|
||||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
|
||||
dim=2) # [B, H, 2, D, E]
|
||||
|
||||
return output, final_kv_cache
|
||||
|
||||
|
||||
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
|
||||
"""Reference implementation: linear attention decode function"""
|
||||
B, H, _, D = q.shape
|
||||
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
|
||||
|
||||
# Calculate decay factors once (more efficient)
|
||||
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
|
||||
|
||||
# Process each batch
|
||||
for b in range(B):
|
||||
slot_id = slot_idx[b].item()
|
||||
|
||||
# Skip padding positions
|
||||
if slot_id == -1:
|
||||
continue
|
||||
|
||||
# Process all heads at once for this batch
|
||||
q_b = q[b, :, 0] # [H, D]
|
||||
k_b = k[b, :, 0] # [H, D]
|
||||
v_b = v[b, :, 0] # [H, D]
|
||||
|
||||
# Process each attention head
|
||||
for h in range(H):
|
||||
# Get current query, key and value
|
||||
q_bh = q_b[h]
|
||||
k_bh = k_b[h]
|
||||
v_bh = v_b[h]
|
||||
|
||||
# Get cache
|
||||
kv_cache_old = kv_caches[b, h]
|
||||
|
||||
# Calculate new key-value outer product
|
||||
kv_outer = torch.outer(k_bh, v_bh)
|
||||
|
||||
# Apply decay and update cache
|
||||
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
|
||||
|
||||
# Calculate output
|
||||
out_h = torch.matmul(q_bh, kv_new)
|
||||
|
||||
# Update output and cache
|
||||
output[b, h * D:(h + 1) * D] = out_h
|
||||
kv_caches[b, h] = kv_new
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_linear_decode_forward_triton(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.arange(batch_size, device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
torch.testing.assert_close(triton_output,
|
||||
reference_output,
|
||||
rtol=1e-1,
|
||||
atol=1e-1)
|
||||
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_linear_decode_forward_triton_with_padding(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
batch_size = 4
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
padding_mask = (slot_idx
|
||||
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
||||
|
||||
triton_masked = triton_output[padding_mask]
|
||||
reference_masked = reference_output[padding_mask]
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
|
||||
valid_indices = slot_idx != -1
|
||||
|
||||
for i in range(batch_size):
|
||||
if valid_indices[i] > 0:
|
||||
torch.testing.assert_close(kv_caches[i],
|
||||
kv_caches_copy[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
torch.testing.assert_close(triton_masked,
|
||||
reference_masked,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_lightning_attention_reference(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
base = 0.01
|
||||
q = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
k = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
|
||||
ed = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
ed[h] = 0.1 * (h + 1)
|
||||
|
||||
kv_history = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_history_clone = kv_history.clone()
|
||||
|
||||
ref_output, ref_kv_cache = reference_lightning_attention(
|
||||
q, k, v, ed, 256, kv_history)
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import lightning_attention
|
||||
actual_output, actual_kv_cache = lightning_attention(
|
||||
q, k, v, ed, 256, kv_history_clone)
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(ref_kv_cache,
|
||||
actual_kv_cache,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
|
||||
assert ref_kv_cache.shape == actual_kv_cache.shape
|
||||
@ -176,6 +176,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
|
||||
trust_remote_code=True),
|
||||
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
|
||||
trust_remote_code=True),
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
|
||||
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
|
||||
|
||||
@ -971,26 +971,34 @@ class ModelConfig:
|
||||
return sum(not bc.attention.no_op
|
||||
for bc in block_configs[start:end])
|
||||
else:
|
||||
# Hybrid model
|
||||
# Hybrid model Jamba
|
||||
layers_block_type_value = getattr(self.hf_config,
|
||||
"layers_block_type", None)
|
||||
if layers_block_type_value is None:
|
||||
raise ValueError("The model is an hybrid without a "
|
||||
"layers_block_type in the hf_config, "
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
if layers_block_type_value is not None:
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
# Hybrid model Minimax
|
||||
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
|
||||
if attn_type_list:
|
||||
return sum(t == 1 for t in attn_type_list[start:end])
|
||||
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
if layers_block_type_value is None and attn_type_list is None:
|
||||
raise ValueError(
|
||||
"The model is an hybrid without a"
|
||||
"layers_block_type or an attn_type_list in the hf_config,"
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
return sum(t == 1 for t in attn_type_list[start:end])
|
||||
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
|
||||
@ -303,8 +303,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
if not scheduler_outputs.is_empty():
|
||||
# this will cause mamba_cache/minimax_cache failed
|
||||
# to release finished_requests_ids of the last steps
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
|
||||
@ -1098,9 +1098,10 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
await shutdown_task
|
||||
|
||||
sock.close()
|
||||
try:
|
||||
await shutdown_task
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
651
vllm/model_executor/layers/lightning_attn.py
Normal file
651
vllm/model_executor/layers/lightning_attn.py
Normal file
@ -0,0 +1,651 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
||||
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
|
||||
NUM_BLOCK, CBLOCK: tl.constexpr):
|
||||
# This kernel computes the diagonal blocks of the attention matrix
|
||||
# Each diagonal block represents attention
|
||||
# where queries attend to keys in the same block
|
||||
off = tl.program_id(0)
|
||||
off_bh = off // NUM_BLOCK # batch-head index
|
||||
off_block = off % NUM_BLOCK # block index within the sequence
|
||||
off_cblock = tl.program_id(1) # sub-block index within a block
|
||||
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
# Calculate base offsets for the current batch and head
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
|
||||
# Calculate offsets for the current block
|
||||
block_offset = off_block * BLOCK
|
||||
qk_block_offset = block_offset * d
|
||||
v_block_offset = block_offset * e
|
||||
o_block_offset = block_offset * e
|
||||
|
||||
# Calculate offsets for the current sub-block
|
||||
cblock_offset = off_cblock * CBLOCK
|
||||
q_cblock_offset = cblock_offset * d
|
||||
o_cblock_offset = cblock_offset * e
|
||||
|
||||
# Calculate pointers to the query, key, value, and output tensors
|
||||
Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * d +
|
||||
tl.arange(0, d)[None, :])
|
||||
K_trans_block_ptr = (K + qk_offset + qk_block_offset +
|
||||
tl.arange(0, CBLOCK)[None, :] * d +
|
||||
tl.arange(0, d)[:, None])
|
||||
V_block_ptr = (V + v_offset + v_block_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, e)[None, :])
|
||||
O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, e)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
S_block_ptr = S + off_h
|
||||
s = tl.load(S_block_ptr)
|
||||
|
||||
i = off_cblock
|
||||
q_index = tl.arange(0, CBLOCK) + i * CBLOCK
|
||||
|
||||
# Load query values
|
||||
q = tl.load(Q_block_ptr,
|
||||
mask=block_offset + q_index[:, None] < n,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# Initialize output accumulator
|
||||
qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
|
||||
|
||||
# Process all sub-blocks up to and
|
||||
# including the current one (causal attention)
|
||||
for j in range(i + 1):
|
||||
kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
|
||||
diff = q_index[:, None] - kv_index[None, :]
|
||||
s_index = s * diff
|
||||
# Apply causal mask: only attend to positions before the current one
|
||||
s_index = tl.where(diff >= 0, -s_index, float("-inf"))
|
||||
decay = tl.exp(s_index)
|
||||
|
||||
# Load key and value
|
||||
k_trans = tl.load(
|
||||
K_trans_block_ptr,
|
||||
mask=block_offset + kv_index[None, :] < n,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
v = tl.load(
|
||||
V_block_ptr,
|
||||
mask=block_offset + kv_index[:, None] < n,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# Compute attention scores and apply decay
|
||||
qk = tl.dot(q, k_trans) * decay
|
||||
|
||||
# Compute weighted values and accumulate
|
||||
qkv += tl.dot(qk, v)
|
||||
|
||||
# Move to the next sub-block
|
||||
K_trans_block_ptr += CBLOCK * d
|
||||
V_block_ptr += CBLOCK * e
|
||||
|
||||
# Store the result
|
||||
tl.store(
|
||||
O_block_ptr,
|
||||
qkv.to(O_block_ptr.dtype.element_ty),
|
||||
mask=block_offset + q_index[:, None] < n,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kv_parallel(
|
||||
K,
|
||||
V,
|
||||
K_decay,
|
||||
KV,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK,
|
||||
D_FBLOCK: tl.constexpr,
|
||||
E_FBLOCK: tl.constexpr,
|
||||
NUM_FBLOCK: tl.constexpr,
|
||||
CBLOCK: tl.constexpr,
|
||||
NUM_CBLOCK: tl.constexpr,
|
||||
):
|
||||
# This kernel computes the key-value outer
|
||||
# products for each block in parallel
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_block = tl.program_id(1) # block index
|
||||
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
block_offset = off_block * BLOCK
|
||||
|
||||
# Calculate offsets for the current block
|
||||
k_block_offset = block_offset * d
|
||||
v_block_offset = block_offset * e
|
||||
kv_block_offset = off_block * d * e
|
||||
|
||||
# Calculate base offsets for the current batch and head
|
||||
k_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e
|
||||
|
||||
# Calculate pointers to the key, value, and key-value tensors
|
||||
K_trans_block_ptr = (K + k_offset + k_block_offset +
|
||||
tl.arange(0, CBLOCK)[None, :] * d +
|
||||
tl.arange(0, D_FBLOCK)[:, None])
|
||||
V_block_ptr = (V + v_offset + v_block_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (KV + kv_offset + kv_block_offset +
|
||||
tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay factors for the current head and block
|
||||
k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :])
|
||||
|
||||
kv_index = tl.arange(0, CBLOCK)
|
||||
|
||||
# Initialize the key-value outer product accumulator
|
||||
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
|
||||
|
||||
# Handle the last block which might be smaller than BLOCK
|
||||
if off_block == NUM_BLOCK - 1:
|
||||
split_n = n - (NUM_BLOCK - 1) * BLOCK
|
||||
else:
|
||||
split_n = BLOCK
|
||||
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
|
||||
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
|
||||
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
|
||||
|
||||
# Process all sub-blocks in the current block
|
||||
for j in range(num_blocks):
|
||||
left_bound = (1 - j) * left_shift
|
||||
# Load key and value, handling boundary conditions
|
||||
k_trans = tl.load(K_trans_block_ptr - left_shift * d,
|
||||
mask=kv_index[None, :] >= left_bound,
|
||||
other=0.0)
|
||||
v = tl.load(V_block_ptr - left_shift * e,
|
||||
mask=kv_index[:, None] >= left_bound,
|
||||
other=0.0)
|
||||
|
||||
# Load decay factor and compute weighted key-value outer product
|
||||
k_decay = tl.load(k_decay_ptr)
|
||||
kv += tl.dot(k_trans * k_decay, v)
|
||||
|
||||
# Move to the next sub-block
|
||||
K_trans_block_ptr += CBLOCK * d
|
||||
V_block_ptr += CBLOCK * e
|
||||
k_decay_ptr += CBLOCK
|
||||
|
||||
# Store the result
|
||||
tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
|
||||
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
|
||||
NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr):
|
||||
# This kernel reduces the key-value outer products
|
||||
# across blocks and updates the KV history
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e
|
||||
|
||||
# Calculate pointer to the key-value tensor
|
||||
KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
s_ptrs = S + off_h
|
||||
s = tl.load(s_ptrs)
|
||||
|
||||
# Calculate pointer to the key-value history tensor
|
||||
kv_history_offset = off_bh * d * e
|
||||
KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset +
|
||||
tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the previous key-value history
|
||||
kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
|
||||
|
||||
# Process all blocks in reverse order to compute the prefix sum
|
||||
for i in range(NUM_BLOCK):
|
||||
block_size = min(n - i * BLOCK, BLOCK)
|
||||
# Compute decay factor for the current block
|
||||
block_decay = tl.exp(-s.to(tl.float32) * block_size)
|
||||
|
||||
# Load the current key-value outer product
|
||||
kv_cur = tl.load(KV_block_ptr).to(tl.float32)
|
||||
# Store the previous key-value history to the current block
|
||||
tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
|
||||
|
||||
# Update the key-value history with the current block
|
||||
kv_pre = block_decay * kv_pre + kv_cur
|
||||
KV_block_ptr += d * e
|
||||
|
||||
# Store the updated key-value history
|
||||
tl.store(KV_HISTORY_block_ptr, kv_pre)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_none_diag_kernel(
|
||||
Q,
|
||||
Out,
|
||||
S,
|
||||
KV,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK,
|
||||
E_FBLOCK: tl.constexpr,
|
||||
CBLOCK: tl.constexpr,
|
||||
NUM_CBLOCK: tl.constexpr,
|
||||
):
|
||||
# This kernel computes the non-diagonal blocks of the attention matrix
|
||||
# Each non-diagonal block represents attention
|
||||
# where queries attend to keys in different blocks
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
off_nc = tl.program_id(1)
|
||||
off_n = off_nc // NUM_CBLOCK # block index
|
||||
off_c = off_nc % NUM_CBLOCK # sub-block index
|
||||
off_e = tl.program_id(2) # output feature block index
|
||||
|
||||
n_offset = off_n * BLOCK
|
||||
c_offset = off_c * CBLOCK
|
||||
e_offset = off_e * E_FBLOCK
|
||||
block_offset = n_offset + c_offset
|
||||
|
||||
# Calculate offsets for the current batch, head, and block
|
||||
q_offset = off_bh * n * d + (n_offset + c_offset) * d
|
||||
o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
|
||||
|
||||
# Calculate pointers to the query, output, and key-value tensors
|
||||
Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d +
|
||||
tl.arange(0, d)[None, :])
|
||||
O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
S_block_ptr = S + off_h
|
||||
s = tl.load(S_block_ptr)
|
||||
|
||||
c_array = tl.arange(0, CBLOCK)
|
||||
|
||||
# Load the key-value outer product for the current block
|
||||
kv = tl.load(KV_block_ptr).to(tl.float32)
|
||||
q_index = block_offset + tl.arange(0, CBLOCK)
|
||||
|
||||
# Load query values
|
||||
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n,
|
||||
other=0.).to(tl.float32)
|
||||
|
||||
# Compute decay factors for the current sub-block
|
||||
q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
|
||||
|
||||
# Compute non-diagonal attention output
|
||||
qkv_none_diag = tl.dot(q, kv) * q_decay
|
||||
|
||||
# Load diagonal attention output (computed by _fwd_diag_kernel)
|
||||
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n,
|
||||
other=0.).to(tl.float32)
|
||||
|
||||
# Combine diagonal and non-diagonal attention outputs
|
||||
qkv = qkv_diag + qkv_none_diag
|
||||
|
||||
# Store the result
|
||||
tl.store(O_block_ptr,
|
||||
qkv.to(O_block_ptr.dtype.element_ty),
|
||||
mask=q_index[:, None] < n)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, s, kv_history):
|
||||
# Forward pass of the lightning attention algorithm
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
s = s.contiguous()
|
||||
|
||||
# Check CUDA compute capability
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported",
|
||||
"for compute capability >= 80")
|
||||
|
||||
# Get input dimensions
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
|
||||
# Initialize output tensor
|
||||
o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
|
||||
|
||||
# Set block sizes
|
||||
BLOCK = 256
|
||||
NUM_BLOCK = triton.cdiv(n, BLOCK)
|
||||
|
||||
CBLOCK = 32
|
||||
NUM_CBLOCK = BLOCK // CBLOCK
|
||||
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
|
||||
|
||||
# Compute decay factors for keys
|
||||
array = torch.arange(0, BLOCK, device=q.device) + 1
|
||||
k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1)))
|
||||
|
||||
# Step 1: Compute diagonal blocks of attention
|
||||
grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
|
||||
_fwd_diag_kernel[grid](q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
s,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
CBLOCK=CBLOCK)
|
||||
|
||||
# Set feature block sizes
|
||||
NUM_FBLOCK = 1
|
||||
D_FBLOCK = d // NUM_FBLOCK
|
||||
assert d % NUM_FBLOCK == 0
|
||||
E_FBLOCK = e // NUM_FBLOCK
|
||||
assert e % NUM_FBLOCK == 0
|
||||
|
||||
CBLOCK = 64
|
||||
NUM_CBLOCK = BLOCK // CBLOCK
|
||||
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
|
||||
|
||||
# Step 2: Compute key-value outer products for each block in parallel
|
||||
kv = torch.empty((b, h, NUM_BLOCK, d, e),
|
||||
dtype=torch.float32,
|
||||
device=q.device)
|
||||
grid = (b * h, NUM_BLOCK)
|
||||
_fwd_kv_parallel[grid](
|
||||
k,
|
||||
v,
|
||||
k_decay,
|
||||
kv,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
D_FBLOCK=D_FBLOCK,
|
||||
E_FBLOCK=E_FBLOCK,
|
||||
NUM_FBLOCK=NUM_FBLOCK,
|
||||
CBLOCK=CBLOCK,
|
||||
NUM_CBLOCK=NUM_CBLOCK,
|
||||
)
|
||||
|
||||
# Step 3: Reduce key-value outer products
|
||||
# across blocks and update KV history
|
||||
grid = (b * h, NUM_FBLOCK)
|
||||
_fwd_kv_reduce[grid](s,
|
||||
kv,
|
||||
kv_history,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
D_FBLOCK=D_FBLOCK,
|
||||
E_FBLOCK=E_FBLOCK)
|
||||
|
||||
# Step 4: Compute non-diagonal blocks of attention
|
||||
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
|
||||
_fwd_none_diag_kernel[grid](
|
||||
q,
|
||||
o,
|
||||
s,
|
||||
kv,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
E_FBLOCK=E_FBLOCK,
|
||||
CBLOCK=CBLOCK,
|
||||
NUM_CBLOCK=NUM_CBLOCK,
|
||||
)
|
||||
|
||||
# Save tensors for backward pass
|
||||
ctx.save_for_backward(q, k, v, s, kv)
|
||||
ctx.BLOCK = BLOCK
|
||||
|
||||
return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2)
|
||||
|
||||
|
||||
# Apply the lightning attention function
|
||||
lightning_attention_ = _attention.apply
|
||||
|
||||
|
||||
def lightning_attention(q, k, v, ed, block_size=256, kv_history=None):
|
||||
"""
|
||||
Apply lightning attention algorithm
|
||||
to compute attention efficiently.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [batch, heads, seq_len, dim]
|
||||
k: Key tensor of shape [batch, heads, seq_len, dim]
|
||||
v: Value tensor of shape [batch, heads, seq_len, dim_v]
|
||||
ed: Decay rate tensor of shape [heads]
|
||||
block_size: Size of blocks for block-sparse attention
|
||||
kv_history: Optional key-value history from previous computations
|
||||
|
||||
Returns:
|
||||
output: Attention output
|
||||
kv: Updated key-value history
|
||||
"""
|
||||
d = q.shape[-1]
|
||||
e = v.shape[-1]
|
||||
|
||||
if ed.dim() == 1:
|
||||
ed = ed.view(1, -1, 1, 1)
|
||||
|
||||
# Split the computation into chunks for better parallelism
|
||||
m = 128 if d >= 128 else 64
|
||||
assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
|
||||
arr = [m * i for i in range(d // m + 1)]
|
||||
if arr[-1] != d:
|
||||
arr.append(d)
|
||||
n = len(arr)
|
||||
output = 0
|
||||
|
||||
# Initialize or clone key-value history
|
||||
if kv_history is None:
|
||||
kv_history = torch.zeros((q.shape[0], q.shape[1], d, e),
|
||||
dtype=torch.float32,
|
||||
device=q.device)
|
||||
else:
|
||||
kv_history = kv_history.clone().contiguous()
|
||||
|
||||
# Process each chunk and accumulate results
|
||||
for i in range(n - 1):
|
||||
s = arr[i]
|
||||
e = arr[i + 1]
|
||||
q1 = q[..., s:e]
|
||||
k1 = k[..., s:e]
|
||||
o, kv = lightning_attention_(q1, k1, v, ed, kv_history)
|
||||
output = output + o
|
||||
return output, kv
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _linear_attn_decode_kernel(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
kv_cache_ptr,
|
||||
slope_rate,
|
||||
slot_idx,
|
||||
output_ptr,
|
||||
D: tl.constexpr,
|
||||
qkv_b_stride,
|
||||
qkv_h_stride,
|
||||
cache_b_stride,
|
||||
cache_h_stride,
|
||||
cache_d0_stride,
|
||||
cache_d1_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for linear attention decoding with KV cache.
|
||||
|
||||
This kernel computes attention for a single token using the KV cache.
|
||||
"""
|
||||
pid_b = tl.program_id(0) # batch index
|
||||
pid_h = tl.program_id(1) # head index
|
||||
pid_d = tl.program_id(2) # dimension block index
|
||||
|
||||
# Load slot index for the current batch
|
||||
slot_id = tl.load(slot_idx + pid_b)
|
||||
|
||||
# Skip if slot_id is -1 (padding)
|
||||
if slot_id == -1:
|
||||
return
|
||||
|
||||
batch_id = pid_b
|
||||
head_id = pid_h
|
||||
|
||||
# Load decay rate for the current head
|
||||
ratio = tl.load(slope_rate + pid_h)
|
||||
|
||||
# Calculate offsets for dimensions
|
||||
qk_d_offsets = tl.arange(0, D)
|
||||
v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
|
||||
cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[
|
||||
None, :] * cache_d1_stride
|
||||
|
||||
# Calculate offsets for the current batch and head
|
||||
q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
|
||||
cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
|
||||
|
||||
# Create masks for loading tensors
|
||||
qk_mask = qk_d_offsets < D
|
||||
v_mask = v_d_offsets < D
|
||||
|
||||
# Load query, key, and value tensors
|
||||
q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
|
||||
k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
|
||||
v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
|
||||
|
||||
# Compute key-value outer product
|
||||
kv_outer = k[:, None] * v[None, :]
|
||||
kv_mask = qk_mask[:, None] & v_mask[None, :]
|
||||
|
||||
# Apply decay to previous KV cache
|
||||
ratio = tl.exp(-ratio)
|
||||
kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
|
||||
kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
|
||||
kv_outer = kv_outer + ratio * kv_cache_old
|
||||
|
||||
# Compute attention output
|
||||
output = q[:, None].to(tl.float32) * kv_outer
|
||||
output = tl.sum(output, axis=0)
|
||||
|
||||
# Update KV cache and store output
|
||||
tl.store(kv_ptr, kv_outer, mask=kv_mask)
|
||||
tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
|
||||
|
||||
|
||||
def linear_decode_forward_triton(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
BLOCK_SIZE: int = 32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform linear attention decoding using Triton kernels.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [B, H, 1, D]
|
||||
k: Key tensor of shape [B, H, 1, D]
|
||||
v: Value tensor of shape [B, H, 1, D]
|
||||
kv_caches: Key-value cache tensor
|
||||
slope_rate: Decay rate tensor
|
||||
slot_idx: Slot indices for batches
|
||||
BLOCK_SIZE: Size of blocks for processing
|
||||
|
||||
Returns:
|
||||
output: Attention output tensor
|
||||
"""
|
||||
B, H, _, D = q.shape
|
||||
assert k.shape == (B, H, 1, D)
|
||||
assert v.shape == (B, H, 1, D)
|
||||
|
||||
# Initialize output tensor
|
||||
output = torch.empty_like(q)
|
||||
|
||||
# Set grid dimensions for the kernel
|
||||
grid = (B, H, D // BLOCK_SIZE)
|
||||
|
||||
# Calculate strides for tensors
|
||||
qkv_b_stride = q.stride(0)
|
||||
qkv_h_stride = q.stride(1)
|
||||
|
||||
cache_b_stride = kv_caches.stride(0)
|
||||
cache_h_stride = kv_caches.stride(1)
|
||||
cache_d0_stride = kv_caches.stride(2)
|
||||
cache_d1_stride = kv_caches.stride(3)
|
||||
|
||||
# Launch the kernel
|
||||
_linear_attn_decode_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_caches,
|
||||
slope_rate,
|
||||
slot_idx,
|
||||
output,
|
||||
D,
|
||||
qkv_b_stride,
|
||||
qkv_h_stride,
|
||||
cache_b_stride,
|
||||
cache_h_stride,
|
||||
cache_d0_stride,
|
||||
cache_d1_stride,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Reshape output and return
|
||||
output = rearrange(output, "b h n d -> b n (h d)")
|
||||
return output.squeeze(1).contiguous()
|
||||
136
vllm/model_executor/models/constant_size_cache.py
Normal file
136
vllm/model_executor/models/constant_size_cache.py
Normal file
@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
|
||||
class ConstantSizeCache(ABC):
|
||||
"""
|
||||
Abstract base class for managing constant size caches
|
||||
like Mamba and Minimax.
|
||||
"""
|
||||
|
||||
def __init__(self, max_batch_size: int):
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the cache
|
||||
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
self.free_cache_indices = list(range(max_batch_size))
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cache(self) -> Any:
|
||||
"""Return the underlying cache tensor(s)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
"""Copy cache data from one index to another"""
|
||||
pass
|
||||
|
||||
def current_run_tensors(self, **kwargs) -> Tuple:
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
state_indices_tensor = torch.as_tensor(state_indices,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
cache_tensors = self.cache
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
cache_tensors, state_indices_tensor = kwargs[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return (cache_tensors, state_indices_tensor)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant state_indices into the CUDA graph input buffer
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
||||
_, input_state_indices_buffer = input_buffers[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
||||
state_indices)
|
||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
||||
|
||||
input_state_indices_buffer.copy_(
|
||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Cache during the CUDA graph replay
|
||||
runs.
|
||||
"""
|
||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
return (self.cache, state_indices_tensor)
|
||||
|
||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
||||
finished_requests_ids) -> int:
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
if cur_rid in finished_requests_ids:
|
||||
# set as pad, do not allocate destination index
|
||||
return PAD_SLOT_ID
|
||||
elif cur_rid not in self.cache_indices_mapping:
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
|
||||
return destination_index
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened, so we copy the
|
||||
# existing cache into the siblings seq_ids caches
|
||||
index_exists = next(iter(seq_ids2indices.values()))
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self._copy_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
|
||||
return destination_index
|
||||
else:
|
||||
return self.cache_indices_mapping[cur_rid][seq_id]
|
||||
|
||||
def _prepare_current_run_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]) -> List[int]:
|
||||
return [
|
||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
||||
finished_requests_ids)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.cache_indices_mapping:
|
||||
for seq_id in self.cache_indices_mapping[req_id]:
|
||||
self.free_cache_indices.append(
|
||||
self.cache_indices_mapping[req_id][seq_id])
|
||||
self.cache_indices_mapping.pop(req_id)
|
||||
@ -1,12 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -21,7 +22,7 @@ class MambaCacheParams:
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MambaCacheManager:
|
||||
class MambaCacheManager(ConstantSizeCache):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
|
||||
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
|
||||
@ -32,6 +33,9 @@ class MambaCacheManager:
|
||||
if not vllm_config.model_config.enforce_eager:
|
||||
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(max_batch_size)
|
||||
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
dtype=dtype,
|
||||
@ -41,126 +45,32 @@ class MambaCacheManager:
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
self._mamba_cache = (conv_state, temporal_state)
|
||||
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the self.mamba_cache
|
||||
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
self.free_cache_indices = list(range(max_batch_size))
|
||||
@property
|
||||
def cache(self):
|
||||
return self._mamba_cache
|
||||
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
for cache_t in self.cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
state_indices_tensor = torch.as_tensor(state_indices,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
mamba_cache_tensors = self.mamba_cache
|
||||
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
(mamba_cache_tensors,
|
||||
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
|
||||
cache_tensors, state_indices_tensor = super().current_run_tensors(
|
||||
**kwargs)
|
||||
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant state_indices into the CUDA graph input buffer
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
||||
_, input_state_indices_buffer = input_buffers[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
||||
state_indices)
|
||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
||||
|
||||
input_state_indices_buffer.copy_(
|
||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
return (self.mamba_cache, state_indices_tensor)
|
||||
|
||||
def _copy_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
||||
finished_requests_ids) -> int:
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
if cur_rid in finished_requests_ids:
|
||||
# set as pad, do not allocate destination index
|
||||
return PAD_SLOT_ID
|
||||
elif cur_rid not in self.mamba_cache_indices_mapping:
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {
|
||||
seq_id: destination_index
|
||||
}
|
||||
return destination_index
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened, so we copy the
|
||||
# existing cache into the siblings seq_ids caches
|
||||
index_exists = next(iter(seq_ids2indices.values()))
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self._copy_mamba_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = destination_index
|
||||
return destination_index
|
||||
else:
|
||||
# already exists
|
||||
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]) -> List[int]:
|
||||
return [
|
||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
||||
finished_requests_ids)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.mamba_cache_indices_mapping:
|
||||
for seq_id in self.mamba_cache_indices_mapping[req_id]:
|
||||
self.free_cache_indices.append(
|
||||
self.mamba_cache_indices_mapping[req_id][seq_id])
|
||||
self.mamba_cache_indices_mapping.pop(req_id)
|
||||
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
|
||||
35
vllm/model_executor/models/minimax_cache.py
Normal file
35
vllm/model_executor/models/minimax_cache.py
Normal file
@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimaxCacheParams:
|
||||
minimax_cache: torch.Tensor = torch.Tensor()
|
||||
state_indices_tensor: torch.Tensor = torch.Tensor()
|
||||
|
||||
def at_layer_idx(self, layer_idx):
|
||||
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MinimaxCacheManager(ConstantSizeCache):
|
||||
|
||||
def __init__(self, dtype, cache_shape):
|
||||
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
|
||||
self._minimax_cache = torch.empty(size=cache_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._minimax_cache
|
||||
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.cache) > 0
|
||||
for cache_t in self.cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
1273
vllm/model_executor/models/minimax_text_01.py
Normal file
1273
vllm/model_executor/models/minimax_text_01.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -35,6 +35,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
# baichuan-7b, upper case 'C' in the class name
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
|
||||
# baichuan-13b, lower case 'c' in the class name
|
||||
|
||||
Reference in New Issue
Block a user