[Frontend][Doc][5/N] Improve all pooling task | Polish encode (pooling) api & Document. (#25524)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@ -30,11 +30,11 @@ If `--runner pooling` has been set (manually or automatically) but the model doe
|
||||
vLLM will attempt to automatically convert the model according to the architecture names
|
||||
shown in the table below.
|
||||
|
||||
| Architecture | `--convert` | Supported pooling tasks |
|
||||
|-------------------------------------------------|-------------|-------------------------------|
|
||||
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `encode`, `embed` |
|
||||
| `*For*Classification`, `*ClassificationModel` | `classify` | `encode`, `classify`, `score` |
|
||||
| `*ForRewardModeling`, `*RewardModel` | `reward` | `encode` |
|
||||
| Architecture | `--convert` | Supported pooling tasks |
|
||||
|-------------------------------------------------|-------------|---------------------------------------|
|
||||
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` |
|
||||
| `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` |
|
||||
| `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` |
|
||||
|
||||
!!! tip
|
||||
You can explicitly set `--convert <type>` to specify how to convert the model.
|
||||
@ -45,12 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to
|
||||
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
|
||||
enabling the corresponding APIs:
|
||||
|
||||
| Task | APIs |
|
||||
|------------|--------------------------------------|
|
||||
| `encode` | `LLM.reward(...)` |
|
||||
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* |
|
||||
| `classify` | `LLM.classify(...)` |
|
||||
| `score` | `LLM.score(...)` |
|
||||
| Task | APIs |
|
||||
|------------------|-------------------------------------------------------------------------------|
|
||||
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\*, `LLM.encode(..., pooling_task="embed")` |
|
||||
| `classify` | `LLM.classify(...)`, `LLM.encode(..., pooling_task="classify")` |
|
||||
| `score` | `LLM.score(...)` |
|
||||
| `token_classify` | `LLM.reward(...)`, `LLM.encode(..., pooling_task="token_classify")` |
|
||||
| `token_embed` | `LLM.encode(..., pooling_task="token_embed")` |
|
||||
| `plugin` | `LLM.encode(..., pooling_task="plugin")` |
|
||||
|
||||
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
|
||||
|
||||
@ -144,7 +146,6 @@ A code example can be found here: [examples/offline_inference/basic/score.py](..
|
||||
### `LLM.reward`
|
||||
|
||||
The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
|
||||
It returns the extracted hidden states directly.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
@ -161,15 +162,17 @@ A code example can be found here: [examples/offline_inference/basic/reward.py](.
|
||||
### `LLM.encode`
|
||||
|
||||
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
|
||||
It returns the extracted hidden states directly.
|
||||
|
||||
!!! note
|
||||
Please use one of the more specific methods or set the task directly when using `LLM.encode`:
|
||||
|
||||
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
|
||||
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
|
||||
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
|
||||
- For similarity scores, use `LLM.score(...)`.
|
||||
- For rewards, use `LLM.reward(...)` or `pooling_task="token_classify"`.
|
||||
- For token classification, use `pooling_task="token_classify"`.
|
||||
- For multi-vector retrieval, use `pooling_task="token_embed"`
|
||||
- For IO Processor Plugins , use `pooling_task="plugin"`
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
@ -185,10 +188,47 @@ print(f"Data: {data!r}")
|
||||
|
||||
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:
|
||||
|
||||
- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
|
||||
- [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models.
|
||||
- [Classification API](../serving/openai_compatible_server.md#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
|
||||
- [Score API](../serving/openai_compatible_server.md#score-api) is similar to `LLM.score` for cross-encoder models.
|
||||
- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
|
||||
|
||||
!!! note
|
||||
Please use one of the more specific methods or set the task directly when using [Pooling API](../serving/openai_compatible_server.md#pooling-api) api.:
|
||||
|
||||
- For embeddings, use [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) or `"task":"embed"`.
|
||||
- For classification logits, use [Classification API](../serving/openai_compatible_server.md#classification-api) or `task":"classify"`.
|
||||
- For similarity scores, use [Score API](../serving/openai_compatible_server.md#score-api).
|
||||
- For rewards, `task":"token_classify"`.
|
||||
- For token classification, use `task":"token_classify"`.
|
||||
- For multi-vector retrieval, use `task":"token_embed"`
|
||||
- For IO Processor Plugins , use `task":"plugin"`
|
||||
|
||||
```python
|
||||
# start a supported embeddings model server with `vllm serve`, e.g.
|
||||
# vllm serve intfloat/e5-small
|
||||
import requests
|
||||
|
||||
host = "localhost"
|
||||
port = "8000"
|
||||
model_name = "intfloat/e5-small"
|
||||
|
||||
api_url = f"http://{host}:{port}/pooling"
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompt = {"model": model_name, "input": prompts, "task": "embed"}
|
||||
|
||||
response = requests.post(api_url, json=prompt)
|
||||
|
||||
for output in response.json()["data"]:
|
||||
data = output["data"]
|
||||
print(f"Data: {data!r} (size={len(data)})")
|
||||
```
|
||||
|
||||
## Matryoshka Embeddings
|
||||
|
||||
@ -265,3 +305,16 @@ Expected output:
|
||||
```
|
||||
|
||||
An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py)
|
||||
|
||||
## Deprecated Features
|
||||
|
||||
### Encode task
|
||||
|
||||
We have split the `encode` task into two more specific token wise tasks: `token_embed` and `token_classify`:
|
||||
|
||||
- `token_embed` is the same as embed, using normalize as activation.
|
||||
- `token_classify` is the same as classify, default using softmax as activation.
|
||||
|
||||
### Remove softmax from PoolingParams
|
||||
|
||||
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, you should set `use_activation`, since we actually allow `classify` and `token_classify` to use any activation function.
|
||||
|
||||
Reference in New Issue
Block a user