[Model] Ultravox: Support Llama 4 and Gemma 3 backends (#17818)
Signed-off-by: Farzad Abdolhosseini <farzad@fixie.ai> Signed-off-by: Patrick Li <patrick8289@gmail.com> Co-authored-by: Patrick Li <patrick8289@gmail.com>
This commit is contained in:
committed by
GitHub
parent
7ae75fa6d0
commit
62965de5fe
@ -221,6 +221,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501
|
||||
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
||||
is_available_online=False),
|
||||
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||
is_available_online=False),
|
||||
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
||||
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"),
|
||||
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
|
||||
|
||||
@ -89,6 +89,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
|
||||
@ -39,9 +39,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
merge_multimodal_embeddings,
|
||||
merge_multimodal_embeddings_from_map)
|
||||
|
||||
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
|
||||
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||
_AUDIO_TOKENS_PER_SECOND = 6.25
|
||||
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
|
||||
_MAX_ENCODER_BATCH_SIZE = 16
|
||||
|
||||
|
||||
@ -80,14 +78,15 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
|
||||
sampling_rate: Optional[int] = None,
|
||||
**kwargs: object,
|
||||
) -> ProcessorMixin:
|
||||
config = self.ctx.model_config.hf_config
|
||||
hf_processor = self.ctx.get_hf_processor(**kwargs)
|
||||
|
||||
# NOTE: Ultravox processing definition uses '<|eot_id|>' as the
|
||||
# placeholder that will cause confusion with the actual end of turn
|
||||
# token, thus we override placeholder with a reserved special
|
||||
# token.
|
||||
# token, thus we override placeholder with a reserved token.
|
||||
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
|
||||
hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
|
||||
hf_processor.audio_replacement_token_id = config.audio_token_index
|
||||
|
||||
return hf_processor
|
||||
|
||||
def get_feature_extractor(
|
||||
@ -274,7 +273,7 @@ class UltravoxProjector(nn.Module):
|
||||
else:
|
||||
self.act = get_act_fn(config.projector_act)
|
||||
|
||||
dim_out = config.text_config.hidden_size
|
||||
dim_out = config.text_hidden_size
|
||||
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
|
||||
|
||||
# Ultravox v0.4.1 and below use layer_norm after the second linear layer
|
||||
@ -572,9 +571,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
# The audio token index is not included in the embedding table
|
||||
# We need to remove it before embedding lookup
|
||||
safe_input_ids = input_ids.clone()
|
||||
safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0
|
||||
inputs_embeds = self.language_model.get_input_embeddings(
|
||||
safe_input_ids)
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
|
||||
# TODO(ywang96): remove this block after v0 is deprecated.
|
||||
if not envs.VLLM_USE_V1:
|
||||
@ -585,7 +589,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
else:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
_AUDIO_PLACEHOLDER_TOKEN)
|
||||
self.config.audio_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
@ -623,10 +627,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
multimodal_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
language_model = self.language_model
|
||||
if hasattr(language_model, "language_model"):
|
||||
language_model = language_model.language_model
|
||||
|
||||
hidden_states = language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -45,6 +45,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "ultravox"
|
||||
audio_token = "<|audio|>"
|
||||
is_composition = False
|
||||
|
||||
def __init__(
|
||||
@ -80,29 +81,32 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
# Avoid circular import
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
self.text_config = get_config(text_model_id,
|
||||
trust_remote_code=False)
|
||||
text_config_obj = get_config(text_model_id,
|
||||
trust_remote_code=False)
|
||||
else:
|
||||
text_config = text_config or {}
|
||||
self.text_config = transformers.CONFIG_MAPPING[text_config.get(
|
||||
text_config_obj = transformers.CONFIG_MAPPING[text_config.get(
|
||||
"model_type", "llama")](**text_config)
|
||||
|
||||
inner_text_config = text_config_obj.get_text_config()
|
||||
|
||||
if audio_model_id is not None:
|
||||
# Avoid circular import
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
self.audio_config = get_config(audio_model_id,
|
||||
trust_remote_code=False)
|
||||
audio_config = get_config(audio_model_id, trust_remote_code=False)
|
||||
else:
|
||||
audio_config = audio_config or {}
|
||||
self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
|
||||
audio_config = transformers.CONFIG_MAPPING[audio_config.get(
|
||||
"model_type", "whisper")](**audio_config)
|
||||
|
||||
self.text_config = text_config_obj
|
||||
self.audio_config = audio_config
|
||||
self.text_model_lora_config = text_model_lora_config or {}
|
||||
self.audio_model_lora_config = audio_model_lora_config or {}
|
||||
|
||||
self.vocab_size = self.text_config.vocab_size
|
||||
|
||||
self.initializer_range = self.text_config.initializer_range
|
||||
self.vocab_size = inner_text_config.vocab_size
|
||||
self.initializer_range = inner_text_config.initializer_range
|
||||
self.text_hidden_size = inner_text_config.hidden_size
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user