[Bugfix] Fix EAGLE vocab embedding for multimodal target model (#19570)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
@ -329,16 +329,24 @@ class EagleProposer:
|
||||
|
||||
self.attn_layer_names = list(draft_attn_layer_names)
|
||||
|
||||
if supports_multimodal(target_model):
|
||||
# handle multimodality
|
||||
self.model.config.image_token_index = (
|
||||
target_model.config.image_token_index)
|
||||
target_language_model = target_model.get_language_model()
|
||||
else:
|
||||
target_language_model = target_model
|
||||
# share embed_tokens with the target model if needed
|
||||
if get_pp_group().world_size == 1 \
|
||||
and self.model.model.embed_tokens.weight.shape \
|
||||
== target_model.model.embed_tokens.weight.shape:
|
||||
== target_language_model.model.embed_tokens.weight.shape:
|
||||
logger.info(
|
||||
"Assuming the EAGLE head shares the same vocab embedding" \
|
||||
" with the target model."
|
||||
)
|
||||
del self.model.model.embed_tokens
|
||||
self.model.model.embed_tokens = target_model.model.embed_tokens
|
||||
self.model.model.embed_tokens = (
|
||||
target_language_model.model.embed_tokens)
|
||||
else:
|
||||
logger.info(
|
||||
"The EAGLE head's vocab embedding will be loaded separately" \
|
||||
@ -349,12 +357,9 @@ class EagleProposer:
|
||||
# some model definition do not define lm_head explicitly
|
||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||
if self.vllm_config.speculative_config.method != "eagle3" and \
|
||||
hasattr(target_model, "lm_head"):
|
||||
hasattr(target_language_model, "lm_head"):
|
||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||
if supports_multimodal(target_model):
|
||||
self.model.lm_head = target_model.get_language_model().lm_head
|
||||
else:
|
||||
self.model.lm_head = target_model.lm_head
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
|
||||
Reference in New Issue
Block a user