[Misc] Consolidate pooler config overrides (#10351)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-15 14:59:00 +08:00
committed by GitHub
parent 2ec8827288
commit 2ac6d0e75b
7 changed files with 141 additions and 190 deletions

View File

@ -2,6 +2,7 @@ from argparse import ArgumentTypeError
import pytest
from vllm.config import PoolerConfig
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
from vllm.utils import FlexibleArgumentParser
@ -32,9 +33,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):
def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args(["--pooling-type=MEAN"])
args = parser.parse_args([
'--override-pooler-config',
'{"pooling_type": "MEAN"}',
])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.pooling_type == 'MEAN'
assert engine_args.override_pooler_config == PoolerConfig(
pooling_type="MEAN", )
@pytest.mark.parametrize(

View File

@ -1,6 +1,8 @@
from dataclasses import asdict
import pytest
from vllm.config import ModelConfig
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@ -108,7 +110,7 @@ def test_get_sliding_window():
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
@ -119,39 +121,31 @@ def test_get_pooling_config():
revision=None,
)
minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type=None,
pooling_norm=None,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)
pooling_config = model_config._init_pooler_config(None)
assert pooling_config is not None
assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name
assert pooling_config.normalize
assert pooling_config.pooling_type == PoolingType.MEAN.name
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None)
model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None)
minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type='CLS',
pooling_norm=True,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name
pooling_config = model_config._init_pooler_config(override_config)
assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config)
@pytest.mark.skipif(current_platform.is_rocm(),