[Bugfix] Fix EAGLE vocab embedding for multimodal target model (#19570)

Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
qizixi
2025-06-12 20:09:19 -07:00
committed by GitHub
parent e3b12667d4
commit c68698b326

View File

@ -329,16 +329,24 @@ class EagleProposer:
self.attn_layer_names = list(draft_attn_layer_names)
if supports_multimodal(target_model):
# handle multimodality
self.model.config.image_token_index = (
target_model.config.image_token_index)
target_language_model = target_model.get_language_model()
else:
target_language_model = target_model
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \
and self.model.model.embed_tokens.weight.shape \
== target_model.model.embed_tokens.weight.shape:
== target_language_model.model.embed_tokens.weight.shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding" \
" with the target model."
)
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_model.model.embed_tokens
self.model.model.embed_tokens = (
target_language_model.model.embed_tokens)
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately" \
@ -349,12 +357,9 @@ class EagleProposer:
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \
hasattr(target_model, "lm_head"):
hasattr(target_language_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
if supports_multimodal(target_model):
self.model.lm_head = target_model.get_language_model().lm_head
else:
self.model.lm_head = target_model.lm_head
self.model.lm_head = target_language_model.lm_head
@torch.inference_mode()
def dummy_run(