[Config] Refactor mistral configs (#20570)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Patrick von Platen
2025-07-08 00:25:10 +02:00
committed by GitHub
parent 042d131f39
commit 14601f5fba
3 changed files with 167 additions and 113 deletions

View File

@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"q_fake_quantizer.qscale_act": "attn.q_scale",
"k_fake_quantizer.qscale_act": "k_scale",
"v_fake_quantizer.qscale_act": "v_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",

View File

@ -7,7 +7,7 @@ import os
import time
from functools import cache, partial
from pathlib import Path
from typing import Any, Callable, Literal, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union
import huggingface_hub
from huggingface_hub import get_safetensors_metadata, hf_hub_download
@ -42,6 +42,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
SkyworkR1VChatConfig, SolarConfig,
Telechat2Config, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
@ -394,7 +395,16 @@ def get_config(
config = _maybe_remap_hf_config_attrs(config)
elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, **kwargs)
# This function loads a params.json config which
# should be used when loading models in mistral format
config_dict = _download_mistral_config_file(model, revision)
if (max_position_embeddings :=
config_dict.get("max_position_embeddings")) is None:
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
model, revision, **kwargs)
config_dict["max_position_embeddings"] = max_position_embeddings
config = adapt_config_dict(config_dict)
else:
supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
@ -693,117 +703,6 @@ def maybe_register_config_serialize_by_value() -> None:
exc_info=e)
def load_params_config(model: Union[str, Path], revision: Optional[str],
**kwargs) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_file_name = "params.json"
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
if config_dict is None:
raise ValueError(
f"Failed to load mistral '{config_file_name}' config for model "
f"{model}. Please check if the model is a mistral-format model "
f"and if the config file exists.")
assert isinstance(config_dict, dict)
config_mapping = {
"dim": "hidden_size",
"norm_eps": "rms_norm_eps",
"n_kv_heads": "num_key_value_heads",
"n_layers": "num_hidden_layers",
"n_heads": "num_attention_heads",
"hidden_dim": "intermediate_size",
}
def recurse_elems(elem: Any):
if isinstance(elem, dict):
config_dict = {}
for key, value in elem.items():
key = config_mapping.get(key, key)
config_dict[key] = recurse_elems(value)
return config_dict
else:
return elem
config_dict["model_type"] = config_dict.get("model_type", "transformer")
config_dict["hidden_act"] = config_dict.get("activation", "silu")
config_dict["tie_word_embeddings"] = config_dict.get(
"tie_embeddings", False)
if config_dict.get("max_position_embeddings") is None:
max_position_embeddings = 128_000
try:
trust_remote_code_val = kwargs.get("trust_remote_code", False)
hf_config = get_config(model=model,
trust_remote_code=trust_remote_code_val,
revision=revision,
config_format=ConfigFormat.HF)
if hf_value := hf_config.get_text_config().max_position_embeddings:
max_position_embeddings = hf_value
except Exception as e:
logger.warning(
"The params.json file is missing 'max_position_embeddings'"
" and could not get a value from the HF config."
" Defaulting to 128000",
exc_info=e)
config_dict["max_position_embeddings"] = max_position_embeddings
if config_dict.get("quantization") is not None:
quantization = config_dict.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {
"quant_method": "fp8",
"activation_scheme": "static"
}
elif quantization.get("quant_method") == "compressed-tensors":
# Pass through the quantization config to compressed-tensors
quantization_config = quantization
else:
raise ValueError(
f"Found unknown quantization='{quantization}' in config")
config_dict["quantization_config"] = quantization_config
config_type: Literal["text",
"multimodal"] = "multimodal" if config_dict.get(
"vision_encoder") is not None else "text"
if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
if config_type == "multimodal":
multimodal_config = config_dict.pop("vision_encoder")
quantization_config = config_dict.get("quantization_config", {})
config_dict = {
"text_config": config_dict,
"vision_config": multimodal_config
}
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
config_dict["model_type"] = "pixtral"
if quantization_config:
config_dict["quantization_config"] = quantization_config
config_dict.update(kwargs)
config_dict = recurse_elems(config_dict)
# transform to HF config format
if config_type == "multimodal":
config_dict["text_config"] = PretrainedConfig(
**config_dict["text_config"])
config_dict["vision_config"] = PretrainedConfig(
**config_dict["vision_config"])
return PretrainedConfig(**config_dict)
def get_hf_image_processor_config(
model: Union[str, Path],
hf_token: Optional[Union[bool, str]] = None,
@ -920,3 +819,35 @@ def try_get_tokenizer_config(
)
except Exception:
return None
def _download_mistral_config_file(model, revision) -> dict:
config_file_name = "params.json"
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
if config_dict is None:
raise ValueError(
f"Failed to load mistral '{config_file_name}' config for model "
f"{model}. Please check if the model is a mistral-format model "
f"and if the config file exists.")
assert isinstance(config_dict, dict)
return config_dict
def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
max_position_embeddings = 128_000
try:
trust_remote_code_val = kwargs.get("trust_remote_code", False)
hf_config = get_config(model=model,
trust_remote_code=trust_remote_code_val,
revision=revision,
config_format=ConfigFormat.HF)
if hf_value := hf_config.get_text_config().max_position_embeddings:
max_position_embeddings = hf_value
except Exception as e:
logger.warning(
"The params.json file is missing 'max_position_embeddings'"
" and could not get a value from the HF config."
" Defaulting to 128000",
exc_info=e)
return max_position_embeddings

View File

@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers import PretrainedConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
def adapt_config_dict(config_dict: dict[str, Any],
**kwargs) -> PretrainedConfig:
config_dict.update(kwargs)
config_dict = _remap_general_mistral_args(config_dict)
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
if bool(config_dict.get("moe")):
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
if bool(config_dict.get("yarn")):
config_dict = _remap_mistral_yarn_args(config_dict)
if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args")
or config_dict.get("vision_encoder")):
config_dict = _remap_mistral_vision_args(config_dict)
config = PretrainedConfig.from_dict(config_dict)
logger.debug("Initialized config", config)
return config
def _remap_mistral_vision_args(config: dict) -> dict:
if config.get("multimodal"):
vision_config = config.pop("multimodal")
else:
vision_config = config.pop("vision_encoder")
quant_config = config.get("quantization_config")
config = {
"model_type": "pixtral",
"architectures": ["PixtralForConditionalGeneration"],
"text_config": PretrainedConfig.from_dict(config),
"vision_config": PretrainedConfig.from_dict(vision_config),
}
if quant_config:
config["quantization_config"] = quant_config
return config
def _remap_mistral_yarn_args(config: dict) -> dict:
# Direct remaps: yarn.X -> rope_scaling.Y
# Source keys are from mistral.model.args.YarnArgs
_map = {
"beta": "beta_fast",
"alpha": "beta_slow",
}
yarn_config = config.get("yarn") or {}
renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()}
config["rope_scaling"] = {
"rope_type": "yarn",
"mscale_all_dim": 1, # We hardcoded this to 1
**renamed_yarn_config
}
return config
def _remap_general_mistral_args(config: dict) -> dict:
# Mistral key -> HF key
config_mapping = {
"dim": "hidden_size",
"norm_eps": "rms_norm_eps",
"n_kv_heads": "num_key_value_heads",
"n_layers": "num_hidden_layers",
"n_heads": "num_attention_heads",
"hidden_dim": "intermediate_size",
}
# HF key -> (Mistral key, default value)
top_level_mapping_with_default = {
"model_type": ("model_type", "transformer"),
"hidden_act": ("activation", "silu"),
"tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", 128_000),
"max_position_embeddings": ("max_position_embeddings", 128_000),
}
for key, new_key in config_mapping.items():
if key in config:
config[new_key] = config.pop(key)
for new_key, (key,
default_value) in top_level_mapping_with_default.items():
config[new_key] = config.pop(key, default_value)
return config
def _remap_mistral_quantization_args(config: dict) -> dict:
quantization = config.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {
"quant_method": "fp8",
"activation_scheme": "static"
}
elif quantization.get("quant_method") == "compressed-tensors":
# Pass through the quantization config to compressed-tensors
quantization_config = quantization
else:
raise ValueError(
f"Found unknown quantization='{quantization}' in config")
config["quantization_config"] = quantization_config
return config