diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5d5080479e..48ec611df1 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "qscale_act": "input_scale", "qscale_weight": "weight_scale", "kv_fake_quantizer.qscale_act": "kv_scale", + "q_fake_quantizer.qscale_act": "attn.q_scale", + "k_fake_quantizer.qscale_act": "k_scale", + "v_fake_quantizer.qscale_act": "v_scale", "wq": "q_proj", "wk": "k_proj", "wv": "v_proj", diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 9ccde29297..411c970b2f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -7,7 +7,7 @@ import os import time from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Literal, Optional, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union import huggingface_hub from huggingface_hub import get_safetensors_metadata, hf_hub_download @@ -42,6 +42,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, SkyworkR1VChatConfig, SolarConfig, Telechat2Config, UltravoxConfig) # yapf: enable +from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname @@ -394,7 +395,16 @@ def get_config( config = _maybe_remap_hf_config_attrs(config) elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision, **kwargs) + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = _download_mistral_config_file(model, revision) + if (max_position_embeddings := + config_dict.get("max_position_embeddings")) is None: + max_position_embeddings = _maybe_retrieve_max_pos_from_hf( + model, revision, **kwargs) + config_dict["max_position_embeddings"] = max_position_embeddings + + config = adapt_config_dict(config_dict) else: supported_formats = [ fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO @@ -693,117 +703,6 @@ def maybe_register_config_serialize_by_value() -> None: exc_info=e) -def load_params_config(model: Union[str, Path], revision: Optional[str], - **kwargs) -> PretrainedConfig: - # This function loads a params.json config which - # should be used when loading models in mistral format - - config_file_name = "params.json" - - config_dict = get_hf_file_to_dict(config_file_name, model, revision) - if config_dict is None: - raise ValueError( - f"Failed to load mistral '{config_file_name}' config for model " - f"{model}. Please check if the model is a mistral-format model " - f"and if the config file exists.") - assert isinstance(config_dict, dict) - - config_mapping = { - "dim": "hidden_size", - "norm_eps": "rms_norm_eps", - "n_kv_heads": "num_key_value_heads", - "n_layers": "num_hidden_layers", - "n_heads": "num_attention_heads", - "hidden_dim": "intermediate_size", - } - - def recurse_elems(elem: Any): - if isinstance(elem, dict): - config_dict = {} - for key, value in elem.items(): - key = config_mapping.get(key, key) - config_dict[key] = recurse_elems(value) - - return config_dict - else: - return elem - - config_dict["model_type"] = config_dict.get("model_type", "transformer") - config_dict["hidden_act"] = config_dict.get("activation", "silu") - config_dict["tie_word_embeddings"] = config_dict.get( - "tie_embeddings", False) - - if config_dict.get("max_position_embeddings") is None: - max_position_embeddings = 128_000 - try: - trust_remote_code_val = kwargs.get("trust_remote_code", False) - hf_config = get_config(model=model, - trust_remote_code=trust_remote_code_val, - revision=revision, - config_format=ConfigFormat.HF) - if hf_value := hf_config.get_text_config().max_position_embeddings: - max_position_embeddings = hf_value - except Exception as e: - logger.warning( - "The params.json file is missing 'max_position_embeddings'" - " and could not get a value from the HF config." - " Defaulting to 128000", - exc_info=e) - config_dict["max_position_embeddings"] = max_position_embeddings - - if config_dict.get("quantization") is not None: - quantization = config_dict.get("quantization", {}) - if quantization.get("qformat_weight") == "fp8_e4m3": - # This maps to the FP8 static per-tensor quantization scheme - quantization_config = { - "quant_method": "fp8", - "activation_scheme": "static" - } - elif quantization.get("quant_method") == "compressed-tensors": - # Pass through the quantization config to compressed-tensors - quantization_config = quantization - else: - raise ValueError( - f"Found unknown quantization='{quantization}' in config") - - config_dict["quantization_config"] = quantization_config - - config_type: Literal["text", - "multimodal"] = "multimodal" if config_dict.get( - "vision_encoder") is not None else "text" - - if config_dict.get("moe") is not None: - config_dict["architectures"] = ["MixtralForCausalLM"] - else: - config_dict["architectures"] = ["MistralForCausalLM"] - - if config_type == "multimodal": - multimodal_config = config_dict.pop("vision_encoder") - quantization_config = config_dict.get("quantization_config", {}) - - config_dict = { - "text_config": config_dict, - "vision_config": multimodal_config - } - config_dict["architectures"] = ["PixtralForConditionalGeneration"] - config_dict["model_type"] = "pixtral" - if quantization_config: - config_dict["quantization_config"] = quantization_config - - config_dict.update(kwargs) - - config_dict = recurse_elems(config_dict) - - # transform to HF config format - if config_type == "multimodal": - config_dict["text_config"] = PretrainedConfig( - **config_dict["text_config"]) - config_dict["vision_config"] = PretrainedConfig( - **config_dict["vision_config"]) - - return PretrainedConfig(**config_dict) - - def get_hf_image_processor_config( model: Union[str, Path], hf_token: Optional[Union[bool, str]] = None, @@ -920,3 +819,35 @@ def try_get_tokenizer_config( ) except Exception: return None + + +def _download_mistral_config_file(model, revision) -> dict: + config_file_name = "params.json" + config_dict = get_hf_file_to_dict(config_file_name, model, revision) + if config_dict is None: + raise ValueError( + f"Failed to load mistral '{config_file_name}' config for model " + f"{model}. Please check if the model is a mistral-format model " + f"and if the config file exists.") + assert isinstance(config_dict, dict) + return config_dict + + +def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: + max_position_embeddings = 128_000 + try: + trust_remote_code_val = kwargs.get("trust_remote_code", False) + hf_config = get_config(model=model, + trust_remote_code=trust_remote_code_val, + revision=revision, + config_format=ConfigFormat.HF) + if hf_value := hf_config.get_text_config().max_position_embeddings: + max_position_embeddings = hf_value + except Exception as e: + logger.warning( + "The params.json file is missing 'max_position_embeddings'" + " and could not get a value from the HF config." + " Defaulting to 128000", + exc_info=e) + + return max_position_embeddings diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py new file mode 100644 index 0000000000..d2059c55a3 --- /dev/null +++ b/vllm/transformers_utils/configs/mistral.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from transformers import PretrainedConfig + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def adapt_config_dict(config_dict: dict[str, Any], + **kwargs) -> PretrainedConfig: + config_dict.update(kwargs) + config_dict = _remap_general_mistral_args(config_dict) + + if bool(config_dict.get("quantization")): + config_dict = _remap_mistral_quantization_args(config_dict) + + if bool(config_dict.get("moe")): + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + if bool(config_dict.get("yarn")): + config_dict = _remap_mistral_yarn_args(config_dict) + if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args") + or config_dict.get("vision_encoder")): + config_dict = _remap_mistral_vision_args(config_dict) + + config = PretrainedConfig.from_dict(config_dict) + + logger.debug("Initialized config", config) + + return config + + +def _remap_mistral_vision_args(config: dict) -> dict: + if config.get("multimodal"): + vision_config = config.pop("multimodal") + else: + vision_config = config.pop("vision_encoder") + + quant_config = config.get("quantization_config") + config = { + "model_type": "pixtral", + "architectures": ["PixtralForConditionalGeneration"], + "text_config": PretrainedConfig.from_dict(config), + "vision_config": PretrainedConfig.from_dict(vision_config), + } + if quant_config: + config["quantization_config"] = quant_config + return config + + +def _remap_mistral_yarn_args(config: dict) -> dict: + # Direct remaps: yarn.X -> rope_scaling.Y + # Source keys are from mistral.model.args.YarnArgs + _map = { + "beta": "beta_fast", + "alpha": "beta_slow", + } + yarn_config = config.get("yarn") or {} + renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()} + config["rope_scaling"] = { + "rope_type": "yarn", + "mscale_all_dim": 1, # We hardcoded this to 1 + **renamed_yarn_config + } + return config + + +def _remap_general_mistral_args(config: dict) -> dict: + # Mistral key -> HF key + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + # HF key -> (Mistral key, default value) + top_level_mapping_with_default = { + "model_type": ("model_type", "transformer"), + "hidden_act": ("activation", "silu"), + "tie_word_embeddings": ("tied_embeddings", False), + "max_seq_len": ("max_seq_len", 128_000), + "max_position_embeddings": ("max_position_embeddings", 128_000), + } + + for key, new_key in config_mapping.items(): + if key in config: + config[new_key] = config.pop(key) + + for new_key, (key, + default_value) in top_level_mapping_with_default.items(): + config[new_key] = config.pop(key, default_value) + + return config + + +def _remap_mistral_quantization_args(config: dict) -> dict: + quantization = config.get("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + # This maps to the FP8 static per-tensor quantization scheme + quantization_config = { + "quant_method": "fp8", + "activation_scheme": "static" + } + elif quantization.get("quant_method") == "compressed-tensors": + # Pass through the quantization config to compressed-tensors + quantization_config = quantization + else: + raise ValueError( + f"Found unknown quantization='{quantization}' in config") + + config["quantization_config"] = quantization_config + + return config