[V1][Speculative Decoding] Fix DeepSeek MTP (#20022)
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
This commit is contained in:
@ -52,11 +52,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
@ -74,8 +69,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds[positions == 0] = 0
|
||||
@ -112,7 +105,10 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -123,6 +119,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||
input_ids,
|
||||
@ -242,6 +240,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# According to DeepSeek-V3 Technical Report, MTP modules
|
||||
# shares embedding layer. We only load the first weights.
|
||||
if (spec_layer != self.model.mtp_start_layer_idx
|
||||
and ".layers" not in name):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
@ -253,17 +257,25 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
Add .mtp_block for modules in transformer layer block for spec layer
|
||||
and rename shared layer weights to be top level.
|
||||
"""
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
||||
]
|
||||
shared_weight_names = ["embed_tokens"]
|
||||
spec_layer_weight = False
|
||||
shared_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
if weight_name in shared_weight_names:
|
||||
shared_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# treat rest weights as weights for transformer layer block
|
||||
name = name.replace(f"model.layers.{spec_layer}.",
|
||||
f"model.layers.{spec_layer}.mtp_block.")
|
||||
elif shared_weight:
|
||||
# treat shared weights as top level weights
|
||||
name = name.replace(f"model.layers.{spec_layer}.", "model.")
|
||||
return name
|
||||
|
||||
@ -148,7 +148,7 @@ class EagleProposer:
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||
attn_metadata = self.runner.attn_metadata_builders[0].build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user