[New Model]: support GTE NewModel (#17986)

This commit is contained in:
wang.yuqi
2025-05-14 16:31:31 +08:00
committed by GitHub
parent e7ef61c1f0
commit 63ad622233
11 changed files with 279 additions and 32 deletions

View File

@ -7,6 +7,7 @@ import numpy as np
import pytest
from tests.models.utils import EmbedModelInfo
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
# Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
@ -77,16 +78,22 @@ def run_mteb_embed_task_st(model_name, tasks):
return run_mteb_embed_task(model, tasks)
def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
def mteb_test_embed_models(hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
vllm_extra_kwargs=None):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
vllm_extra_kwargs = vllm_extra_kwargs or {}
with vllm_runner(model_info.name,
task="embed",
max_model_len=None,
dtype=model_info.dtype) as vllm_model:
dtype=model_info.dtype,
**vllm_extra_kwargs) as vllm_model:
if model_info.architecture:
assert (model_info.architecture
@ -99,9 +106,9 @@ def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype)
with hf_runner(model_info.name,
is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
with set_default_torch_dtype(model_dtype) and hf_runner(
model_info.name, is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
print("VLLM:", vllm_dtype, vllm_main_score)

View File

@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test
MODELS = [
########## BertModel
EmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
dtype="float32",
enable_test=True),
EmbedModelInfo("thenlper/gte-base",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-large-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-base-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
########### NewModel
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
enable_test=True),
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-7B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=False),
########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
example_prompts) -> None:
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

View File

@ -23,6 +23,7 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@ -33,6 +34,9 @@ def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,

View File

@ -46,6 +46,7 @@ def test_models_mteb(
vllm_runner,
model_info: EmbedModelInfo,
) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@ -60,6 +61,9 @@ def test_models_correctness(
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,

View File

@ -256,11 +256,17 @@ _EMBEDDING_EXAMPLE_MODELS = {
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
"GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5",
trust_remote_code=True,
hf_overrides={"architectures":
["GteNewModel"]}),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),