[V1][Speculative Decoding] Fix DeepSeek MTP (#20022)

Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
This commit is contained in:
cjackal
2025-06-26 00:41:02 +09:00
committed by GitHub
parent bf5181583f
commit 8359f4c8d8
2 changed files with 21 additions and 9 deletions

View File

@ -52,11 +52,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
@ -74,8 +69,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
@ -112,7 +105,10 @@ class DeepSeekMultiTokenPredictor(nn.Module):
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
@ -123,6 +119,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
@ -242,6 +240,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if name.endswith(".bias") and name not in params_dict:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
@ -253,17 +257,25 @@ class DeepSeekMTP(nn.Module, SupportsPP):
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name

View File

@ -148,7 +148,7 @@ class EagleProposer:
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builder.build(
attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)