Compare commits
1 Commits
khluu/try_
...
whisper-tr
| Author | SHA1 | Date | |
|---|---|---|---|
| d3eddd6ef1 |
@ -15,12 +15,14 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
---
|
||||
|
||||
[2025/03] We are collaborating with Ollama to host an [Inference Night](https://lu.ma/vllm-ollama) at Y Combinator in San Francisco on Thursday, March 27, at 6 PM. Discuss all things inference local or data center!
|
||||
|
||||
[2025/04] We're hosting our first-ever *vLLM Asia Developer Day* in Singapore on *April 3rd*! This is a full-day event (9 AM - 9 PM SGT) in partnership with SGInnovate, AMD, and Embedded LLM. Meet the vLLM team and learn about LLM inference for RL, MI300X, and more! [Register Now](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
||||
|
||||
@ -4,8 +4,6 @@
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [The first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg), March 16th 2025. [[Slides]](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [The East Coast vLLM Meetup](https://lu.ma/7mu4k4xx), March 11th 2025. [[Slides]](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0)
|
||||
- [The ninth vLLM meetup](https://lu.ma/h7g3kuj9), with Meta, February 27th 2025. [[Slides]](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing)
|
||||
- [The eighth vLLM meetup](https://lu.ma/zep56hui), with Google Cloud, January 22nd 2025. [[Slides]](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing)
|
||||
|
||||
@ -503,11 +503,6 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `MiniMaxText01ForCausalLM`
|
||||
* MiniMax-Text
|
||||
* `MiniMaxAI/MiniMax-Text-01`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
- * `Zamba2ForCausalLM`
|
||||
* Zamba2
|
||||
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# Usage Stats Collection
|
||||
|
||||
vLLM collects anonymous usage data by default to help the engineering team better understand which hardware and model configurations are widely used. This data allows them to prioritize their efforts on the most common workloads. The collected data is transparent, does not contain any sensitive information.
|
||||
|
||||
A subset of the data, after cleaning and aggregation, will be publicly released for the community's benefit. For example, you can see the 2024 usage report [here](https://2024.vllm.ai).
|
||||
vLLM collects anonymous usage data by default to help the engineering team better understand which hardware and model configurations are widely used. This data allows them to prioritize their efforts on the most common workloads. The collected data is transparent, does not contain any sensitive information, and will be publicly released for the community's benefit.
|
||||
|
||||
## What data is collected?
|
||||
|
||||
|
||||
@ -1,286 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
linear_decode_forward_triton)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [4, 8]
|
||||
HEAD_SIZES = [64]
|
||||
BATCH_SIZES = [1, 2]
|
||||
SEQ_LENGTHS = [16]
|
||||
DTYPES = [torch.float32]
|
||||
|
||||
|
||||
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
"""Reference implementation of lightning attention core algorithm
|
||||
|
||||
The difference from the main implementation is that this processes
|
||||
each step sequentially, instead of using parallelized triton kernels
|
||||
"""
|
||||
B, H, S, D = q.shape
|
||||
E = v.shape[-1]
|
||||
dtype = q.dtype
|
||||
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
|
||||
|
||||
# Use clone() to ensure an independent copy
|
||||
if kv_history is None:
|
||||
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
|
||||
else:
|
||||
kv_cache = kv_history.clone()
|
||||
|
||||
# More efficient implementation
|
||||
# Convert decay factors to matrix form
|
||||
if ed.dim() == 1:
|
||||
decay = torch.exp(-ed).view(1, -1, 1, 1)
|
||||
else:
|
||||
decay = torch.exp(-ed)
|
||||
|
||||
for b in range(B):
|
||||
for step in range(S):
|
||||
# Process all heads at once for this position
|
||||
q_bs = q[b, :, step] # [H, D]
|
||||
k_bs = k[b, :, step] # [H, D]
|
||||
v_bs = v[b, :, step] # [H, E]
|
||||
|
||||
# Calculate KV outer products for all heads
|
||||
for h in range(H):
|
||||
# Calculate KV outer product
|
||||
kv_outer = torch.outer(k_bs[h], v_bs[h])
|
||||
|
||||
# Update KV cache with decay
|
||||
# Note: Using the same order as in the Triton kernel
|
||||
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
|
||||
|
||||
# Calculate attention output
|
||||
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
|
||||
|
||||
# Match the shape returned by the actual implementation
|
||||
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
|
||||
# where dimension 2 contains both KV and KV history
|
||||
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
|
||||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
|
||||
dim=2) # [B, H, 2, D, E]
|
||||
|
||||
return output, final_kv_cache
|
||||
|
||||
|
||||
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
|
||||
"""Reference implementation: linear attention decode function"""
|
||||
B, H, _, D = q.shape
|
||||
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
|
||||
|
||||
# Calculate decay factors once (more efficient)
|
||||
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
|
||||
|
||||
# Process each batch
|
||||
for b in range(B):
|
||||
slot_id = slot_idx[b].item()
|
||||
|
||||
# Skip padding positions
|
||||
if slot_id == -1:
|
||||
continue
|
||||
|
||||
# Process all heads at once for this batch
|
||||
q_b = q[b, :, 0] # [H, D]
|
||||
k_b = k[b, :, 0] # [H, D]
|
||||
v_b = v[b, :, 0] # [H, D]
|
||||
|
||||
# Process each attention head
|
||||
for h in range(H):
|
||||
# Get current query, key and value
|
||||
q_bh = q_b[h]
|
||||
k_bh = k_b[h]
|
||||
v_bh = v_b[h]
|
||||
|
||||
# Get cache
|
||||
kv_cache_old = kv_caches[b, h]
|
||||
|
||||
# Calculate new key-value outer product
|
||||
kv_outer = torch.outer(k_bh, v_bh)
|
||||
|
||||
# Apply decay and update cache
|
||||
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
|
||||
|
||||
# Calculate output
|
||||
out_h = torch.matmul(q_bh, kv_new)
|
||||
|
||||
# Update output and cache
|
||||
output[b, h * D:(h + 1) * D] = out_h
|
||||
kv_caches[b, h] = kv_new
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_linear_decode_forward_triton(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.arange(batch_size, device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
torch.testing.assert_close(triton_output,
|
||||
reference_output,
|
||||
rtol=1e-1,
|
||||
atol=1e-1)
|
||||
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_linear_decode_forward_triton_with_padding(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
batch_size = 4
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
|
||||
padding_mask = (slot_idx
|
||||
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
||||
|
||||
triton_masked = triton_output[padding_mask]
|
||||
reference_masked = reference_output[padding_mask]
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
|
||||
valid_indices = slot_idx != -1
|
||||
|
||||
for i in range(batch_size):
|
||||
if valid_indices[i] > 0:
|
||||
torch.testing.assert_close(kv_caches[i],
|
||||
kv_caches_copy[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
torch.testing.assert_close(triton_masked,
|
||||
reference_masked,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_lightning_attention_reference(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
base = 0.01
|
||||
q = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
k = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
|
||||
ed = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
ed[h] = 0.1 * (h + 1)
|
||||
|
||||
kv_history = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
kv_history_clone = kv_history.clone()
|
||||
|
||||
ref_output, ref_kv_cache = reference_lightning_attention(
|
||||
q, k, v, ed, 256, kv_history)
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import lightning_attention
|
||||
actual_output, actual_kv_cache = lightning_attention(
|
||||
q, k, v, ed, 256, kv_history_clone)
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(ref_kv_cache,
|
||||
actual_kv_cache,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
|
||||
assert ref_kv_cache.shape == actual_kv_cache.shape
|
||||
@ -176,8 +176,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
|
||||
trust_remote_code=True),
|
||||
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
|
||||
trust_remote_code=True),
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
|
||||
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
|
||||
|
||||
@ -971,34 +971,26 @@ class ModelConfig:
|
||||
return sum(not bc.attention.no_op
|
||||
for bc in block_configs[start:end])
|
||||
else:
|
||||
# Hybrid model Jamba
|
||||
# Hybrid model
|
||||
layers_block_type_value = getattr(self.hf_config,
|
||||
"layers_block_type", None)
|
||||
if layers_block_type_value is not None:
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
if layers_block_type_value is None:
|
||||
raise ValueError("The model is an hybrid without a "
|
||||
"layers_block_type in the hf_config, "
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
# Hybrid model Minimax
|
||||
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
|
||||
if attn_type_list:
|
||||
return sum(t == 1 for t in attn_type_list[start:end])
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
|
||||
if layers_block_type_value is None and attn_type_list is None:
|
||||
raise ValueError(
|
||||
"The model is an hybrid without a"
|
||||
"layers_block_type or an attn_type_list in the hf_config,"
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
return sum(t == 1 for t in attn_type_list[start:end])
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
|
||||
@ -303,11 +303,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# this will cause mamba_cache/minimax_cache failed
|
||||
# to release finished_requests_ids of the last steps
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
|
||||
@ -67,6 +67,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
UnloadLoRAAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
@ -80,7 +82,7 @@ from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranscription)
|
||||
OpenAIServingTranscription, OpenAIServingTranslation)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||
with_cancellation)
|
||||
@ -383,6 +385,10 @@ def transcription(request: Request) -> OpenAIServingTranscription:
|
||||
return request.app.state.openai_serving_transcription
|
||||
|
||||
|
||||
def translation(request: Request) -> OpenAIServingTranslation:
|
||||
return request.app.state.openai_serving_translation
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
@ -625,6 +631,31 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/audio/translations")
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_translations(request: Annotated[TranslationRequest,
|
||||
Form()],
|
||||
raw_request: Request):
|
||||
handler = translation(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Translations API")
|
||||
|
||||
audio_data = await request.file.read()
|
||||
generator = await handler.create_translation(audio_data, request,
|
||||
raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, TranslationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
@ -1098,10 +1129,9 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
try:
|
||||
await shutdown_task
|
||||
finally:
|
||||
sock.close()
|
||||
await shutdown_task
|
||||
|
||||
sock.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1652,3 +1652,196 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||
|
||||
words: Optional[list[TranscriptionWord]] = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class TranslationStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
||||
object: Literal["translation.chunk"] = "translation.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranslationResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class TranslationRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
|
||||
file: UploadFile
|
||||
"""
|
||||
The audio file object (not file name) to translate, in one of these
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
language: Optional[str] = None
|
||||
"""The language of the input audio.
|
||||
|
||||
Supplying the input language in
|
||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||
will improve accuracy and latency.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
|
||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||
should match the audio language.
|
||||
"""
|
||||
|
||||
response_format: AudioResponseFormat = Field(default="json")
|
||||
"""
|
||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||
`verbose_json`, or `vtt`.
|
||||
"""
|
||||
|
||||
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
Higher values like 0.8 will make the output more random, while lower values
|
||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||
to automatically increase the temperature until certain thresholds are hit.
|
||||
"""
|
||||
|
||||
timestamp_granularities: list[Literal["word", "segment"]] = Field(
|
||||
alias="timestamp_granularities[]", default=[])
|
||||
"""The timestamp granularities to populate for this translation.
|
||||
|
||||
`response_format` must be set `verbose_json` to use timestamp granularities.
|
||||
Either or both of these options are supported: `word`, or `segment`. Note:
|
||||
There is no additional latency for segment timestamps, but generating word
|
||||
timestamps incurs additional latency.
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
Completion endpoint.
|
||||
"""
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: Optional[bool] = False
|
||||
stream_continuous_usage_stats: Optional[bool] = False
|
||||
|
||||
# Default sampling parameters for translation requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
|
||||
return SamplingParams.from_optional(temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Translation response objects
|
||||
class TranslationResponse(OpenAIBaseModel):
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
|
||||
class TranslationWord(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the word in seconds."""
|
||||
|
||||
start: float
|
||||
"""Start time of the word in seconds."""
|
||||
|
||||
word: str
|
||||
"""The text content of the word."""
|
||||
|
||||
|
||||
class TranslationSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
"""
|
||||
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
this segment silent.
|
||||
"""
|
||||
|
||||
seek: int
|
||||
"""Seek offset of the segment."""
|
||||
|
||||
start: float
|
||||
"""Start time of the segment in seconds."""
|
||||
|
||||
temperature: float
|
||||
"""Temperature parameter used for generating the segment."""
|
||||
|
||||
text: str
|
||||
"""Text content of the segment."""
|
||||
|
||||
tokens: list[int]
|
||||
"""Array of token IDs for the text content."""
|
||||
|
||||
|
||||
class TranslationResponseVerbose(OpenAIBaseModel):
|
||||
duration: str
|
||||
"""The duration of the input audio."""
|
||||
|
||||
language: str
|
||||
"""The language of the input audio."""
|
||||
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
segments: Optional[list[TranslationSegment]] = None
|
||||
"""Segments of the translated text and their corresponding details."""
|
||||
|
||||
words: Optional[list[TranslationWord]] = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
@ -4,7 +4,7 @@ import io
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from math import ceil
|
||||
from typing import Final, Optional, Union, cast
|
||||
from typing import Callable, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
@ -14,7 +14,8 @@ from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, UsageInfo)
|
||||
TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
|
||||
TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
@ -30,7 +31,7 @@ except ImportError:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
||||
# TODO these configs should live somewhere with the model so we can support
|
||||
# additional ones
|
||||
|
||||
@ -144,16 +145,19 @@ ISO639_1_OTHER_LANGS = {
|
||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAIServing):
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
"""Base class for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
task_type: str = "transcribe", # or "translate"
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
@ -167,15 +171,16 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
self.max_audio_clip_s = processor.feature_extractor.chunk_length
|
||||
self.model_sr = processor.feature_extractor.sampling_rate
|
||||
self.hop_length = processor.feature_extractor.hop_length
|
||||
self.task_type = task_type
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
|
||||
async def _preprocess_transcription(
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: TranscriptionRequest,
|
||||
request: Union[TranscriptionRequest, TranslationRequest],
|
||||
audio_data: bytes,
|
||||
) -> tuple[PromptType, float]:
|
||||
# Validate request
|
||||
@ -218,21 +223,22 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||
(f"<|startoftranscript|>{lang_token}"
|
||||
f"<|{self.task_type}|><|notimestamps|>{request.prompt}")
|
||||
}
|
||||
return cast(PromptType, prompt), duration
|
||||
|
||||
# TODO (varun) : Make verbose response work !
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
||||
ErrorResponse]:
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
async def _create_speech_to_text(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: Union[TranscriptionRequest, TranslationRequest],
|
||||
raw_request: Request,
|
||||
response_class: Union[TranscriptionResponse, TranslationResponse],
|
||||
stream_generator_method: Callable,
|
||||
) -> Union[Union[TranscriptionResponse, TranslationResponse],
|
||||
AsyncGenerator[str, None], ErrorResponse]:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -247,7 +253,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`")
|
||||
|
||||
request_id = f"trsc-{self._base_request_id(raw_request)}"
|
||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
@ -261,13 +267,14 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
if lora_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support LoRA for Transcription.")
|
||||
"Currently do not support LoRA for "
|
||||
f"{self.task_type.title()}.")
|
||||
if prompt_adapter_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support PromptAdapter for Transcription."
|
||||
)
|
||||
f"Currently do not support PromptAdapter for "
|
||||
f"{self.task_type.title()}.")
|
||||
|
||||
prompt, duration_s = await self._preprocess_transcription(
|
||||
prompt, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
@ -300,31 +307,36 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if request.stream:
|
||||
return self.transcription_stream_generator(request,
|
||||
result_generator,
|
||||
request_id,
|
||||
request_metadata,
|
||||
duration_s)
|
||||
return stream_generator_method(request, result_generator,
|
||||
request_id, request_metadata,
|
||||
duration_s)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert result_generator is not None
|
||||
async for op in result_generator:
|
||||
result = op
|
||||
return TranscriptionResponse(text=result.outputs[0].text)
|
||||
return response_class(text=result.outputs[0].text)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self, request: TranscriptionRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
request: Union[TranscriptionRequest, TranslationRequest],
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
chunk_object_type: str,
|
||||
response_stream_choice_class: Union[TranscriptionResponseStreamChoice,
|
||||
TranslationResponseStreamChoice],
|
||||
stream_response_class: Union[TranscriptionStreamResponse,
|
||||
TranslationStreamResponse],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
chunk_object_type: Final = "transcription.chunk"
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
@ -361,20 +373,20 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
chunk = TranscriptionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
chunk = stream_response_class(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
@ -395,7 +407,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
|
||||
final_usage_chunk = TranscriptionStreamResponse(
|
||||
final_usage_chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
@ -414,8 +426,115 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in chat completion stream generator.")
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
"""Handles transcription requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="transcribe")
|
||||
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
||||
ErrorResponse]:
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranscriptionResponse,
|
||||
stream_generator_method=self.transcription_stream_generator,
|
||||
)
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self, request: TranscriptionRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
return await self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="transcription.chunk",
|
||||
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||
stream_response_class=TranscriptionStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate")
|
||||
|
||||
async def create_translation(
|
||||
self, audio_data: bytes, request: TranslationRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranslationResponse,
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
async def translation_stream_generator(
|
||||
self, request: TranslationRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
return await self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
)
|
||||
|
||||
@ -1,651 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
||||
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
|
||||
NUM_BLOCK, CBLOCK: tl.constexpr):
|
||||
# This kernel computes the diagonal blocks of the attention matrix
|
||||
# Each diagonal block represents attention
|
||||
# where queries attend to keys in the same block
|
||||
off = tl.program_id(0)
|
||||
off_bh = off // NUM_BLOCK # batch-head index
|
||||
off_block = off % NUM_BLOCK # block index within the sequence
|
||||
off_cblock = tl.program_id(1) # sub-block index within a block
|
||||
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
# Calculate base offsets for the current batch and head
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
|
||||
# Calculate offsets for the current block
|
||||
block_offset = off_block * BLOCK
|
||||
qk_block_offset = block_offset * d
|
||||
v_block_offset = block_offset * e
|
||||
o_block_offset = block_offset * e
|
||||
|
||||
# Calculate offsets for the current sub-block
|
||||
cblock_offset = off_cblock * CBLOCK
|
||||
q_cblock_offset = cblock_offset * d
|
||||
o_cblock_offset = cblock_offset * e
|
||||
|
||||
# Calculate pointers to the query, key, value, and output tensors
|
||||
Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * d +
|
||||
tl.arange(0, d)[None, :])
|
||||
K_trans_block_ptr = (K + qk_offset + qk_block_offset +
|
||||
tl.arange(0, CBLOCK)[None, :] * d +
|
||||
tl.arange(0, d)[:, None])
|
||||
V_block_ptr = (V + v_offset + v_block_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, e)[None, :])
|
||||
O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, e)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
S_block_ptr = S + off_h
|
||||
s = tl.load(S_block_ptr)
|
||||
|
||||
i = off_cblock
|
||||
q_index = tl.arange(0, CBLOCK) + i * CBLOCK
|
||||
|
||||
# Load query values
|
||||
q = tl.load(Q_block_ptr,
|
||||
mask=block_offset + q_index[:, None] < n,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# Initialize output accumulator
|
||||
qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
|
||||
|
||||
# Process all sub-blocks up to and
|
||||
# including the current one (causal attention)
|
||||
for j in range(i + 1):
|
||||
kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
|
||||
diff = q_index[:, None] - kv_index[None, :]
|
||||
s_index = s * diff
|
||||
# Apply causal mask: only attend to positions before the current one
|
||||
s_index = tl.where(diff >= 0, -s_index, float("-inf"))
|
||||
decay = tl.exp(s_index)
|
||||
|
||||
# Load key and value
|
||||
k_trans = tl.load(
|
||||
K_trans_block_ptr,
|
||||
mask=block_offset + kv_index[None, :] < n,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
v = tl.load(
|
||||
V_block_ptr,
|
||||
mask=block_offset + kv_index[:, None] < n,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# Compute attention scores and apply decay
|
||||
qk = tl.dot(q, k_trans) * decay
|
||||
|
||||
# Compute weighted values and accumulate
|
||||
qkv += tl.dot(qk, v)
|
||||
|
||||
# Move to the next sub-block
|
||||
K_trans_block_ptr += CBLOCK * d
|
||||
V_block_ptr += CBLOCK * e
|
||||
|
||||
# Store the result
|
||||
tl.store(
|
||||
O_block_ptr,
|
||||
qkv.to(O_block_ptr.dtype.element_ty),
|
||||
mask=block_offset + q_index[:, None] < n,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kv_parallel(
|
||||
K,
|
||||
V,
|
||||
K_decay,
|
||||
KV,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK,
|
||||
D_FBLOCK: tl.constexpr,
|
||||
E_FBLOCK: tl.constexpr,
|
||||
NUM_FBLOCK: tl.constexpr,
|
||||
CBLOCK: tl.constexpr,
|
||||
NUM_CBLOCK: tl.constexpr,
|
||||
):
|
||||
# This kernel computes the key-value outer
|
||||
# products for each block in parallel
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_block = tl.program_id(1) # block index
|
||||
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
block_offset = off_block * BLOCK
|
||||
|
||||
# Calculate offsets for the current block
|
||||
k_block_offset = block_offset * d
|
||||
v_block_offset = block_offset * e
|
||||
kv_block_offset = off_block * d * e
|
||||
|
||||
# Calculate base offsets for the current batch and head
|
||||
k_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e
|
||||
|
||||
# Calculate pointers to the key, value, and key-value tensors
|
||||
K_trans_block_ptr = (K + k_offset + k_block_offset +
|
||||
tl.arange(0, CBLOCK)[None, :] * d +
|
||||
tl.arange(0, D_FBLOCK)[:, None])
|
||||
V_block_ptr = (V + v_offset + v_block_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (KV + kv_offset + kv_block_offset +
|
||||
tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay factors for the current head and block
|
||||
k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :])
|
||||
|
||||
kv_index = tl.arange(0, CBLOCK)
|
||||
|
||||
# Initialize the key-value outer product accumulator
|
||||
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
|
||||
|
||||
# Handle the last block which might be smaller than BLOCK
|
||||
if off_block == NUM_BLOCK - 1:
|
||||
split_n = n - (NUM_BLOCK - 1) * BLOCK
|
||||
else:
|
||||
split_n = BLOCK
|
||||
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
|
||||
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
|
||||
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
|
||||
|
||||
# Process all sub-blocks in the current block
|
||||
for j in range(num_blocks):
|
||||
left_bound = (1 - j) * left_shift
|
||||
# Load key and value, handling boundary conditions
|
||||
k_trans = tl.load(K_trans_block_ptr - left_shift * d,
|
||||
mask=kv_index[None, :] >= left_bound,
|
||||
other=0.0)
|
||||
v = tl.load(V_block_ptr - left_shift * e,
|
||||
mask=kv_index[:, None] >= left_bound,
|
||||
other=0.0)
|
||||
|
||||
# Load decay factor and compute weighted key-value outer product
|
||||
k_decay = tl.load(k_decay_ptr)
|
||||
kv += tl.dot(k_trans * k_decay, v)
|
||||
|
||||
# Move to the next sub-block
|
||||
K_trans_block_ptr += CBLOCK * d
|
||||
V_block_ptr += CBLOCK * e
|
||||
k_decay_ptr += CBLOCK
|
||||
|
||||
# Store the result
|
||||
tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
|
||||
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
|
||||
NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr):
|
||||
# This kernel reduces the key-value outer products
|
||||
# across blocks and updates the KV history
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e
|
||||
|
||||
# Calculate pointer to the key-value tensor
|
||||
KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
s_ptrs = S + off_h
|
||||
s = tl.load(s_ptrs)
|
||||
|
||||
# Calculate pointer to the key-value history tensor
|
||||
kv_history_offset = off_bh * d * e
|
||||
KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset +
|
||||
tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the previous key-value history
|
||||
kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
|
||||
|
||||
# Process all blocks in reverse order to compute the prefix sum
|
||||
for i in range(NUM_BLOCK):
|
||||
block_size = min(n - i * BLOCK, BLOCK)
|
||||
# Compute decay factor for the current block
|
||||
block_decay = tl.exp(-s.to(tl.float32) * block_size)
|
||||
|
||||
# Load the current key-value outer product
|
||||
kv_cur = tl.load(KV_block_ptr).to(tl.float32)
|
||||
# Store the previous key-value history to the current block
|
||||
tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
|
||||
|
||||
# Update the key-value history with the current block
|
||||
kv_pre = block_decay * kv_pre + kv_cur
|
||||
KV_block_ptr += d * e
|
||||
|
||||
# Store the updated key-value history
|
||||
tl.store(KV_HISTORY_block_ptr, kv_pre)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_none_diag_kernel(
|
||||
Q,
|
||||
Out,
|
||||
S,
|
||||
KV,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK,
|
||||
E_FBLOCK: tl.constexpr,
|
||||
CBLOCK: tl.constexpr,
|
||||
NUM_CBLOCK: tl.constexpr,
|
||||
):
|
||||
# This kernel computes the non-diagonal blocks of the attention matrix
|
||||
# Each non-diagonal block represents attention
|
||||
# where queries attend to keys in different blocks
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
off_h = off_bh % h # head index
|
||||
|
||||
off_nc = tl.program_id(1)
|
||||
off_n = off_nc // NUM_CBLOCK # block index
|
||||
off_c = off_nc % NUM_CBLOCK # sub-block index
|
||||
off_e = tl.program_id(2) # output feature block index
|
||||
|
||||
n_offset = off_n * BLOCK
|
||||
c_offset = off_c * CBLOCK
|
||||
e_offset = off_e * E_FBLOCK
|
||||
block_offset = n_offset + c_offset
|
||||
|
||||
# Calculate offsets for the current batch, head, and block
|
||||
q_offset = off_bh * n * d + (n_offset + c_offset) * d
|
||||
o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
|
||||
|
||||
# Calculate pointers to the query, output, and key-value tensors
|
||||
Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d +
|
||||
tl.arange(0, d)[None, :])
|
||||
O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
|
||||
# Load the decay rate for the current head
|
||||
S_block_ptr = S + off_h
|
||||
s = tl.load(S_block_ptr)
|
||||
|
||||
c_array = tl.arange(0, CBLOCK)
|
||||
|
||||
# Load the key-value outer product for the current block
|
||||
kv = tl.load(KV_block_ptr).to(tl.float32)
|
||||
q_index = block_offset + tl.arange(0, CBLOCK)
|
||||
|
||||
# Load query values
|
||||
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n,
|
||||
other=0.).to(tl.float32)
|
||||
|
||||
# Compute decay factors for the current sub-block
|
||||
q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
|
||||
|
||||
# Compute non-diagonal attention output
|
||||
qkv_none_diag = tl.dot(q, kv) * q_decay
|
||||
|
||||
# Load diagonal attention output (computed by _fwd_diag_kernel)
|
||||
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n,
|
||||
other=0.).to(tl.float32)
|
||||
|
||||
# Combine diagonal and non-diagonal attention outputs
|
||||
qkv = qkv_diag + qkv_none_diag
|
||||
|
||||
# Store the result
|
||||
tl.store(O_block_ptr,
|
||||
qkv.to(O_block_ptr.dtype.element_ty),
|
||||
mask=q_index[:, None] < n)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, s, kv_history):
|
||||
# Forward pass of the lightning attention algorithm
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
s = s.contiguous()
|
||||
|
||||
# Check CUDA compute capability
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported",
|
||||
"for compute capability >= 80")
|
||||
|
||||
# Get input dimensions
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
|
||||
# Initialize output tensor
|
||||
o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
|
||||
|
||||
# Set block sizes
|
||||
BLOCK = 256
|
||||
NUM_BLOCK = triton.cdiv(n, BLOCK)
|
||||
|
||||
CBLOCK = 32
|
||||
NUM_CBLOCK = BLOCK // CBLOCK
|
||||
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
|
||||
|
||||
# Compute decay factors for keys
|
||||
array = torch.arange(0, BLOCK, device=q.device) + 1
|
||||
k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1)))
|
||||
|
||||
# Step 1: Compute diagonal blocks of attention
|
||||
grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
|
||||
_fwd_diag_kernel[grid](q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
s,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
CBLOCK=CBLOCK)
|
||||
|
||||
# Set feature block sizes
|
||||
NUM_FBLOCK = 1
|
||||
D_FBLOCK = d // NUM_FBLOCK
|
||||
assert d % NUM_FBLOCK == 0
|
||||
E_FBLOCK = e // NUM_FBLOCK
|
||||
assert e % NUM_FBLOCK == 0
|
||||
|
||||
CBLOCK = 64
|
||||
NUM_CBLOCK = BLOCK // CBLOCK
|
||||
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
|
||||
|
||||
# Step 2: Compute key-value outer products for each block in parallel
|
||||
kv = torch.empty((b, h, NUM_BLOCK, d, e),
|
||||
dtype=torch.float32,
|
||||
device=q.device)
|
||||
grid = (b * h, NUM_BLOCK)
|
||||
_fwd_kv_parallel[grid](
|
||||
k,
|
||||
v,
|
||||
k_decay,
|
||||
kv,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
D_FBLOCK=D_FBLOCK,
|
||||
E_FBLOCK=E_FBLOCK,
|
||||
NUM_FBLOCK=NUM_FBLOCK,
|
||||
CBLOCK=CBLOCK,
|
||||
NUM_CBLOCK=NUM_CBLOCK,
|
||||
)
|
||||
|
||||
# Step 3: Reduce key-value outer products
|
||||
# across blocks and update KV history
|
||||
grid = (b * h, NUM_FBLOCK)
|
||||
_fwd_kv_reduce[grid](s,
|
||||
kv,
|
||||
kv_history,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
D_FBLOCK=D_FBLOCK,
|
||||
E_FBLOCK=E_FBLOCK)
|
||||
|
||||
# Step 4: Compute non-diagonal blocks of attention
|
||||
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
|
||||
_fwd_none_diag_kernel[grid](
|
||||
q,
|
||||
o,
|
||||
s,
|
||||
kv,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
E_FBLOCK=E_FBLOCK,
|
||||
CBLOCK=CBLOCK,
|
||||
NUM_CBLOCK=NUM_CBLOCK,
|
||||
)
|
||||
|
||||
# Save tensors for backward pass
|
||||
ctx.save_for_backward(q, k, v, s, kv)
|
||||
ctx.BLOCK = BLOCK
|
||||
|
||||
return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2)
|
||||
|
||||
|
||||
# Apply the lightning attention function
|
||||
lightning_attention_ = _attention.apply
|
||||
|
||||
|
||||
def lightning_attention(q, k, v, ed, block_size=256, kv_history=None):
|
||||
"""
|
||||
Apply lightning attention algorithm
|
||||
to compute attention efficiently.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [batch, heads, seq_len, dim]
|
||||
k: Key tensor of shape [batch, heads, seq_len, dim]
|
||||
v: Value tensor of shape [batch, heads, seq_len, dim_v]
|
||||
ed: Decay rate tensor of shape [heads]
|
||||
block_size: Size of blocks for block-sparse attention
|
||||
kv_history: Optional key-value history from previous computations
|
||||
|
||||
Returns:
|
||||
output: Attention output
|
||||
kv: Updated key-value history
|
||||
"""
|
||||
d = q.shape[-1]
|
||||
e = v.shape[-1]
|
||||
|
||||
if ed.dim() == 1:
|
||||
ed = ed.view(1, -1, 1, 1)
|
||||
|
||||
# Split the computation into chunks for better parallelism
|
||||
m = 128 if d >= 128 else 64
|
||||
assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
|
||||
arr = [m * i for i in range(d // m + 1)]
|
||||
if arr[-1] != d:
|
||||
arr.append(d)
|
||||
n = len(arr)
|
||||
output = 0
|
||||
|
||||
# Initialize or clone key-value history
|
||||
if kv_history is None:
|
||||
kv_history = torch.zeros((q.shape[0], q.shape[1], d, e),
|
||||
dtype=torch.float32,
|
||||
device=q.device)
|
||||
else:
|
||||
kv_history = kv_history.clone().contiguous()
|
||||
|
||||
# Process each chunk and accumulate results
|
||||
for i in range(n - 1):
|
||||
s = arr[i]
|
||||
e = arr[i + 1]
|
||||
q1 = q[..., s:e]
|
||||
k1 = k[..., s:e]
|
||||
o, kv = lightning_attention_(q1, k1, v, ed, kv_history)
|
||||
output = output + o
|
||||
return output, kv
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _linear_attn_decode_kernel(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
kv_cache_ptr,
|
||||
slope_rate,
|
||||
slot_idx,
|
||||
output_ptr,
|
||||
D: tl.constexpr,
|
||||
qkv_b_stride,
|
||||
qkv_h_stride,
|
||||
cache_b_stride,
|
||||
cache_h_stride,
|
||||
cache_d0_stride,
|
||||
cache_d1_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for linear attention decoding with KV cache.
|
||||
|
||||
This kernel computes attention for a single token using the KV cache.
|
||||
"""
|
||||
pid_b = tl.program_id(0) # batch index
|
||||
pid_h = tl.program_id(1) # head index
|
||||
pid_d = tl.program_id(2) # dimension block index
|
||||
|
||||
# Load slot index for the current batch
|
||||
slot_id = tl.load(slot_idx + pid_b)
|
||||
|
||||
# Skip if slot_id is -1 (padding)
|
||||
if slot_id == -1:
|
||||
return
|
||||
|
||||
batch_id = pid_b
|
||||
head_id = pid_h
|
||||
|
||||
# Load decay rate for the current head
|
||||
ratio = tl.load(slope_rate + pid_h)
|
||||
|
||||
# Calculate offsets for dimensions
|
||||
qk_d_offsets = tl.arange(0, D)
|
||||
v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
|
||||
cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[
|
||||
None, :] * cache_d1_stride
|
||||
|
||||
# Calculate offsets for the current batch and head
|
||||
q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
|
||||
cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
|
||||
|
||||
# Create masks for loading tensors
|
||||
qk_mask = qk_d_offsets < D
|
||||
v_mask = v_d_offsets < D
|
||||
|
||||
# Load query, key, and value tensors
|
||||
q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
|
||||
k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
|
||||
v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
|
||||
|
||||
# Compute key-value outer product
|
||||
kv_outer = k[:, None] * v[None, :]
|
||||
kv_mask = qk_mask[:, None] & v_mask[None, :]
|
||||
|
||||
# Apply decay to previous KV cache
|
||||
ratio = tl.exp(-ratio)
|
||||
kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
|
||||
kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
|
||||
kv_outer = kv_outer + ratio * kv_cache_old
|
||||
|
||||
# Compute attention output
|
||||
output = q[:, None].to(tl.float32) * kv_outer
|
||||
output = tl.sum(output, axis=0)
|
||||
|
||||
# Update KV cache and store output
|
||||
tl.store(kv_ptr, kv_outer, mask=kv_mask)
|
||||
tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
|
||||
|
||||
|
||||
def linear_decode_forward_triton(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
BLOCK_SIZE: int = 32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform linear attention decoding using Triton kernels.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [B, H, 1, D]
|
||||
k: Key tensor of shape [B, H, 1, D]
|
||||
v: Value tensor of shape [B, H, 1, D]
|
||||
kv_caches: Key-value cache tensor
|
||||
slope_rate: Decay rate tensor
|
||||
slot_idx: Slot indices for batches
|
||||
BLOCK_SIZE: Size of blocks for processing
|
||||
|
||||
Returns:
|
||||
output: Attention output tensor
|
||||
"""
|
||||
B, H, _, D = q.shape
|
||||
assert k.shape == (B, H, 1, D)
|
||||
assert v.shape == (B, H, 1, D)
|
||||
|
||||
# Initialize output tensor
|
||||
output = torch.empty_like(q)
|
||||
|
||||
# Set grid dimensions for the kernel
|
||||
grid = (B, H, D // BLOCK_SIZE)
|
||||
|
||||
# Calculate strides for tensors
|
||||
qkv_b_stride = q.stride(0)
|
||||
qkv_h_stride = q.stride(1)
|
||||
|
||||
cache_b_stride = kv_caches.stride(0)
|
||||
cache_h_stride = kv_caches.stride(1)
|
||||
cache_d0_stride = kv_caches.stride(2)
|
||||
cache_d1_stride = kv_caches.stride(3)
|
||||
|
||||
# Launch the kernel
|
||||
_linear_attn_decode_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_caches,
|
||||
slope_rate,
|
||||
slot_idx,
|
||||
output,
|
||||
D,
|
||||
qkv_b_stride,
|
||||
qkv_h_stride,
|
||||
cache_b_stride,
|
||||
cache_h_stride,
|
||||
cache_d0_stride,
|
||||
cache_d1_stride,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Reshape output and return
|
||||
output = rearrange(output, "b h n d -> b n (h d)")
|
||||
return output.squeeze(1).contiguous()
|
||||
@ -1,136 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
|
||||
class ConstantSizeCache(ABC):
|
||||
"""
|
||||
Abstract base class for managing constant size caches
|
||||
like Mamba and Minimax.
|
||||
"""
|
||||
|
||||
def __init__(self, max_batch_size: int):
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the cache
|
||||
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
self.free_cache_indices = list(range(max_batch_size))
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cache(self) -> Any:
|
||||
"""Return the underlying cache tensor(s)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
"""Copy cache data from one index to another"""
|
||||
pass
|
||||
|
||||
def current_run_tensors(self, **kwargs) -> Tuple:
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
state_indices_tensor = torch.as_tensor(state_indices,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
cache_tensors = self.cache
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
cache_tensors, state_indices_tensor = kwargs[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return (cache_tensors, state_indices_tensor)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant state_indices into the CUDA graph input buffer
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
||||
_, input_state_indices_buffer = input_buffers[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
||||
state_indices)
|
||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
||||
|
||||
input_state_indices_buffer.copy_(
|
||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Cache during the CUDA graph replay
|
||||
runs.
|
||||
"""
|
||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
return (self.cache, state_indices_tensor)
|
||||
|
||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
||||
finished_requests_ids) -> int:
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
if cur_rid in finished_requests_ids:
|
||||
# set as pad, do not allocate destination index
|
||||
return PAD_SLOT_ID
|
||||
elif cur_rid not in self.cache_indices_mapping:
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
|
||||
return destination_index
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened, so we copy the
|
||||
# existing cache into the siblings seq_ids caches
|
||||
index_exists = next(iter(seq_ids2indices.values()))
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self._copy_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
|
||||
return destination_index
|
||||
else:
|
||||
return self.cache_indices_mapping[cur_rid][seq_id]
|
||||
|
||||
def _prepare_current_run_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]) -> List[int]:
|
||||
return [
|
||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
||||
finished_requests_ids)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.cache_indices_mapping:
|
||||
for seq_id in self.cache_indices_mapping[req_id]:
|
||||
self.free_cache_indices.append(
|
||||
self.cache_indices_mapping[req_id][seq_id])
|
||||
self.cache_indices_mapping.pop(req_id)
|
||||
@ -1,13 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -22,7 +21,7 @@ class MambaCacheParams:
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MambaCacheManager(ConstantSizeCache):
|
||||
class MambaCacheManager:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
|
||||
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
|
||||
@ -33,9 +32,6 @@ class MambaCacheManager(ConstantSizeCache):
|
||||
if not vllm_config.model_config.enforce_eager:
|
||||
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(max_batch_size)
|
||||
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
dtype=dtype,
|
||||
@ -45,32 +41,126 @@ class MambaCacheManager(ConstantSizeCache):
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
self._mamba_cache = (conv_state, temporal_state)
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._mamba_cache
|
||||
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
for cache_t in self.cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the self.mamba_cache
|
||||
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
self.free_cache_indices = list(range(max_batch_size))
|
||||
|
||||
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
cache_tensors, state_indices_tensor = super().current_run_tensors(
|
||||
**kwargs)
|
||||
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
state_indices_tensor = torch.as_tensor(state_indices,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
mamba_cache_tensors = self.mamba_cache
|
||||
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
(mamba_cache_tensors,
|
||||
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant state_indices into the CUDA graph input buffer
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
||||
_, input_state_indices_buffer = input_buffers[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
||||
state_indices)
|
||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
||||
|
||||
input_state_indices_buffer.copy_(
|
||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
return (self.mamba_cache, state_indices_tensor)
|
||||
|
||||
def _copy_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
||||
finished_requests_ids) -> int:
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
if cur_rid in finished_requests_ids:
|
||||
# set as pad, do not allocate destination index
|
||||
return PAD_SLOT_ID
|
||||
elif cur_rid not in self.mamba_cache_indices_mapping:
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {
|
||||
seq_id: destination_index
|
||||
}
|
||||
return destination_index
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened, so we copy the
|
||||
# existing cache into the siblings seq_ids caches
|
||||
index_exists = next(iter(seq_ids2indices.values()))
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self._copy_mamba_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = destination_index
|
||||
return destination_index
|
||||
else:
|
||||
# already exists
|
||||
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]) -> List[int]:
|
||||
return [
|
||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
||||
finished_requests_ids)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.mamba_cache_indices_mapping:
|
||||
for seq_id in self.mamba_cache_indices_mapping[req_id]:
|
||||
self.free_cache_indices.append(
|
||||
self.mamba_cache_indices_mapping[req_id][seq_id])
|
||||
self.mamba_cache_indices_mapping.pop(req_id)
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimaxCacheParams:
|
||||
minimax_cache: torch.Tensor = torch.Tensor()
|
||||
state_indices_tensor: torch.Tensor = torch.Tensor()
|
||||
|
||||
def at_layer_idx(self, layer_idx):
|
||||
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MinimaxCacheManager(ConstantSizeCache):
|
||||
|
||||
def __init__(self, dtype, cache_shape):
|
||||
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
|
||||
self._minimax_cache = torch.empty(size=cache_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._minimax_cache
|
||||
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.cache) > 0
|
||||
for cache_t in self.cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
File diff suppressed because it is too large
Load Diff
@ -35,7 +35,6 @@ _TEXT_GENERATION_MODELS = {
|
||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
# baichuan-7b, upper case 'C' in the class name
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
|
||||
# baichuan-13b, lower case 'c' in the class name
|
||||
|
||||
Reference in New Issue
Block a user