[Misc] Consolidate pooler config overrides (#10351)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -2,6 +2,7 @@ from argparse import ArgumentTypeError
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -32,9 +33,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):
|
||||
|
||||
def test_valid_pooling_config():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
args = parser.parse_args(["--pooling-type=MEAN"])
|
||||
args = parser.parse_args([
|
||||
'--override-pooler-config',
|
||||
'{"pooling_type": "MEAN"}',
|
||||
])
|
||||
engine_args = EngineArgs.from_cli_args(args=args)
|
||||
assert engine_args.pooling_type == 'MEAN'
|
||||
assert engine_args.override_pooler_config == PoolerConfig(
|
||||
pooling_type="MEAN", )
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from dataclasses import asdict
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -108,7 +110,7 @@ def test_get_sliding_window():
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
minilm_model_config = ModelConfig(
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
@ -119,39 +121,31 @@ def test_get_pooling_config():
|
||||
revision=None,
|
||||
)
|
||||
|
||||
minilm_pooling_config = minilm_model_config._init_pooler_config(
|
||||
pooling_type=None,
|
||||
pooling_norm=None,
|
||||
pooling_returned_token_ids=None,
|
||||
pooling_softmax=None,
|
||||
pooling_step_tag_id=None)
|
||||
pooling_config = model_config._init_pooler_config(None)
|
||||
assert pooling_config is not None
|
||||
|
||||
assert minilm_pooling_config.pooling_norm
|
||||
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name
|
||||
assert pooling_config.normalize
|
||||
assert pooling_config.pooling_type == PoolingType.MEAN.name
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
minilm_model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None)
|
||||
model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None)
|
||||
|
||||
minilm_pooling_config = minilm_model_config._init_pooler_config(
|
||||
pooling_type='CLS',
|
||||
pooling_norm=True,
|
||||
pooling_returned_token_ids=None,
|
||||
pooling_softmax=None,
|
||||
pooling_step_tag_id=None)
|
||||
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
|
||||
assert minilm_pooling_config.pooling_norm
|
||||
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name
|
||||
pooling_config = model_config._init_pooler_config(override_config)
|
||||
assert pooling_config is not None
|
||||
assert asdict(pooling_config) == asdict(override_config)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
|
||||
Reference in New Issue
Block a user