[Model] Introduce Kimi Linear to vLLM (#27809)

Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
Zhiyuan Li
2025-10-30 21:02:27 +08:00
committed by GitHub
parent 1994de99ea
commit 4e68cc9b6a
15 changed files with 1325 additions and 48 deletions

View File

@ -382,6 +382,7 @@ th {
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ |
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ |
| `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ |
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ |
| `Lfm2MoeForCausalLM` | LFM2MoE | `LiquidAI/LFM2-8B-A1B-preview`, etc. | ✅︎ | ✅︎ |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ |

View File

@ -296,6 +296,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"random": "ai21labs/Jamba-tiny-random",
},
),
"KimiLinearForCausalLM": _HfExamplesInfo(
"moonshotai/Kimi-Linear-48B-A3B-Instruct", trust_remote_code=True
),
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"),
"Lfm2MoeForCausalLM": _HfExamplesInfo(
"LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58"

View File

@ -453,6 +453,7 @@ class CompilationConfig:
"vllm::linear_attention",
"vllm::plamo2_mamba_mixer",
"vllm::gdn_attention",
"vllm::kda_attention",
"vllm::sparse_attn_indexer",
]

View File

@ -1236,6 +1236,7 @@ class ModelConfig:
"deepseek_v32",
"deepseek_mtp",
"kimi_k2",
"kimi_linear",
"longcat_flash",
):
return self.hf_text_config.kv_lora_rank is not None

View File

@ -1304,7 +1304,7 @@ def kda_gate_fwd_kernel(
tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
def kda_gate_fwd(
def fused_kda_gate(
g: torch.Tensor,
A: torch.Tensor,
head_k_dim: int,

View File

@ -0,0 +1,426 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from einops import rearrange
from torch import nn
from vllm.attention import AttentionBackend
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from .fla.ops.kda import (
FusedRMSNormGated,
chunk_kda,
fused_kda_gate,
fused_recurrent_kda,
)
from .linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from .mamba.abstract import MambaBase
from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator
from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .quantization.base_config import QuantizationConfig
logger = init_logger(__name__)
def kda_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output)
def kda_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="kda_attention",
op_func=kda_attention,
mutates_args=["output"],
fake_impl=kda_attention_fake,
)
class KimiDeltaAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
return GDNAttentionBackend
def get_state_dtype(
self,
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
if self.model_config is None or self.cache_config is None:
raise ValueError("model_config and cache_config must be set")
return MambaStateDtypeCalculator.kda_state_dtype(
self.model_config.dtype, self.cache_config.mamba_cache_dtype
)
def get_state_shape(
self,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.kda_state_shape(
self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
)
def __init__(
self,
layer_idx: int,
hidden_size: int,
quant_config: QuantizationConfig | None = None,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
rms_norm_eps: float = 1e-5,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = hidden_size
self.model_config = model_config
self.cache_config = cache_config
if model_config is None:
raise ValueError("model_config must be provided")
kda_config = model_config.linear_attn_config
self.head_dim = kda_config["head_dim"]
self.num_heads = kda_config["num_heads"]
self.layer_idx = layer_idx
self.prefix = prefix
assert self.num_heads % self.tp_size == 0
self.local_num_heads = divide(self.num_heads, self.tp_size)
projection_size = self.head_dim * self.num_heads
self.conv_size = kda_config["short_conv_kernel_size"]
self.q_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.k_proj",
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.v_proj",
)
self.f_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_a_proj",
)
self.f_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_b_proj",
)
self.dt_bias = nn.Parameter(
torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
)
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.b_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.b_proj",
)
self.q_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.q_conv1d",
)
self.k_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.k_conv1d",
)
self.v_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.v_conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)
self.A_log = nn.Parameter(
torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})
self.g_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_a_proj",
)
self.g_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_b_proj",
)
self.o_norm = FusedRMSNormGated(
self.head_dim, eps=rms_norm_eps, activation="sigmoid"
)
self.o_proj = RowParallelLinear(
projection_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
return torch.ops.vllm.kda_attention(
hidden_states,
output,
self.prefix,
)
def _forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
# Mimic the memory allocation in the real run
q = torch.empty_like(hidden_states)
k = torch.empty_like(hidden_states)
v = torch.empty_like(hidden_states)
g = hidden_states.new_empty(
hidden_states.size(0),
self.local_num_heads,
self.head_dim,
dtype=torch.float32,
)
beta = torch.empty(
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
)
core_attn_out = torch.empty_like(hidden_states)
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
constant_caches = self.kv_cache[forward_context.virtual_engine]
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
# deal with strides
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
)
k_conv_weights = self.k_conv1d.weight.view(
self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
)
v_conv_weights = self.v_conv1d.weight.view(
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
)
if attn_metadata.num_prefills > 0:
q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1)
q = causal_conv1d_fn(
q_proj_states,
q_conv_weights,
self.q_conv1d.bias,
activation="silu",
conv_states=conv_state_q,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
k = causal_conv1d_fn(
k_proj_states,
k_conv_weights,
self.k_conv1d.bias,
activation="silu",
conv_states=conv_state_k,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
v = causal_conv1d_fn(
v_proj_states,
v_conv_weights,
self.v_conv1d.bias,
activation="silu",
conv_states=conv_state_v,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
else:
decode_conv_indices = non_spec_state_indices_tensor[
: attn_metadata.num_decodes
]
q = causal_conv1d_update(
q_proj_states,
conv_state_q,
q_conv_weights,
self.q_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
k = causal_conv1d_update(
k_proj_states,
conv_state_k,
k_conv_weights,
self.k_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
v = causal_conv1d_update(
v_proj_states,
conv_state_v,
v_conv_weights,
self.v_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
)
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
if attn_metadata.num_prefills > 0:
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0
initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
)
# Init cache
recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
else:
(
core_attn_out_non_spec,
last_recurrent_state,
) = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]

View File

@ -80,6 +80,15 @@ class MambaStateDtypeCalculator:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
@classmethod
def kda_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
):
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype, state_dtype, torch.float32)
class MambaStateShapeCalculator:
@classmethod
@ -182,3 +191,35 @@ class MambaStateShapeCalculator:
head_v_dim,
)
return conv_state_shape, temporal_state_shape
@classmethod
def kda_state_shape(
cls,
tp_world_size: int,
num_heads: int,
head_dim: int,
num_k_heads: int | None = None,
head_k_dim: int | None = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return (
conv_state_shape,
conv_state_k_shape,
conv_state_k_shape,
recurrent_state_shape,
)

View File

@ -147,9 +147,10 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)
if self.rotary_emb is not None:
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from math import lcm
from typing import TYPE_CHECKING
import vllm.envs as envs
@ -8,7 +9,7 @@ from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
@ -347,12 +348,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
@ -372,17 +389,6 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if mamba_page_size == 0:
return
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: 128-byte alignment
# * Other MLA backends: 64-byte alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
else:
kernel_block_alignment_size = 16
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
@ -400,15 +406,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# easily by changing the way we layout chunks in the
# mamba2 kernels.
from math import gcd
def lcm(a, b):
return a * b // gcd(a, b)
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
base_chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size

View File

@ -0,0 +1,663 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Any
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.kda import KimiDeltaAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from .interfaces import HasInnerState, IsHybrid, MixtureOfExperts, SupportsPP
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
class KimiMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QKVParallelLinear | None = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class KimiMoE(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
layer_idx: int = 0,
):
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
moe_intermediate_size = config.moe_intermediate_size
num_experts = config.num_experts
moe_renormalize = config.moe_renormalize
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.num_shared_experts = config.num_shared_experts
self.layer_idx = layer_idx
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
hidden_size,
num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts))
self.experts = FusedMoE(
num_experts=num_experts,
top_k=config.num_experts_per_token,
hidden_size=hidden_size,
intermediate_size=moe_intermediate_size,
reduce_results=False,
renormalize=moe_renormalize,
quant_config=quant_config,
use_grouped_topk=config.use_grouped_topk,
num_expert_group=config.num_expert_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.moe_router_activation_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
)
if self.num_shared_experts is not None:
intermediate_size = moe_intermediate_size * self.num_shared_experts
self.shared_experts = KimiMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
if self.num_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class KimiMLAAttention(nn.Module):
"""
Main reference: DeepseekV2 vllm Implementation
"""
def __init__(
self,
config: KimiLinearConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
rope_theta: float = 10000,
use_nope: bool = False,
rope_scaling: dict[str, Any] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.use_nope = use_nope
assert self.use_nope is True
assert self.q_lora_rank is None
assert rope_scaling is None
assert num_heads % tp_size == 0
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_a_layernorm = RMSNorm(
self.kv_lora_rank,
eps=config.rms_norm_eps,
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
mla_modules = MLAModules(
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
rotary_emb=None,
o_proj=self.o_proj,
fused_qkv_a_proj=None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
q_a_layernorm=None,
q_b_proj=None,
q_proj=self.q_proj,
indexer=None,
is_sparse=False,
topk_indices_buffer=None,
)
self.mla_attn = MultiHeadLatentAttentionWrapper(
self.hidden_size,
self.num_local_heads,
self.scaling,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.v_head_dim,
self.q_lora_rank,
self.kv_lora_rank,
mla_modules,
cache_config,
quant_config,
prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
output: torch.Tensor,
) -> None:
output[:] = self.mla_attn(positions, hidden_states)
class KimiDecoderLayer(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
layer_idx: int,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
model_config: ModelConfig | None = None,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.is_moe = config.is_moe
if config.is_kda_layer(layer_idx):
self.self_attn = KimiDeltaAttention(
layer_idx=layer_idx,
hidden_size=config.hidden_size,
quant_config=quant_config,
cache_config=cache_config,
model_config=config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = KimiMLAAttention(
layer_idx=layer_idx,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
quant_config=quant_config,
cache_config=cache_config,
model_config=model_config,
prefix=f"{prefix}.self_attn",
config=config,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank,
kv_lora_rank=config.kv_lora_rank,
use_nope=config.mla_use_nope,
)
if (
self.is_moe
and config.num_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
):
self.block_sparse_moe = KimiMoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.mlp = self.block_sparse_moe
else:
self.mlp = KimiMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
attn_output = torch.empty_like(hidden_states)
self.self_attn(
hidden_states=hidden_states,
positions=positions,
output=attn_output,
)
hidden_states = attn_output
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class KimiLinearModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_text_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
extra_kwargs = {}
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
return KimiDecoderLayer(
config,
layer_idx,
cache_config,
quant_config,
parallel_config,
model_config,
prefix,
**extra_kwargs,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
get_layer,
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
world_size = get_tensor_model_parallel_world_size()
assert config.num_attention_heads % world_size == 0, (
"num_attention_heads must be divisible by world_size"
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for _, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class KimiLinearForCausalLM(
nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.config = self.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.model = KimiLinearModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.config.vocab_size, scale=logit_scale
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.kda_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
)
@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
)
return MambaStateShapeCalculator.kda_state_shape(
tp_size,
hf_config.linear_attn_config["num_heads"],
hf_config.linear_attn_config["head_dim"],
conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"],
num_spec=num_spec,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if self.config.is_moe:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_experts,
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for idx, (param_name, weight_name, expert_id, shard_id) in enumerate(
expert_params_mapping
):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias")
and name not in params_dict
and not self.config.is_linear_attn
): # noqa: E501
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, **kwargs)
loaded_params.add(name)
def get_spec_layer_idx_from_weight_name(
config: KimiLinearConfig, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None

View File

@ -118,6 +118,7 @@ _TEXT_GENERATION_MODELS = {
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
"Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),

View File

@ -79,6 +79,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_v3="DeepseekV3Config",
deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)

View File

@ -19,6 +19,7 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
@ -54,6 +55,7 @@ __all__ = [
"MiDashengLMConfig",
"MLPSpeculatorConfig",
"MoonViTConfig",
"KimiLinearConfig",
"KimiVLConfig",
"NemotronConfig",
"NemotronHConfig",

View File

@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers.configuration_utils import PretrainedConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
class KimiLinearConfig(PretrainedConfig):
model_type = "kimi_linear"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
model_type="kimi_linear",
vocab_size=163840,
hidden_size=4096,
head_dim=None,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
rope_theta=10000.0,
rope_scaling=None,
tie_word_embeddings=False,
moe_intermediate_size: int | None = None,
moe_renormalize: bool = True,
moe_router_activation_func: str = "sigmoid",
num_experts: int | None = None,
num_experts_per_token: int | None = None,
num_shared_experts: int = 0,
routed_scaling_factor: float = 1.0,
first_k_dense_replace: int = 0,
moe_layer_freq: int = 1,
use_grouped_topk: bool = True,
num_expert_group: int = 1,
topk_group: int = 1,
q_lora_rank: int | None = None,
kv_lora_rank: int | None = None,
qk_nope_head_dim: int | None = None,
qk_rope_head_dim: int | None = None,
v_head_dim: int | None = None,
mla_use_nope: bool | None = False,
num_nextn_predict_layers: int = 0,
linear_attn_config: dict | None = None,
**kwargs,
):
self.model_type = model_type
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.head_dim = (
head_dim if head_dim is not None else hidden_size // num_attention_heads
)
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.mla_use_nope = mla_use_nope
# moe config
self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
self.moe_renormalize = moe_renormalize
self.num_shared_experts = num_shared_experts
self.routed_scaling_factor = routed_scaling_factor
self.moe_router_activation_func = moe_router_activation_func
assert self.moe_router_activation_func in ("softmax", "sigmoid")
self.moe_intermediate_size = moe_intermediate_size
self.first_k_dense_replace = first_k_dense_replace
self.moe_layer_freq = moe_layer_freq
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.num_nextn_predict_layers = num_nextn_predict_layers
if linear_attn_config is not None:
assert linear_attn_config["kda_layers"] is not None
assert linear_attn_config["full_attn_layers"] is not None
self.linear_attn_config = linear_attn_config
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def is_mla(self):
return (
self.q_lora_rank is not None
or self.kv_lora_rank is not None
or self.qk_nope_head_dim is not None
or self.qk_rope_head_dim is not None
or self.v_head_dim is not None
or self.mla_use_nope is True
)
@property
def is_moe(self):
return self.num_experts is not None
@property
def is_linear_attn(self) -> bool:
return not (
self.linear_attn_config is None
or (
isinstance(self.linear_attn_config, dict)
and self.linear_attn_config["kda_layers"] is not None
and len(self.linear_attn_config["kda_layers"]) == 0
)
)
def is_kda_layer(self, layer_idx: int):
return (
self.linear_attn_config is not None
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
)

View File

@ -8,6 +8,7 @@ from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
@ -4134,26 +4135,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def calculate_reorder_batch_threshold(self) -> None:
"""
Check that if any backends reorder batches; that the reordering
is compatible (e.g., decode threshold is the same)
Choose the minimum reorder batch threshold from all attention groups.
Backends should be able to support lower threshold then what they request
just may have a performance penalty due to that backend treating decodes
as prefills.
"""
for group in self._attn_group_iterator():
attn_metadata_builder_i = group.get_metadata_builder()
min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b)
# check that if any backends reorder batches; that the reordering
# is compatible (e.g., decode threshold is the same)
reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold
if reorder_batch_threshold_i is not None:
if self.reorder_batch_threshold is not None:
if reorder_batch_threshold_i != self.reorder_batch_threshold:
raise ValueError(
f"Attention backend reorders decodes with "
f"threshold {reorder_batch_threshold_i} but other "
f"backend uses threshold "
f"{self.reorder_batch_threshold}"
)
else:
self.reorder_batch_threshold = reorder_batch_threshold_i
reorder_batch_thresholds = [
group.get_metadata_builder().reorder_batch_threshold
for group in self._attn_group_iterator()
]
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
def _find_compatible_block_sizes(
self,