[Config] Refactor mistral configs (#20570)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
042d131f39
commit
14601f5fba
@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"qscale_act": "input_scale",
|
||||
"qscale_weight": "weight_scale",
|
||||
"kv_fake_quantizer.qscale_act": "kv_scale",
|
||||
"q_fake_quantizer.qscale_act": "attn.q_scale",
|
||||
"k_fake_quantizer.qscale_act": "k_scale",
|
||||
"v_fake_quantizer.qscale_act": "v_scale",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import time
|
||||
from functools import cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import get_safetensors_metadata, hf_hub_download
|
||||
@ -42,6 +42,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
||||
SkyworkR1VChatConfig, SolarConfig,
|
||||
Telechat2Config, UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
@ -394,7 +395,16 @@ def get_config(
|
||||
config = _maybe_remap_hf_config_attrs(config)
|
||||
|
||||
elif config_format == ConfigFormat.MISTRAL:
|
||||
config = load_params_config(model, revision, **kwargs)
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
config_dict = _download_mistral_config_file(model, revision)
|
||||
if (max_position_embeddings :=
|
||||
config_dict.get("max_position_embeddings")) is None:
|
||||
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
|
||||
model, revision, **kwargs)
|
||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||
|
||||
config = adapt_config_dict(config_dict)
|
||||
else:
|
||||
supported_formats = [
|
||||
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
|
||||
@ -693,117 +703,6 @@ def maybe_register_config_serialize_by_value() -> None:
|
||||
exc_info=e)
|
||||
|
||||
|
||||
def load_params_config(model: Union[str, Path], revision: Optional[str],
|
||||
**kwargs) -> PretrainedConfig:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
|
||||
config_file_name = "params.json"
|
||||
|
||||
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
|
||||
if config_dict is None:
|
||||
raise ValueError(
|
||||
f"Failed to load mistral '{config_file_name}' config for model "
|
||||
f"{model}. Please check if the model is a mistral-format model "
|
||||
f"and if the config file exists.")
|
||||
assert isinstance(config_dict, dict)
|
||||
|
||||
config_mapping = {
|
||||
"dim": "hidden_size",
|
||||
"norm_eps": "rms_norm_eps",
|
||||
"n_kv_heads": "num_key_value_heads",
|
||||
"n_layers": "num_hidden_layers",
|
||||
"n_heads": "num_attention_heads",
|
||||
"hidden_dim": "intermediate_size",
|
||||
}
|
||||
|
||||
def recurse_elems(elem: Any):
|
||||
if isinstance(elem, dict):
|
||||
config_dict = {}
|
||||
for key, value in elem.items():
|
||||
key = config_mapping.get(key, key)
|
||||
config_dict[key] = recurse_elems(value)
|
||||
|
||||
return config_dict
|
||||
else:
|
||||
return elem
|
||||
|
||||
config_dict["model_type"] = config_dict.get("model_type", "transformer")
|
||||
config_dict["hidden_act"] = config_dict.get("activation", "silu")
|
||||
config_dict["tie_word_embeddings"] = config_dict.get(
|
||||
"tie_embeddings", False)
|
||||
|
||||
if config_dict.get("max_position_embeddings") is None:
|
||||
max_position_embeddings = 128_000
|
||||
try:
|
||||
trust_remote_code_val = kwargs.get("trust_remote_code", False)
|
||||
hf_config = get_config(model=model,
|
||||
trust_remote_code=trust_remote_code_val,
|
||||
revision=revision,
|
||||
config_format=ConfigFormat.HF)
|
||||
if hf_value := hf_config.get_text_config().max_position_embeddings:
|
||||
max_position_embeddings = hf_value
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"The params.json file is missing 'max_position_embeddings'"
|
||||
" and could not get a value from the HF config."
|
||||
" Defaulting to 128000",
|
||||
exc_info=e)
|
||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||
|
||||
if config_dict.get("quantization") is not None:
|
||||
quantization = config_dict.get("quantization", {})
|
||||
if quantization.get("qformat_weight") == "fp8_e4m3":
|
||||
# This maps to the FP8 static per-tensor quantization scheme
|
||||
quantization_config = {
|
||||
"quant_method": "fp8",
|
||||
"activation_scheme": "static"
|
||||
}
|
||||
elif quantization.get("quant_method") == "compressed-tensors":
|
||||
# Pass through the quantization config to compressed-tensors
|
||||
quantization_config = quantization
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Found unknown quantization='{quantization}' in config")
|
||||
|
||||
config_dict["quantization_config"] = quantization_config
|
||||
|
||||
config_type: Literal["text",
|
||||
"multimodal"] = "multimodal" if config_dict.get(
|
||||
"vision_encoder") is not None else "text"
|
||||
|
||||
if config_dict.get("moe") is not None:
|
||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||
else:
|
||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||
|
||||
if config_type == "multimodal":
|
||||
multimodal_config = config_dict.pop("vision_encoder")
|
||||
quantization_config = config_dict.get("quantization_config", {})
|
||||
|
||||
config_dict = {
|
||||
"text_config": config_dict,
|
||||
"vision_config": multimodal_config
|
||||
}
|
||||
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
|
||||
config_dict["model_type"] = "pixtral"
|
||||
if quantization_config:
|
||||
config_dict["quantization_config"] = quantization_config
|
||||
|
||||
config_dict.update(kwargs)
|
||||
|
||||
config_dict = recurse_elems(config_dict)
|
||||
|
||||
# transform to HF config format
|
||||
if config_type == "multimodal":
|
||||
config_dict["text_config"] = PretrainedConfig(
|
||||
**config_dict["text_config"])
|
||||
config_dict["vision_config"] = PretrainedConfig(
|
||||
**config_dict["vision_config"])
|
||||
|
||||
return PretrainedConfig(**config_dict)
|
||||
|
||||
|
||||
def get_hf_image_processor_config(
|
||||
model: Union[str, Path],
|
||||
hf_token: Optional[Union[bool, str]] = None,
|
||||
@ -920,3 +819,35 @@ def try_get_tokenizer_config(
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _download_mistral_config_file(model, revision) -> dict:
|
||||
config_file_name = "params.json"
|
||||
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
|
||||
if config_dict is None:
|
||||
raise ValueError(
|
||||
f"Failed to load mistral '{config_file_name}' config for model "
|
||||
f"{model}. Please check if the model is a mistral-format model "
|
||||
f"and if the config file exists.")
|
||||
assert isinstance(config_dict, dict)
|
||||
return config_dict
|
||||
|
||||
|
||||
def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
|
||||
max_position_embeddings = 128_000
|
||||
try:
|
||||
trust_remote_code_val = kwargs.get("trust_remote_code", False)
|
||||
hf_config = get_config(model=model,
|
||||
trust_remote_code=trust_remote_code_val,
|
||||
revision=revision,
|
||||
config_format=ConfigFormat.HF)
|
||||
if hf_value := hf_config.get_text_config().max_position_embeddings:
|
||||
max_position_embeddings = hf_value
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"The params.json file is missing 'max_position_embeddings'"
|
||||
" and could not get a value from the HF config."
|
||||
" Defaulting to 128000",
|
||||
exc_info=e)
|
||||
|
||||
return max_position_embeddings
|
||||
|
||||
120
vllm/transformers_utils/configs/mistral.py
Normal file
120
vllm/transformers_utils/configs/mistral.py
Normal file
@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def adapt_config_dict(config_dict: dict[str, Any],
|
||||
**kwargs) -> PretrainedConfig:
|
||||
config_dict.update(kwargs)
|
||||
config_dict = _remap_general_mistral_args(config_dict)
|
||||
|
||||
if bool(config_dict.get("quantization")):
|
||||
config_dict = _remap_mistral_quantization_args(config_dict)
|
||||
|
||||
if bool(config_dict.get("moe")):
|
||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||
else:
|
||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||
|
||||
if bool(config_dict.get("yarn")):
|
||||
config_dict = _remap_mistral_yarn_args(config_dict)
|
||||
if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args")
|
||||
or config_dict.get("vision_encoder")):
|
||||
config_dict = _remap_mistral_vision_args(config_dict)
|
||||
|
||||
config = PretrainedConfig.from_dict(config_dict)
|
||||
|
||||
logger.debug("Initialized config", config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_vision_args(config: dict) -> dict:
|
||||
if config.get("multimodal"):
|
||||
vision_config = config.pop("multimodal")
|
||||
else:
|
||||
vision_config = config.pop("vision_encoder")
|
||||
|
||||
quant_config = config.get("quantization_config")
|
||||
config = {
|
||||
"model_type": "pixtral",
|
||||
"architectures": ["PixtralForConditionalGeneration"],
|
||||
"text_config": PretrainedConfig.from_dict(config),
|
||||
"vision_config": PretrainedConfig.from_dict(vision_config),
|
||||
}
|
||||
if quant_config:
|
||||
config["quantization_config"] = quant_config
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_yarn_args(config: dict) -> dict:
|
||||
# Direct remaps: yarn.X -> rope_scaling.Y
|
||||
# Source keys are from mistral.model.args.YarnArgs
|
||||
_map = {
|
||||
"beta": "beta_fast",
|
||||
"alpha": "beta_slow",
|
||||
}
|
||||
yarn_config = config.get("yarn") or {}
|
||||
renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()}
|
||||
config["rope_scaling"] = {
|
||||
"rope_type": "yarn",
|
||||
"mscale_all_dim": 1, # We hardcoded this to 1
|
||||
**renamed_yarn_config
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _remap_general_mistral_args(config: dict) -> dict:
|
||||
# Mistral key -> HF key
|
||||
config_mapping = {
|
||||
"dim": "hidden_size",
|
||||
"norm_eps": "rms_norm_eps",
|
||||
"n_kv_heads": "num_key_value_heads",
|
||||
"n_layers": "num_hidden_layers",
|
||||
"n_heads": "num_attention_heads",
|
||||
"hidden_dim": "intermediate_size",
|
||||
}
|
||||
# HF key -> (Mistral key, default value)
|
||||
top_level_mapping_with_default = {
|
||||
"model_type": ("model_type", "transformer"),
|
||||
"hidden_act": ("activation", "silu"),
|
||||
"tie_word_embeddings": ("tied_embeddings", False),
|
||||
"max_seq_len": ("max_seq_len", 128_000),
|
||||
"max_position_embeddings": ("max_position_embeddings", 128_000),
|
||||
}
|
||||
|
||||
for key, new_key in config_mapping.items():
|
||||
if key in config:
|
||||
config[new_key] = config.pop(key)
|
||||
|
||||
for new_key, (key,
|
||||
default_value) in top_level_mapping_with_default.items():
|
||||
config[new_key] = config.pop(key, default_value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_quantization_args(config: dict) -> dict:
|
||||
quantization = config.get("quantization", {})
|
||||
if quantization.get("qformat_weight") == "fp8_e4m3":
|
||||
# This maps to the FP8 static per-tensor quantization scheme
|
||||
quantization_config = {
|
||||
"quant_method": "fp8",
|
||||
"activation_scheme": "static"
|
||||
}
|
||||
elif quantization.get("quant_method") == "compressed-tensors":
|
||||
# Pass through the quantization config to compressed-tensors
|
||||
quantization_config = quantization
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Found unknown quantization='{quantization}' in config")
|
||||
|
||||
config["quantization_config"] = quantization_config
|
||||
|
||||
return config
|
||||
Reference in New Issue
Block a user