Compare commits
6 Commits
gpu_ids2
...
fix-doc-bu
| Author | SHA1 | Date | |
|---|---|---|---|
| 1db4b78a13 | |||
| fdadb6f43a | |||
| 41060c6e08 | |||
| 3de2ed767f | |||
| 299252ea82 | |||
| d6902ce79f |
@ -272,3 +272,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Default LoRA Models For Multimodal Models
|
||||
|
||||
Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.
|
||||
|
||||
To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.
|
||||
|
||||
Example usage for offline inference:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
model_id = "ibm-granite/granite-speech-3.3-2b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def get_prompt(question: str, has_audio: bool):
|
||||
"""Build the input prompt to send to vLLM."""
|
||||
if has_audio:
|
||||
question = f"<|audio|>{question}"
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
}
|
||||
]
|
||||
return tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
|
||||
|
||||
model = LLM(
|
||||
model=model_id,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_model_len=2048,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
# Will always pass a `LoRARequest` with the `model_id`
|
||||
# whenever audio is contained in the request data.
|
||||
default_mm_loras = {"audio": model_id},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
question = "can you transcribe the speech into a written format?"
|
||||
prompt_with_audio = get_prompt(
|
||||
question=question,
|
||||
has_audio=True,
|
||||
)
|
||||
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
|
||||
inputs = {
|
||||
"prompt": prompt_with_audio,
|
||||
"multi_modal_data": {
|
||||
"audio": audio,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0.2,
|
||||
max_tokens=64,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:
|
||||
|
||||
```bash
|
||||
vllm serve ibm-granite/granite-speech-3.3-2b \
|
||||
--max-model-len 2048 \
|
||||
--enable-lora \
|
||||
--default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
|
||||
--max-lora-rank 64
|
||||
```
|
||||
|
||||
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
|
||||
|
||||
@ -13,6 +13,7 @@ ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse"
|
||||
sys.path.insert(0, str(ROOT_DIR))
|
||||
sys.modules["aiohttp"] = MagicMock()
|
||||
sys.modules["blake3"] = MagicMock()
|
||||
sys.modules["gguf"] = MagicMock()
|
||||
sys.modules["vllm._C"] = MagicMock()
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
|
||||
|
||||
@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.11, < 0.11
|
||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||
outlines == 0.1.11
|
||||
outlines_core == 0.2.10
|
||||
# required for outlines backend disk cache
|
||||
diskcache == 5.6.3
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
typing_extensions >= 4.10
|
||||
|
||||
@ -16,14 +16,18 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
GUIDED_DECODING_BACKENDS = [
|
||||
|
||||
# Separate backends which support grammars vs ones
|
||||
# which only support regex based constraints in tests.
|
||||
GRAMMAR_DECODING_BACKENDS = [
|
||||
# (backend, disable_any_whitespace),
|
||||
("outlines", False),
|
||||
("lm-format-enforcer", False),
|
||||
("xgrammar", True),
|
||||
("guidance", True),
|
||||
]
|
||||
|
||||
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
@ -39,7 +43,7 @@ def llm():
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
regex=sample_regex,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion(sample_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_grammar(sample_sql_statements, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
# A list is not what was intended, but is still valid
|
||||
# json.
|
||||
assert isinstance(parsed_json, (dict, list))
|
||||
|
||||
|
||||
class CarType(str, Enum):
|
||||
@ -395,7 +402,7 @@ class CarDescription(BaseModel):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sample_output_schema = {
|
||||
|
||||
107
tests/entrypoints/openai/test_default_mm_loras.py
Normal file
107
tests/entrypoints/openai/test_default_mm_loras.py
Normal file
@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from ...conftest import AudioTestAssets
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# NOTE - the tests in this module are currently analogous to test_chat, but are
|
||||
# separated to avoid OOM killing due to module-scoped servers, since we
|
||||
# need a multimodal model for these tests.
|
||||
|
||||
# Contains a modality specific lora alongside the base model
|
||||
MULTIMODAL_MODEL_NAME = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct")
|
||||
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
|
||||
|
||||
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def multimodal_server(request, monkeypatch_module): # noqa: F811
|
||||
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"half",
|
||||
"--max-model-len",
|
||||
"12800",
|
||||
"--enforce-eager",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"speech={AUDIO_LORA_PATH}",
|
||||
"--max-lora-rank",
|
||||
"320",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
"--trust-remote-code",
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
"--default-mm-loras",
|
||||
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def multi_modal_client(multimodal_server):
|
||||
async with multimodal_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# base model with default lora should give the same response as lora model
|
||||
"model_name",
|
||||
[MULTIMODAL_MODEL_NAME, "speech"],
|
||||
)
|
||||
async def test_default_mm_lora_chat_completions(
|
||||
model_name: str,
|
||||
multi_modal_client: openai.AsyncOpenAI,
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Can you transcribe this audio?",
|
||||
}, {
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_assets[0].url
|
||||
},
|
||||
}]
|
||||
}]
|
||||
|
||||
chat_completion = await multi_modal_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=128,
|
||||
temperature=0.0)
|
||||
|
||||
assert len(chat_completion.choices) > 0
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
assert message.content == ACTIVE_MM_LORA_RESPONSE
|
||||
140
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
140
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests compute_expert_num_tokens kernels
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor] = None
|
||||
|
||||
def to_device(self, device: str):
|
||||
self.topk_ids = self.topk_ids.to(device=device)
|
||||
if self.expert_map is not None:
|
||||
self.expert_map = self.expert_map.to(device=device)
|
||||
|
||||
@staticmethod
|
||||
def make(num_tokens: int, num_topk: int, num_experts: int, device: str,
|
||||
topk_ids_dtype: torch.dtype) -> "TestTensors":
|
||||
|
||||
# make topk ids
|
||||
topk_ids = torch.empty((num_tokens, num_topk),
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
for x in range(num_tokens):
|
||||
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
|
||||
topk_ids = topk_ids.to(dtype=torch.int64)
|
||||
return TestTensors(topk_ids=topk_ids)
|
||||
|
||||
def with_ep_rank(self, ep_rank: int, num_global_experts: int,
|
||||
num_local_experts: int, device: str):
|
||||
# make an expert map
|
||||
expert_map = torch.empty((num_global_experts),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_map.fill_(-1)
|
||||
s = ep_rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)),
|
||||
device=device)
|
||||
|
||||
return TestTensors(topk_ids=self.topk_ids.clone(),
|
||||
expert_map=expert_map)
|
||||
|
||||
|
||||
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
|
||||
# do the reference in cpu
|
||||
tt.to_device("cpu")
|
||||
expert_ids, counts = tt.topk_ids.unique(return_counts=True)
|
||||
|
||||
for eid, count in zip(expert_ids, counts):
|
||||
if eid != -1 and tt.expert_map is not None:
|
||||
eid = tt.expert_map[eid]
|
||||
|
||||
if eid == -1:
|
||||
continue
|
||||
|
||||
expert_num_tokens[eid] += count
|
||||
|
||||
|
||||
def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
num_experts: int, ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
|
||||
assert num_topk <= num_experts
|
||||
|
||||
tt = TestTensors.make(num_tokens,
|
||||
num_topk,
|
||||
num_experts,
|
||||
topk_ids_dtype=topk_ids_dtype,
|
||||
device="cpu")
|
||||
|
||||
num_global_experts = num_experts
|
||||
assert num_global_experts % ep_size == 0
|
||||
num_local_experts = num_global_experts // ep_size
|
||||
for ep_rank in range(ep_size):
|
||||
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts,
|
||||
num_local_experts, "cpu")
|
||||
|
||||
ref_expert_num_tokens = torch.zeros((num_local_experts),
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
ref_impl(tt_rank, ref_expert_num_tokens)
|
||||
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
|
||||
|
||||
tt_rank.to_device("cuda")
|
||||
# Test with expert_map
|
||||
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
|
||||
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map)
|
||||
|
||||
# Test without expert map
|
||||
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
|
||||
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
|
||||
topk_ids, num_local_experts, expert_map=None)
|
||||
|
||||
torch.testing.assert_close(ref_expert_num_tokens,
|
||||
triton_expert_num_tokens_w_emap,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(ref_expert_num_tokens,
|
||||
triton_expert_num_tokens_wo_emap,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317])
|
||||
@pytest.mark.parametrize("num_topk", [2, 6, 8])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
num_experts: int, ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts,
|
||||
ep_size, topk_ids_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("numel", list(range(1, 8192, 11)))
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("ep_size", [2])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
do_test_compute_expert_num_tokens(num_tokens=numel,
|
||||
num_topk=1,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
topk_ids_dtype=topk_ids_dtype)
|
||||
118
tests/lora/test_default_mm_loras.py
Normal file
118
tests/lora/test_default_mm_loras.py
Normal file
@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for applying default registered multimodal loras.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..conftest import AudioTestAssets, VllmRunner
|
||||
|
||||
MODEL_PATH = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||
AUDIO_LORA_PATH = os.path.join(MODEL_PATH, "speech-lora")
|
||||
IMAGE_LORA_PATH = os.path.join(MODEL_PATH, "vision-lora")
|
||||
|
||||
AUDIO_PROMPT = "<|user|><|audio_1|>Can you transcribe this audio?<|end|><|assistant|>" # noqa: E501
|
||||
|
||||
# Responses are greedy decoded; we just check the end of
|
||||
# the generated text. If the lora is inactive, this model
|
||||
# generates commentary on the transcription.
|
||||
RESPONSE_SUFFIX_WITH_LORA = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||
RESPONSE_SUFFIX_WITHOUT_LORA = "Certainly! Here is the transcription of the audio you provided:\n\nThe first words I spoke in the original phonograph record: A little piece of practical poetry. Mary had a little lamb; its fleece was white as snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||
|
||||
VLLM_RUNNER_BASE_KWARGS = {
|
||||
"model_name": MODEL_PATH,
|
||||
"dtype": "half",
|
||||
"enable_lora": "True",
|
||||
"max_num_seqs": 2,
|
||||
"max_lora_rank": 320,
|
||||
"max_model_len": 12800,
|
||||
"gpu_memory_utilization": 0.8,
|
||||
"limit_mm_per_prompt": {
|
||||
"audio": 1
|
||||
},
|
||||
"enforce_eager": True,
|
||||
}
|
||||
|
||||
|
||||
def run_test(vllm_runner, audio_assets, lora_request, expected_suffix,
|
||||
**kwargs):
|
||||
inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])]
|
||||
|
||||
# Apply any additional kwargs as overrides to the base kwargs
|
||||
vllm_runner_kwargs = {**VLLM_RUNNER_BASE_KWARGS, **kwargs}
|
||||
|
||||
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||
vllm_outputs_with_default_lora = [
|
||||
vllm_model.generate_greedy(
|
||||
prompts,
|
||||
max_tokens=128,
|
||||
audios=audios,
|
||||
lora_request=lora_request,
|
||||
) for prompts, audios in inputs
|
||||
]
|
||||
|
||||
assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(
|
||||
expected_suffix)
|
||||
|
||||
|
||||
def test_active_default_mm_lora(
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
"""Ensure that we can use the default audio lora."""
|
||||
run_test(
|
||||
vllm_runner,
|
||||
audio_assets,
|
||||
lora_request=None,
|
||||
default_mm_loras={"audio": AUDIO_LORA_PATH},
|
||||
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||
)
|
||||
|
||||
|
||||
def test_inactive_default_mm_lora(
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
"""Ensure that modalities are filtered properly."""
|
||||
# Default image lora won't be active since we only pass audio
|
||||
run_test(
|
||||
vllm_runner,
|
||||
audio_assets,
|
||||
lora_request=None,
|
||||
default_mm_loras={"image": IMAGE_LORA_PATH},
|
||||
expected_suffix=RESPONSE_SUFFIX_WITHOUT_LORA,
|
||||
)
|
||||
|
||||
|
||||
def test_default_mm_lora_succeeds_with_redundant_lora_request(
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
"""Ensure that redundantly providing the lora works."""
|
||||
run_test(
|
||||
vllm_runner,
|
||||
audio_assets,
|
||||
lora_request=LoRARequest("audio", 1, AUDIO_LORA_PATH),
|
||||
default_mm_loras={"audio": AUDIO_LORA_PATH},
|
||||
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||
)
|
||||
|
||||
|
||||
def test_default_mm_lora_fails_with_overridden_lora_request(
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
"""Ensure that if the lora_request conflicts with default_mm_loras,
|
||||
we use the lora_request."""
|
||||
run_test(
|
||||
vllm_runner,
|
||||
audio_assets,
|
||||
lora_request=LoRARequest("speech", 2, AUDIO_LORA_PATH),
|
||||
default_mm_loras={"audio": IMAGE_LORA_PATH},
|
||||
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||
)
|
||||
@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
|
||||
whitespace_pattern=None,
|
||||
reasoner=None)
|
||||
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
regex_LP(token_ids, tensor)
|
||||
tensor = regex_LP([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
json_LP(token_ids, tensor)
|
||||
tensor = json_LP([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(
|
||||
@ -92,13 +85,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
# allowed tokens at state 0
|
||||
tensor = regex_lp([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
@ -106,7 +97,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
tensor = json_lp([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
dtype="bfloat16",
|
||||
)
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}."
|
||||
"<think>here is the thinking process")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
regex_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||
"<think>here is the thinking process")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
@ -158,7 +147,7 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
@ -166,8 +155,7 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
|
||||
# Thinking is over, so the tensor should change.
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||
"<think>here is the thinking process</think> Then")
|
||||
"<think>here is the thinking process</think>")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = get_local_guided_decoding_logits_processor(
|
||||
@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning(
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
|
||||
@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
from outlines_core.json_schema import build_regex_from_schema
|
||||
regex = build_regex_from_schema(json.dumps(schema))
|
||||
compiled = re.compile(regex)
|
||||
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
||||
|
||||
@ -41,6 +41,10 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
|
||||
NGRAM_SPEC_CONFIG),
|
||||
#FIXME: This test is flaky on CI thus disabled
|
||||
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
|
||||
@ -106,13 +110,15 @@ def test_structured_output(
|
||||
enforce_eager = bool(not current_platform.is_tpu())
|
||||
# Use a single LLM instance for several scenarios to
|
||||
# speed up the test suite.
|
||||
llm = LLM(model=model_name,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
|
||||
#
|
||||
# Test 1: Generate JSON output based on a provided schema
|
||||
@ -146,32 +152,33 @@ def test_structured_output(
|
||||
#
|
||||
# Test 2: Generate JSON object without a schema
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||
if guided_decoding_backend != "outlines":
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(prompts=(
|
||||
"Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
|
||||
# Parse to verify it is a valid JSON object
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
# Parse to verify it is a valid JSON object
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
#
|
||||
# Test 3: test a jsonschema incompatible with xgrammar
|
||||
@ -210,97 +217,98 @@ def test_structured_output(
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
#
|
||||
# Test 4: Generate SQL statement using EBNF grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 5: Generate SQL statement using Lark grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_lark)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 6: Test invalid grammar input
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
if guided_decoding_backend != "outlines":
|
||||
#
|
||||
# Test 4: Generate SQL statement using EBNF grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short "
|
||||
"as possible."),
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 5: Generate SQL statement using Lark grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_lark)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 6: Test invalid grammar input
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
prompts=
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short "
|
||||
"as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
#
|
||||
# Test 7: Generate text based on a regex pattern
|
||||
#
|
||||
@ -421,35 +429,36 @@ def test_structured_output(
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
structural_tag_config = {
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
if guided_decoding_backend != "outlines":
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
structural_tag_config = {
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": False
|
||||
},
|
||||
"additionalProperties": False
|
||||
},
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
}
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
}
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
|
||||
prompt = """
|
||||
prompt = """
|
||||
You have access to the following function to retrieve the weather in a city:
|
||||
|
||||
{
|
||||
@ -469,7 +478,7 @@ where
|
||||
|
||||
start_tag => `<function`
|
||||
parameters => a JSON dict with the function argument name
|
||||
as key and function argument value as value.
|
||||
as key and function argument value as value.
|
||||
end_tag => `</function>`
|
||||
|
||||
Here is an example,
|
||||
@ -488,37 +497,37 @@ Given the previous instructions, what is the weather in New York City? \
|
||||
Make the response as short as possible.
|
||||
"""
|
||||
|
||||
# Change this once other backends support structural_tag
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
# Change this once other backends support structural_tag
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# Search for function call pattern in the response
|
||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||
matches = re.findall(function_call_pattern, generated_text)
|
||||
# Search for function call pattern in the response
|
||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||
matches = re.findall(function_call_pattern, generated_text)
|
||||
|
||||
if not matches:
|
||||
print(f"Warning: No function calls found in response: "
|
||||
f"{generated_text!r}")
|
||||
continue
|
||||
if not matches:
|
||||
print(f"Warning: No function calls found in response: "
|
||||
f"{generated_text!r}")
|
||||
continue
|
||||
|
||||
# Take the first function call if multiple are found
|
||||
json_str = matches[0]
|
||||
try:
|
||||
json_content = json.loads(json_str)
|
||||
assert "city" in json_content
|
||||
assert isinstance(json_content["city"], str)
|
||||
print(f"Found valid function call: {generated_text!r}")
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
pytest.fail("Invalid function call format: "
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
# Take the first function call if multiple are found
|
||||
json_str = matches[0]
|
||||
try:
|
||||
json_content = json.loads(json_str)
|
||||
assert "city" in json_content
|
||||
assert isinstance(json_content["city"], str)
|
||||
print(f"Found valid function call: {generated_text!r}")
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
pytest.fail("Invalid function call format: "
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
||||
@ -33,6 +33,7 @@ import vllm.envs as envs
|
||||
from vllm import version
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
@ -2989,6 +2990,16 @@ class LoRAConfig:
|
||||
trained with those scaling factors to be used at the same time. If not
|
||||
specified, only adapters trained with the base model scaling factor are
|
||||
allowed."""
|
||||
default_mm_loras: Optional[dict[str, str]] = None
|
||||
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
||||
is only applicable to multimodal models and should be leveraged when a
|
||||
model always expects a LoRA to be active when a given modality is present.
|
||||
Note that currently, if a request provides multiple additional
|
||||
modalities, each of which have their own LoRA, we do NOT apply
|
||||
default_mm_loras because we currently only support one lora adapter
|
||||
per prompt. When run in offline mode, the lora IDs for n modalities
|
||||
will be automatically assigned to 1-n with the names of the modalities
|
||||
in alphabetic order."""
|
||||
bias_enabled: bool = False
|
||||
"""Enable bias for LoRA adapters."""
|
||||
|
||||
@ -3580,7 +3591,8 @@ def get_served_model_name(model: str,
|
||||
|
||||
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||
"xgrammar", "guidance"]
|
||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
||||
|
||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
|
||||
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
||||
GuidedDecodingBackendV1]
|
||||
|
||||
|
||||
@ -395,6 +395,8 @@ class EngineArgs:
|
||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||
max_loras: int = LoRAConfig.max_loras
|
||||
max_lora_rank: int = LoRAConfig.max_lora_rank
|
||||
default_mm_loras: Optional[Dict[str, str]] = \
|
||||
LoRAConfig.default_mm_loras
|
||||
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||
@ -807,6 +809,8 @@ class EngineArgs:
|
||||
**lora_kwargs["max_cpu_loras"])
|
||||
lora_group.add_argument("--fully-sharded-loras",
|
||||
**lora_kwargs["fully_sharded_loras"])
|
||||
lora_group.add_argument("--default-mm-loras",
|
||||
**lora_kwargs["default_mm_loras"])
|
||||
|
||||
# PromptAdapter related configs
|
||||
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
|
||||
@ -1284,10 +1288,16 @@ class EngineArgs:
|
||||
disable_hybrid_kv_cache_manager,
|
||||
)
|
||||
|
||||
if not model_config.is_multimodal_model and self.default_mm_loras:
|
||||
raise ValueError(
|
||||
"Default modality-specific LoRA(s) were provided for a "
|
||||
"non multimodal model")
|
||||
|
||||
lora_config = LoRAConfig(
|
||||
bias_enabled=self.enable_lora_bias,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
max_loras=self.max_loras,
|
||||
default_mm_loras=self.default_mm_loras,
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
||||
|
||||
@ -499,6 +499,10 @@ class LLM:
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
# Add any modality specific loras to the corresponding prompts
|
||||
lora_request = self._get_modality_specific_lora_reqs(
|
||||
parsed_prompts, lora_request)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=sampling_params,
|
||||
@ -513,6 +517,83 @@ class LLM:
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def _get_modality_specific_lora_reqs(
|
||||
self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
|
||||
# Grab the lora config off the vllm config on the engine,
|
||||
# since this is the same for both v0 & v1.
|
||||
lora_config = self.llm_engine.vllm_config.lora_config
|
||||
|
||||
# If there's no lora config / default_mm_loras, or the model
|
||||
# isn't multimodal, leave the lora as is.
|
||||
if (lora_config is None
|
||||
or not self.llm_engine.model_config.is_multimodal_model
|
||||
or (lora_config and lora_config.default_mm_loras is None)):
|
||||
return lora_request
|
||||
|
||||
if not isinstance(parsed_prompts, Sequence):
|
||||
parsed_prompts = [parsed_prompts]
|
||||
|
||||
optional_loras = ([lora_request] * len(parsed_prompts)
|
||||
if not isinstance(lora_request, Sequence) else
|
||||
lora_request)
|
||||
|
||||
return [
|
||||
self._resolve_single_prompt_mm_lora(
|
||||
parsed_prompt,
|
||||
opt_lora_req,
|
||||
lora_config.default_mm_loras,
|
||||
) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
|
||||
optional_loras)
|
||||
]
|
||||
|
||||
def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
|
||||
lora_request: Optional[LoRARequest],
|
||||
default_mm_loras: Optional[dict[str,
|
||||
str]]):
|
||||
if (not default_mm_loras or not isinstance(parsed_prompt, dict)
|
||||
or "multi_modal_data" not in parsed_prompt):
|
||||
return lora_request
|
||||
|
||||
parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)
|
||||
|
||||
intersection = set(
|
||||
parsed_prompt["multi_modal_data"].keys()).intersection(
|
||||
default_mm_loras.keys())
|
||||
if not intersection:
|
||||
return lora_request
|
||||
if len(intersection) > 1:
|
||||
# TODO: Would be nice to be able to have multiple loras per prompt
|
||||
logger.warning(
|
||||
"Multiple modality specific loras were registered and would be"
|
||||
" used by a single prompt consuming several modalities; "
|
||||
" currently we only support one lora per request; as such,"
|
||||
" lora(s) registered with modalities: %s"
|
||||
" will be skipped", intersection)
|
||||
return lora_request
|
||||
|
||||
# Build the LoRA request; the ID of the default mm lora is the
|
||||
# index of the modality name sorted alphabetically + 1.
|
||||
modality_name = intersection.pop()
|
||||
modality_lora_path = default_mm_loras[modality_name]
|
||||
modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1
|
||||
|
||||
# If we have a collision, warn if there is a collision,
|
||||
# but always send the explicitly provided request.
|
||||
if lora_request:
|
||||
if lora_request.lora_int_id != modality_lora_id:
|
||||
logger.warning(
|
||||
"A modality with a registered lora and a lora_request "
|
||||
"with a different ID were provided; falling back to the "
|
||||
"lora_request as we only apply one LoRARequest per prompt")
|
||||
return lora_request
|
||||
|
||||
return LoRARequest(
|
||||
modality_name,
|
||||
modality_lora_id,
|
||||
modality_lora_path,
|
||||
)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
|
||||
@ -87,6 +87,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||
@ -1481,11 +1482,28 @@ async def init_app_state(
|
||||
"This discrepancy may lead to performance degradation.",
|
||||
resolved_chat_template, args.model)
|
||||
|
||||
# Merge default_mm_loras into the static lora_modules
|
||||
default_mm_loras = (vllm_config.lora_config.default_mm_loras
|
||||
if vllm_config.lora_config is not None else {})
|
||||
|
||||
lora_modules = args.lora_modules
|
||||
if default_mm_loras:
|
||||
default_mm_lora_paths = [
|
||||
LoRAModulePath(
|
||||
name=modality,
|
||||
path=lora_path,
|
||||
) for modality, lora_path in default_mm_loras.items()
|
||||
]
|
||||
if args.lora_modules is None:
|
||||
lora_modules = default_mm_lora_paths
|
||||
else:
|
||||
lora_modules += default_mm_lora_paths
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
|
||||
@ -153,7 +153,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
) = self._maybe_get_adapters(request,
|
||||
supports_default_mm_loras=True)
|
||||
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
|
||||
|
||||
@ -458,20 +458,74 @@ class OpenAIServing:
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _get_active_default_mm_loras(
|
||||
self, request: AnyRequest) -> Optional[LoRARequest]:
|
||||
"""Determine if there are any active default multimodal loras."""
|
||||
# TODO: Currently this is only enabled for chat completions
|
||||
# to be better aligned with only being enabled for .generate
|
||||
# when run offline. It would be nice to support additional
|
||||
# tasks types in the future.
|
||||
message_types = self._get_message_types(request)
|
||||
default_mm_loras = set()
|
||||
|
||||
for lora in self.models.lora_requests.values():
|
||||
# Best effort match for default multimodal lora adapters;
|
||||
# There is probably a better way to do this, but currently
|
||||
# this matches against the set of 'types' in any content lists
|
||||
# up until '_', e.g., to match audio_url -> audio
|
||||
if lora.lora_name in message_types:
|
||||
default_mm_loras.add(lora)
|
||||
|
||||
# Currently only support default modality specific loras if
|
||||
# we have exactly one lora matched on the request.
|
||||
if len(default_mm_loras) == 1:
|
||||
return default_mm_loras.pop()
|
||||
return None
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self, request: AnyRequest
|
||||
self,
|
||||
request: AnyRequest,
|
||||
supports_default_mm_loras: bool = False,
|
||||
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
|
||||
None, PromptAdapterRequest]]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
|
||||
if request.model in self.models.lora_requests:
|
||||
return self.models.lora_requests[request.model], None
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
return default_mm_lora, None
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
|
||||
for prompt_adapter in self.models.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return None, prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _get_message_types(self, request: AnyRequest) -> set[str]:
|
||||
"""Retrieve the set of types from message content dicts up
|
||||
until `_`; we use this to match potential multimodal data
|
||||
with default per modality loras.
|
||||
"""
|
||||
message_types: set[str] = set()
|
||||
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
for message in request.messages:
|
||||
if (isinstance(message, dict) and "content" in message
|
||||
and isinstance(message["content"], list)):
|
||||
for content_dict in message["content"]:
|
||||
if "type" in content_dict:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
async def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
|
||||
@ -216,8 +216,8 @@ class ServingScores(OpenAIServing):
|
||||
# cross_encoder models defaults to using pad_token.
|
||||
tokenized_prompts = await asyncio.gather(*(
|
||||
tokenize_async(
|
||||
text=t1, # type: ignore[arg-type]
|
||||
text_pair=t2, # type: ignore[arg-type]
|
||||
text=t1, # type: ignore[arg-type]
|
||||
text_pair=t2, # type: ignore[arg-type]
|
||||
**tokenization_kwargs) for t1, t2 in input_pairs))
|
||||
else:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
|
||||
11
vllm/envs.py
11
vllm/envs.py
@ -117,6 +117,7 @@ if TYPE_CHECKING:
|
||||
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_V1_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
@ -338,10 +339,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"CUDA_VISIBLE_DEVICES":
|
||||
lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
|
||||
|
||||
# used to control the visible devices in the distributed setting
|
||||
"VLLM_VISIBLE_DEVICES":
|
||||
lambda: os.environ.get("VLLM_VISIBLE_DEVICES", None),
|
||||
|
||||
# timeout for each iteration in the engine
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S":
|
||||
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")),
|
||||
@ -851,6 +848,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_V0_USE_OUTLINES_CACHE":
|
||||
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
||||
|
||||
# Whether to turn on the outlines cache for V1
|
||||
# This cache is unbounded and on disk, so it's not safe to use in
|
||||
# an environment with potentially malicious users.
|
||||
"VLLM_V1_USE_OUTLINES_CACHE":
|
||||
lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1",
|
||||
|
||||
# Gap between padding buckets for the forward pass. So we have
|
||||
# 8, we will run forward pass with [16, 24, 32, ...].
|
||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||
|
||||
@ -79,20 +79,33 @@ def maybe_backend_fallback(
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar does not support Lark grammars and the "
|
||||
"grammar failed to convert to GBNF.", "outlines")
|
||||
"grammar failed to convert to GBNF.", "guidance")
|
||||
|
||||
# If the xgrammar module cannot be imported successfully,
|
||||
# we should still allow users to use guided decoding with a fallback.
|
||||
elif not xgr_installed:
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar module cannot be imported successfully.", "outlines")
|
||||
"xgrammar module cannot be imported successfully.", "guidance")
|
||||
|
||||
if (guided_params.backend == "outlines"
|
||||
and guided_params.json_object is not None):
|
||||
# outlines doesn't support json_object, fallback to guidance
|
||||
fallback_or_error(guided_params,
|
||||
"outlines does not support json_object.", "guidance")
|
||||
if guided_params.backend == "outlines":
|
||||
if guided_params.json_object is not None:
|
||||
# outlines doesn't support json_object, fallback to guidance
|
||||
fallback_or_error(guided_params,
|
||||
"outlines does not support json_object.",
|
||||
"guidance")
|
||||
elif guided_params.grammar is not None:
|
||||
# outlines grammar support has been removed, fallback to guidance
|
||||
# if it is a lark-based grammar and xgrammar otherwise
|
||||
if grammar_is_likely_lark(guided_params.grammar):
|
||||
fallback_or_error(guided_params,
|
||||
"outlines no longer supports grammars.",
|
||||
"guidance")
|
||||
else:
|
||||
# The grammar is likely already GBNF format.
|
||||
fallback_or_error(guided_params,
|
||||
"outlines no longer supports grammars.",
|
||||
"xgrammar")
|
||||
|
||||
return guided_params
|
||||
|
||||
@ -111,7 +124,6 @@ async def get_guided_decoding_logits_processor(
|
||||
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
@ -152,7 +164,6 @@ def get_local_guided_decoding_logits_processor(
|
||||
reasoning_backend)
|
||||
reasoner = reasoner_class(tokenizer)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
|
||||
@ -12,7 +12,7 @@ from regex import escape as regex_escape
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
@ -21,36 +21,8 @@ class GuidedDecodingMode(Enum):
|
||||
JSON = "json"
|
||||
REGEX = "regex"
|
||||
CHOICE = "choice"
|
||||
GRAMMAR = "grammar"
|
||||
|
||||
|
||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
|
||||
# the main difference is that we changed the start: value to
|
||||
# start: object | array, so we are denying scalar values as the root of the
|
||||
# JSON. Starting with scalars as the root seems to cause llama to generate
|
||||
# without stop.
|
||||
JSON_GRAMMAR = r"""
|
||||
?start: object | array
|
||||
|
||||
?value: object
|
||||
| array
|
||||
| UNESCAPED_STRING
|
||||
| SIGNED_NUMBER -> number
|
||||
| "true" -> true
|
||||
| "false" -> false
|
||||
| "null" -> null
|
||||
|
||||
array : "[" [value ("," value)*] "]"
|
||||
object : "{" [pair ("," pair)*] "}"
|
||||
pair : UNESCAPED_STRING ":" value
|
||||
|
||||
%import common.UNESCAPED_STRING
|
||||
%import common.SIGNED_NUMBER
|
||||
%import common.WS
|
||||
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
# It's not yet clear that using more provides a benefit, and it could
|
||||
@ -60,16 +32,12 @@ _MAX_THREADPOOL_WORKERS = 16
|
||||
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
global global_thread_pool
|
||||
guide, mode = _get_guide_and_mode(guided_params)
|
||||
@ -83,7 +51,6 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
return await loop.run_in_executor(global_thread_pool,
|
||||
_get_logits_processor, guide, tokenizer,
|
||||
mode, guided_params.whitespace_pattern,
|
||||
@ -91,16 +58,12 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
|
||||
|
||||
def get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
guide, mode = _get_guide_and_mode(guided_params)
|
||||
if not guide or not mode:
|
||||
@ -130,9 +93,10 @@ def _get_guide_and_mode(
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
return choices_regex, GuidedDecodingMode.CHOICE
|
||||
elif guided_params.grammar:
|
||||
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
|
||||
elif guided_params.json_object:
|
||||
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
||||
raise ValueError(
|
||||
"The `outlines` guided decoding backend no longer supports grammar "
|
||||
"guided generation. Please use either the `xgrammar` or `guidance` "
|
||||
"backend")
|
||||
else:
|
||||
return None, None
|
||||
|
||||
@ -143,13 +107,11 @@ def _get_logits_processor(
|
||||
mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
|
||||
reasoner)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
return RegexLogitsProcessor(guide, tokenizer, reasoner)
|
||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||
return CFGLogitsProcessor(guide, tokenizer, reasoner)
|
||||
else:
|
||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||
|
||||
@ -1,168 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright 2024-present the Outlines developers
|
||||
from __future__ import annotations
|
||||
|
||||
# Copyright 2024- the Outlines developers
|
||||
# This file is adapted from
|
||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, Union
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import torch
|
||||
from outlines import grammars
|
||||
from outlines.caching import cache, disable_cache
|
||||
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
|
||||
RegexGuide, Write)
|
||||
from outlines.fsm.parsing import PartialLark
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
from cachetools import LRUCache
|
||||
from diskcache import Cache
|
||||
from outlines_core import Guide, Index, Vocabulary
|
||||
from outlines_core.json_schema import build_regex_from_schema
|
||||
from outlines_core.kernels.torch import (_apply_token_bitmask_inplace_kernel,
|
||||
allocate_token_bitmask)
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if envs.VLLM_V0_USE_OUTLINES_CACHE:
|
||||
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
|
||||
"cache. It may consume a lot of disk space and should "
|
||||
"not be used with untrusted clients.")
|
||||
else:
|
||||
disable_cache()
|
||||
CACHE = None
|
||||
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
|
||||
def __init__(self, guide: Guide, eos_token_id: int,
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
self._guide: Guide = guide
|
||||
self._eos_token_id: int = eos_token_id
|
||||
self._reasoner: Optional[ReasoningParser] = reasoner
|
||||
# CFGState is used for the FSM state for CFGGuide
|
||||
self._fsm_state: defaultdict[int, Union[int,
|
||||
CFGState]] = defaultdict(int)
|
||||
|
||||
def clone(self) -> "BaseLogitsProcessor":
|
||||
cloned = copy.copy(self)
|
||||
cloned._guide = self._guide.copy()
|
||||
cloned._fsm_state = copy.deepcopy(self._fsm_state)
|
||||
return cloned
|
||||
self._mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, input_ids: list[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
if self._mask is None:
|
||||
self._mask = allocate_token_bitmask(scores.size(-1))
|
||||
|
||||
# Skip the structured logits processing if reasoning is not finished.
|
||||
# reasoner is not None only when `--reasoning-parser` is set.
|
||||
if self._reasoner is not None:
|
||||
if not self._reasoner.is_reasoning_end(input_ids):
|
||||
return scores
|
||||
else:
|
||||
# Remove the reasoning tokens from the input_ids
|
||||
# We need this because our implementation relies on the
|
||||
# hash of the input_ids to store the FSM state.
|
||||
input_ids = self._reasoner.extract_content_ids(input_ids)
|
||||
if self._reasoner is not None and not self._reasoner.is_reasoning_end(
|
||||
input_ids):
|
||||
return scores
|
||||
|
||||
seq_id = hash(tuple(input_ids))
|
||||
# Remove the reasoning tokens from the input_ids
|
||||
# We need this because our implementation relies on the
|
||||
# input_ids sequence to store the FSM state.
|
||||
input_ids = (self._reasoner.extract_content_ids(input_ids)
|
||||
if self._reasoner is not None else input_ids)
|
||||
|
||||
if len(input_ids) > 0:
|
||||
last_token = input_ids[-1]
|
||||
last_seq_id = hash(tuple(input_ids[:-1]))
|
||||
self._fsm_state[seq_id] = self._guide.get_next_state(
|
||||
state=self._fsm_state[last_seq_id], token_id=last_token)
|
||||
else:
|
||||
# Note: this is a hack.
|
||||
# Lark pickling does not work properly (silent failure),
|
||||
# which breaks the RPC (which uses python pickleing).
|
||||
# We need to find a better solution.
|
||||
# On the first time this is called, we simply re-create
|
||||
# the Lark object.
|
||||
if isinstance(self._guide, CFGGuide):
|
||||
self._guide.parser = PartialLark(
|
||||
self._guide.cfg_string,
|
||||
parser="lalr",
|
||||
import_paths=[grammars.GRAMMAR_PATH],
|
||||
)
|
||||
self._fsm_state[seq_id] = CFGState(
|
||||
parser_state=self._guide.parser.parse(""), prev_token=None)
|
||||
# Vllm V0 engine has a weird bug where we have to repeat
|
||||
# the eos token id twice for generation to stop, or at least
|
||||
# that is what we have to do from here in any case.
|
||||
# This is a patch until a better solution can be pushed
|
||||
# to outlines_core
|
||||
if input_ids and input_ids[-1] != self._eos_token_id:
|
||||
self._guide.advance(token_id=input_ids[-1], return_tokens=False)
|
||||
|
||||
instruction = self._guide.get_next_instruction(
|
||||
state=self._fsm_state[seq_id])
|
||||
self._guide.write_mask_into(
|
||||
data_ptr=self._mask.data_ptr(),
|
||||
numel=self._mask.numel(),
|
||||
element_size=self._mask.element_size(),
|
||||
)
|
||||
|
||||
if type(instruction) == Generate: # noqa: E721
|
||||
allowed_tokens = instruction.tokens
|
||||
elif type(instruction) == Write: # noqa: E721
|
||||
# TODO: support fast forward tokens
|
||||
allowed_tokens = [instruction.tokens[0]]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported instruction type {type(instruction)}")
|
||||
# Any allowed tokens beyond the length of the scores will
|
||||
# be ignored by the kernel, taking care of the issue with
|
||||
# models such as Llama 3.2 Vision with an `<|image|>` token
|
||||
# with id 128256, but scores.shape == torch.Size([128256])
|
||||
_apply_token_bitmask_inplace_kernel(
|
||||
logits=scores.unsqueeze(dim=0),
|
||||
# mask must be on same device
|
||||
mask=self._mask.to(scores.device, non_blocking=True))
|
||||
self._mask.to("cpu", non_blocking=True)
|
||||
|
||||
mask = torch.full((scores.shape[-1], ),
|
||||
-torch.inf,
|
||||
device=scores.device)
|
||||
# The tokenizer may support more token ids than the model can generate,
|
||||
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
|
||||
# but scores.shape == torch.Size([128256])
|
||||
# Using NumPy is faster for filtering token ids
|
||||
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
|
||||
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
||||
allowed_tokens = allowed_tokens.masked_select(
|
||||
allowed_tokens < scores.shape[-1])
|
||||
mask.index_fill_(0, allowed_tokens, 0)
|
||||
if current_platform.is_hpu():
|
||||
# Workaround for HPU bug where add_() raise RuntimeError:
|
||||
# synNodeCreateWithId failed for node: strided_insert
|
||||
# with synStatus 1 [Invalid argument], hopefully it will
|
||||
# be fixed in the future releases of the HPU runtime.
|
||||
scores = scores.add(mask)
|
||||
else:
|
||||
scores.add_(mask)
|
||||
return scores
|
||||
|
||||
def clone(self) -> BaseLogitsProcessor:
|
||||
guide = copy.deepcopy(self._guide)
|
||||
guide.reset()
|
||||
return BaseLogitsProcessor(guide=guide,
|
||||
eos_token_id=self._eos_token_id,
|
||||
reasoner=self._reasoner)
|
||||
|
||||
|
||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
@classmethod
|
||||
@cache()
|
||||
def _get_guide(cls, regex_string: str,
|
||||
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
return RegexGuide.from_regex(regex_string, tokenizer)
|
||||
global CACHE
|
||||
if CACHE is None:
|
||||
CACHE = get_cache()
|
||||
vocabulary = get_vocabulary(tokenizer) # type: ignore[arg-type]
|
||||
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||
if CACHE is not None and cache_key in CACHE:
|
||||
return Guide(CACHE[cache_key])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser],
|
||||
):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
index = Index(regex_string, vocabulary.inner)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
regex_string
|
||||
A string that represents a regular expression
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
if CACHE is not None:
|
||||
CACHE[cache_key] = index
|
||||
|
||||
"""
|
||||
return Guide(index)
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
super().__init__(
|
||||
RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
|
||||
guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer),
|
||||
eos_token_id=tokenizer.eos_token_id, # type: ignore
|
||||
reasoner=reasoner)
|
||||
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
@ -170,22 +126,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
def __init__(self, schema: Union[str, dict, BaseModel],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[ReasoningParser]):
|
||||
"""Compile the FSM that drives the JSON-guided generation.
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema
|
||||
A JSON schema that encodes the structure we want the model to
|
||||
generate
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
whitespace_pattern
|
||||
Pattern to use for JSON syntactic whitespace (doesn't impact
|
||||
string literals)
|
||||
Example: allow only a single space or newline with
|
||||
`whitespace_pattern=r"[\n ]?"`
|
||||
"""
|
||||
if isinstance(schema, type(BaseModel)):
|
||||
schema_str = json.dumps(schema.model_json_schema())
|
||||
elif isinstance(schema, dict):
|
||||
@ -197,63 +139,42 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
f"Cannot parse schema {schema}. The schema must be either "
|
||||
f"a Pydantic object, a dictionary or a string that contains "
|
||||
f"the JSON Schema specification")
|
||||
|
||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||
super().__init__(regex_string, tokenizer, reasoner)
|
||||
|
||||
|
||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
@classmethod
|
||||
@cache()
|
||||
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
return CFGGuide(cfg, tokenizer)
|
||||
|
||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]):
|
||||
"""Compile the FSM that drives the context free grammar generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cfg
|
||||
A string that represents a context-free grammar
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
|
||||
reasoner)
|
||||
self._guide = self._guide.copy()
|
||||
|
||||
def clone(self) -> "CFGLogitsProcessor":
|
||||
cloned = copy.copy(self)
|
||||
cloned._fsm_state = copy.deepcopy(self._fsm_state)
|
||||
cloned._guide = self._guide.copy()
|
||||
return cloned
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. The decoder of outlines, returns a list whereas
|
||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||
outlines internal api, the decoder should be adapted. In addition
|
||||
we need to handle the missing spaces to Llama's tokenizer to be
|
||||
able to compile FSMs for this model.
|
||||
|
||||
class OutlinesVocabulary:
|
||||
"""
|
||||
Wrapper class for `outlines_core.Vocabulary`,
|
||||
which allows us to store a hash with the vocabulary
|
||||
"""
|
||||
if getattr(tokenizer, "_outlines_adapted", False):
|
||||
return tokenizer
|
||||
|
||||
tokenizer = copy.deepcopy(tokenizer)
|
||||
def __init__(self, vocabulary: Vocabulary) -> None:
|
||||
# Actual vocabulary object
|
||||
self.inner = vocabulary
|
||||
# Have to do abs(hash()) because python hashes can
|
||||
# be negative, and we are using hash as a cache key.
|
||||
hex_str = hashlib.sha256(
|
||||
vocabulary.__repr__().encode('utf-8')).hexdigest()
|
||||
hash_int = int(hex_str, 16)
|
||||
self._hash = hash_int
|
||||
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
|
||||
re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
|
||||
|
||||
|
||||
def _reduced_vocabulary(tokenizer: AnyTokenizer,
|
||||
eos_token_id: int) -> dict[bytes, list[int]]:
|
||||
"""Create a map from vocabulary tokens to lists of equivalent token ids.
|
||||
|
||||
Returns:
|
||||
A Dict of token string -> equivalent token ids
|
||||
"""
|
||||
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
@ -264,21 +185,123 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
|
||||
return string
|
||||
|
||||
def change_decoder(
|
||||
decoder: Callable[[list[int]],
|
||||
str]) -> Callable[[list[int]], list[str]]:
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
vocabulary: dict[bytes, list[int]] = {}
|
||||
empty_token_ids: list[int] = []
|
||||
for token, token_idx in tokenizer.get_vocab().items():
|
||||
if token in tokenizer.all_special_tokens: # type: ignore
|
||||
continue
|
||||
|
||||
def new_decoder(inp_tokens: list[int]) -> list[str]:
|
||||
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
|
||||
and isinstance(inp_tokens[0], list)):
|
||||
inp_tokens = inp_tokens[0]
|
||||
return [decoder(inp_tokens)]
|
||||
token_str = convert_token_to_string(token)
|
||||
if token_str:
|
||||
if isinstance(token, (bytes, bytearray)):
|
||||
# For BPE tokenizers where tokens are stored as bytes.
|
||||
|
||||
return new_decoder
|
||||
# safe to ignore since token_str is of type (bytearray, bytes)
|
||||
# by this point.
|
||||
token_bytes = bytes(token_str) # type: ignore[arg-type]
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||
elif "\ufffd" in token_str and not re_replacement_seq.match(
|
||||
token_str):
|
||||
# Handle tokens with invalid UTF-8 sequences.
|
||||
if re_llama_byte_token.match(token):
|
||||
# Llama-like tokenizers use <0xXX> for incomplete sequences.
|
||||
token_bytes = bytes([int(token[3:5], 16)])
|
||||
else:
|
||||
# GPT2 tokenizers: map each byte back using unicode_to_bytes
|
||||
byte_vals = [unicode_to_bytes.get(c) for c in token]
|
||||
if None in byte_vals:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert token `{token}`"
|
||||
f" ({token_idx}) to bytes: {token_str}")
|
||||
# safe to ignore, since if None in byte_vals,
|
||||
# an error is thrown.
|
||||
token_bytes = bytes(byte_vals) # type: ignore[arg-type]
|
||||
else:
|
||||
token_bytes = token_str.encode('utf-8')
|
||||
|
||||
return tokenizer
|
||||
if token_idx != eos_token_id:
|
||||
vocabulary.setdefault(token_bytes, []).append(token_idx)
|
||||
else:
|
||||
empty_token_ids.append(token_idx)
|
||||
|
||||
return vocabulary
|
||||
|
||||
|
||||
def get_vocabulary(tokenizer: AnyTokenizer) -> Vocabulary:
|
||||
"""Get the `Vocabulary` object for a given tokenizer.
|
||||
"""
|
||||
if hasattr(tokenizer, "_outlines_vocabulary"):
|
||||
return tokenizer._outlines_vocabulary # type: ignore
|
||||
|
||||
try:
|
||||
if hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id",
|
||||
) and tokenizer.eos_token_id is not None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Error during guided decoding setup: Tokenizer"
|
||||
f" ({type(tokenizer)}) has no `eos_token_id` property, "
|
||||
"but `eos_token_id` is required for guided decoding"
|
||||
" to work properly.")
|
||||
|
||||
reduced_vocab = _reduced_vocabulary(
|
||||
tokenizer,
|
||||
eos_token_id #type: ignore
|
||||
)
|
||||
vocabulary = OutlinesVocabulary(Vocabulary(eos_token_id,
|
||||
reduced_vocab))
|
||||
tokenizer._outlines_vocabulary = vocabulary # type: ignore
|
||||
|
||||
return vocabulary
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Cannot get the vocabulary of the tokenizer "
|
||||
f"({type(tokenizer)}). The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
|
||||
|
||||
def get_cache_path() -> str:
|
||||
"""Get the context object that contains previously-computed return values"""
|
||||
outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
if outlines_cache_dir:
|
||||
# OUTLINES_CACHE_DIR takes precedence
|
||||
return outlines_cache_dir
|
||||
elif xdg_cache_home:
|
||||
return os.path.join(xdg_cache_home, ".cache", "outlines")
|
||||
# If homedir is "/", we may be inside a container, and thus writing to
|
||||
# root would be problematic, so we fallback to using a tempfile.
|
||||
# Also validate the path exists, since os.path.expanduser does
|
||||
# not garuntee existence.
|
||||
elif os.path.isdir(home_dir) and home_dir != "/":
|
||||
# Default Unix fallback: ~/.cache/outlines
|
||||
return os.path.join(home_dir, ".cache", "outlines")
|
||||
else:
|
||||
import tempfile
|
||||
|
||||
# home_dir may be / inside a docker container without existing user
|
||||
tempdir = tempfile.gettempdir()
|
||||
return os.path.join(tempdir, ".cache", "outlines")
|
||||
|
||||
|
||||
def get_cache():
|
||||
"""Get the Cache instance to be used for index caching"""
|
||||
|
||||
cache_dir = get_cache_path()
|
||||
if envs.VLLM_V0_USE_OUTLINES_CACHE:
|
||||
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
|
||||
"cache. It may consume a lot of disk space and should "
|
||||
"not be used with untrusted clients.")
|
||||
cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||
outlines_version = importlib.metadata.version("outlines_core")
|
||||
|
||||
cached_version = cache.get('__version__', None)
|
||||
if cached_version != outlines_version:
|
||||
cache.clear()
|
||||
cache.set('__version__', outlines_version)
|
||||
return cache
|
||||
else:
|
||||
return LRUCache(maxsize=128)
|
||||
|
||||
@ -98,7 +98,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
workspace1 = (M_sum, max(N * 2, K))
|
||||
workspace2 = (M_sum, max(N, K))
|
||||
output = (M * topk, K)
|
||||
output = (M, topk, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
@ -172,7 +172,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
|
||||
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=output)
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
|
||||
@ -10,7 +10,8 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
|
||||
#
|
||||
@ -421,6 +422,177 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def _do_fused_experts(
|
||||
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
||||
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
local_num_experts: int, expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata]
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
if fused_out is None:
|
||||
# reuse workspace13 for the output
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _maybe_chunk_fused_experts(
|
||||
self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor, topk_ids: torch.Tensor, activation: str,
|
||||
global_num_experts: int, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata]
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
|
||||
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
||||
return self._do_fused_experts(
|
||||
fused_out=None,
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
(_, _, fused_out_shape,
|
||||
_) = self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
|
||||
global_num_experts,
|
||||
local_num_experts)
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(a2_scale, s, e), topk_ids[s:e])
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
f"fused_out shape {fused_out.shape} vs M {M}")
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
|
||||
def slice_expert_tokens_metadata(
|
||||
full_expert_tokens_meta: ExpertTokensMetadata,
|
||||
chunk_topk_ids: torch.Tensor, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
|
||||
if need_expert_num_tokens_cpu:
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
|
||||
"cpu", non_blocking=True)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = (
|
||||
slice_input_tensors(chunk_idx))
|
||||
|
||||
c_expert_tokens_meta = None
|
||||
if expert_tokens_meta is not None:
|
||||
c_expert_tokens_meta = slice_expert_tokens_metadata(
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||
expert_map)
|
||||
|
||||
self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
a1q=c_a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=c_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta)
|
||||
|
||||
return fused_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -512,110 +684,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
if self.fused_experts.enable_chunking():
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
else:
|
||||
CHUNK_SIZE = M
|
||||
num_chunks = 1
|
||||
|
||||
if num_chunks == 1:
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts,
|
||||
local_num_experts)
|
||||
else:
|
||||
# Use the full M to get the final output shape.
|
||||
_, _, fused_out_shape, _ = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts,
|
||||
local_num_experts))
|
||||
# Use the CHUNK_SIZE to get the workspace shapes.
|
||||
workspace13_shape, workspace2_shape, _, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
|
||||
local_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
if num_chunks == 1:
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
else:
|
||||
# The leading output dimension may not be equal to M, so
|
||||
# we compute output indices separately.
|
||||
M_out = fused_out_shape[0]
|
||||
assert M_out >= M
|
||||
factor = M_out // M
|
||||
assert factor > 0
|
||||
OUT_CHUNK_SIZE = CHUNK_SIZE * factor
|
||||
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, (
|
||||
f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}")
|
||||
|
||||
for chunk in range(num_chunks):
|
||||
begin_chunk_idx = chunk * CHUNK_SIZE
|
||||
end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M)
|
||||
begin_out_idx = chunk * OUT_CHUNK_SIZE
|
||||
end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out)
|
||||
curr_a1q = a1q[begin_chunk_idx:end_chunk_idx]
|
||||
curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx,
|
||||
end_chunk_idx)
|
||||
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
|
||||
end_chunk_idx)
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out[begin_out_idx:end_out_idx],
|
||||
curr_a1q,
|
||||
w1,
|
||||
w2,
|
||||
curr_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=curr_a1q_scale,
|
||||
a2_scale=curr_a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
fused_out = self._maybe_chunk_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
|
||||
@ -6,11 +6,14 @@ import pplx_kernels as pplx
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_validate_scale_shape, moe_kernel_quantize_input)
|
||||
from vllm.utils import cdiv, round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def pplx_hidden_dim_scale_bytes(
|
||||
max_num_tokens: int,
|
||||
@ -101,9 +104,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
assert topk_ids.size(0) == num_tokens
|
||||
assert expert_map is None, """with expert map, -1 id is used for
|
||||
non-local token; this causes error when casting ids to the
|
||||
topk_indices_dtype() uint32"""
|
||||
# expert_map should be None because with expert map, -1 id is used for
|
||||
# non-local token; this causes error when casting ids to the
|
||||
# topk_indices_dtype() int32
|
||||
#
|
||||
if expert_map is not None:
|
||||
logger.warn_once(
|
||||
"The PPLX backend does not support expert mapping. "
|
||||
"The provided `expert_map` will be ignored.")
|
||||
expert_map = None #noqa: F841
|
||||
|
||||
# Is this always going to be a1.device?
|
||||
device = a1.device
|
||||
|
||||
@ -13,9 +13,81 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
quant_dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts,
|
||||
topk_numel, expert_map,
|
||||
HAS_EXPERT_MAP: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
curr_expert = tl.program_id(0)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
topk_ids_ptrs = topk_ids_ptr + offsets
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)
|
||||
for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)):
|
||||
mask = offsets < (topk_numel - x * BLOCK_SIZE)
|
||||
expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1)
|
||||
if HAS_EXPERT_MAP:
|
||||
expert_map_ptrs = expert_map + expert_ids
|
||||
expert_map_mask = expert_ids >= 0
|
||||
expert_ids = tl.load(expert_map_ptrs,
|
||||
mask=expert_map_mask,
|
||||
other=-1)
|
||||
|
||||
has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0)
|
||||
acc = acc + has_curr_expert
|
||||
topk_ids_ptrs += BLOCK_SIZE
|
||||
|
||||
if curr_expert < num_experts:
|
||||
tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))
|
||||
|
||||
|
||||
def count_expert_num_tokens(
|
||||
topk_ids: torch.Tensor, num_local_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Count the number to tokens assigned to each expert.
|
||||
|
||||
Parameters:
|
||||
- topk_ids (torch.Tensor): Tensor mapping each token to its
|
||||
list of experts.
|
||||
- num_local_experts (int): Number of experts in this rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
|
||||
Returns:
|
||||
A tensor of size num_local_experts, where tensor[i] holds the number
|
||||
of tokens assigned to the ith expert.
|
||||
"""
|
||||
assert topk_ids.dtype.is_signed, (
|
||||
"The kernel uses -1 to represent invalid topk_ids")
|
||||
expert_num_tokens = torch.empty((num_local_experts),
|
||||
device=topk_ids.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
grid = num_local_experts
|
||||
BLOCK_SIZE = min(topk_ids.numel(), 1024)
|
||||
BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)
|
||||
|
||||
_count_expert_num_tokens[(grid, )](
|
||||
topk_ids,
|
||||
expert_num_tokens,
|
||||
num_local_experts,
|
||||
topk_ids.numel(),
|
||||
expert_map,
|
||||
HAS_EXPERT_MAP=expert_map is not None,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return expert_num_tokens
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
|
||||
@ -23,6 +23,8 @@ from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
validate_structured_output_request_outlines)
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
validate_xgrammar_grammar)
|
||||
|
||||
@ -193,6 +195,9 @@ class Processor:
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
elif engine_level_backend == "outlines":
|
||||
# outlines backend
|
||||
validate_structured_output_request_outlines(params)
|
||||
else:
|
||||
# NOTE: engine_level_backend must be "auto" here, because we have
|
||||
# checked supported_backends above.
|
||||
|
||||
@ -88,6 +88,15 @@ class StructuredOutputManager:
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "outlines":
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
OutlinesBackend)
|
||||
|
||||
self.backend = OutlinesBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported structured output backend: {backend}")
|
||||
|
||||
319
vllm/v1/structured_output/backend_outlines.py
Normal file
319
vllm/v1/structured_output/backend_outlines.py
Normal file
@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from regex import escape as regex_escape
|
||||
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
OutlinesVocabulary, get_cache, get_vocabulary)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import outlines_core as oc
|
||||
import outlines_core.json_schema as json_schema
|
||||
else:
|
||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||
json_schema = LazyLoader("json_schema", globals(),
|
||||
"outlines_core.json_schema")
|
||||
|
||||
# Python 3.11+ sre_parse and sre_constants
|
||||
# are deprecated, so we must import them from re
|
||||
if sys.version_info >= (3, 11):
|
||||
# Hack to get around pre-commit regex module rule
|
||||
# because going through re is the only way to get sre_parse
|
||||
# and sre_constants in Python 3.11+
|
||||
_re = importlib.import_module("re")
|
||||
sre_parse = _re._parser
|
||||
sre_constants = _re._constants
|
||||
else:
|
||||
import sre_constants
|
||||
import sre_parse
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlinesBackend(StructuredOutputBackend):
|
||||
|
||||
def __post_init__(self):
|
||||
self.vocabulary = get_vocabulary(self.tokenizer)
|
||||
self.cache = get_cache()
|
||||
|
||||
def _compile_index(self, regex_string: str,
|
||||
vocabulary: OutlinesVocabulary) -> oc.Index:
|
||||
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
|
||||
index = oc.Index(regex_string, vocabulary.inner)
|
||||
self.cache[cache_key] = index
|
||||
|
||||
return index
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
regex = json_schema.build_regex_from_schema(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
regex = grammar_spec
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
choices = ast.literal_eval(grammar_spec)
|
||||
choices = [regex_escape(c) for c in choices]
|
||||
regex = "(" + "|".join(choices) + ")"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid request type for Outlines backend ({request_type!s})"
|
||||
)
|
||||
index = self._compile_index(regex, self.vocabulary)
|
||||
max_rollback_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config is not None else 0)
|
||||
return OutlinesGrammar(vocab_size=self.vocab_size,
|
||||
guide=oc.Guide(
|
||||
index, max_rollback=max_rollback_tokens))
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlinesGrammar(StructuredOutputGrammar):
|
||||
|
||||
vocab_size: int
|
||||
guide: oc.Guide = field(hash=False)
|
||||
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||
repr=False,
|
||||
hash=False,
|
||||
init=False)
|
||||
|
||||
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
|
||||
# We delay the finished flag by one step so EOS can still be emitted.
|
||||
_prev_finished: bool = field(default=False,
|
||||
init=False,
|
||||
repr=False,
|
||||
hash=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
if self.guide.accepts_tokens(tokens):
|
||||
# Advance cannot fail because we checked Guide.accepts_tokens()
|
||||
for t in tokens:
|
||||
self.guide.advance(t)
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.guide.rollback_state(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
accepted: list[int] = []
|
||||
for tok in tokens:
|
||||
accepted.append(tok)
|
||||
if not self.guide.accepts_tokens(accepted):
|
||||
accepted.pop()
|
||||
break
|
||||
return accepted
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
mask = bitmask[idx]
|
||||
self.guide.write_mask_into(mask.data_ptr(), mask.numel(),
|
||||
mask.element_size())
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
curr = self.guide.is_finished()
|
||||
prev = self._prev_finished
|
||||
self._prev_finished = curr
|
||||
return prev
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self._prev_finished = False
|
||||
self.guide.reset()
|
||||
|
||||
|
||||
def validate_structured_output_request_outlines(params: SamplingParams):
|
||||
if params.guided_decoding is None:
|
||||
return
|
||||
|
||||
gd_params = params.guided_decoding
|
||||
|
||||
if gd_params.regex:
|
||||
validate_regex_is_buildable(gd_params.regex)
|
||||
elif gd_params.json:
|
||||
if isinstance(gd_params.json, str):
|
||||
try:
|
||||
# make sure schema is valid json
|
||||
json.loads(gd_params.json)
|
||||
schema = gd_params.json
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
try:
|
||||
schema = json.dumps(gd_params.json)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error serializing guided decoding jsonschema: {e}"
|
||||
) from e
|
||||
pattern = json_schema.build_regex_from_schema(schema)
|
||||
validate_regex_is_buildable(pattern)
|
||||
elif gd_params.choice:
|
||||
choices = [regex_escape(str(choice)) for choice in gd_params.choice]
|
||||
regex = "(" + "|".join(choices) + ")"
|
||||
validate_regex_is_buildable(regex)
|
||||
elif gd_params.grammar:
|
||||
raise ValueError("Outlines guided decoding backend "
|
||||
"does not support grammar specifications")
|
||||
|
||||
|
||||
def _prefix_needs_context(parsed) -> bool:
|
||||
"""Return True if there's a look-around/anchor before any consumer."""
|
||||
|
||||
def subpattern_consumes(parsed) -> bool:
|
||||
"""Return True if subpattern can consume at least one character."""
|
||||
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||
for ttype, tval in tokens:
|
||||
# literal, character class, or dot always consumes
|
||||
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||
return True
|
||||
# quantified subpattern: check inner pattern
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
_, mx, sub = tval
|
||||
if mx != 0 and subpattern_consumes(sub):
|
||||
return True
|
||||
# alternation: if any branch consumes, the whole does
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
_, branches = tval
|
||||
if any(subpattern_consumes(br) for br in branches):
|
||||
return True
|
||||
# grouped subpattern: recurse into its contents
|
||||
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(
|
||||
tval[3]):
|
||||
return True
|
||||
# No consumers, return False
|
||||
return False
|
||||
|
||||
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||
for ttype, tval in tokens:
|
||||
# Direct anchors or look-around
|
||||
if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT,
|
||||
sre_constants.ASSERT_NOT):
|
||||
return True
|
||||
|
||||
# Nested subpattern: check
|
||||
if ttype == sre_parse.SUBPATTERN:
|
||||
# tval: (group, add_flags, del_flags, subpattern)
|
||||
if _prefix_needs_context(tval[3]):
|
||||
return True
|
||||
if subpattern_consumes(tval[3]):
|
||||
return False
|
||||
|
||||
# if any branch has a prefix anchor => True,
|
||||
# else if at least one branch consumes => prefix ends => False
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
saw_consumer = False
|
||||
for br in tval[1]:
|
||||
if _prefix_needs_context(br):
|
||||
return True
|
||||
if subpattern_consumes(br):
|
||||
saw_consumer = True
|
||||
if saw_consumer:
|
||||
return False
|
||||
|
||||
# Immediate consumer tokens
|
||||
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||
return False
|
||||
|
||||
# if subpattern has anchor => True, if it can consume => stop
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
if _prefix_needs_context(tval[2]):
|
||||
return True
|
||||
if subpattern_consumes(tval[2]):
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _check_unsupported(parsed) -> None:
|
||||
"""Check for regex features unsupported by regex-automata"""
|
||||
tokens = parsed.data if hasattr(parsed, 'data') else parsed
|
||||
for ttype, tval in tokens:
|
||||
|
||||
# backreference
|
||||
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
|
||||
raise ValueError("Backreferences are unsupported.")
|
||||
|
||||
# look-around assertion
|
||||
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
|
||||
raise ValueError("Look-Around assertion are unsupported.")
|
||||
|
||||
# unicode word boundaries
|
||||
elif ttype == sre_parse.AT:
|
||||
if tval in (sre_constants.AT_BOUNDARY,
|
||||
sre_constants.AT_NON_BOUNDARY):
|
||||
raise ValueError("Unicode word boundaries are unsupported.")
|
||||
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
# tval is (None, branches)
|
||||
for branch in tval[1]:
|
||||
_check_unsupported(branch)
|
||||
|
||||
# tval is (min, max, subpattern)
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
_check_unsupported(tval[2])
|
||||
|
||||
|
||||
def validate_regex_is_buildable(pattern: str) -> None:
|
||||
"""
|
||||
Validates that the input regex is not using unsupported features
|
||||
of the `regex-automata` crate (outlines_core regex engine) and has a
|
||||
universal start state.
|
||||
definition of universal start state used can be found at:
|
||||
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
|
||||
"""
|
||||
try:
|
||||
parsed = sre_parse.parse(pattern)
|
||||
|
||||
except sre_constants.error as e:
|
||||
raise ValueError(f"Error parsing regex: {e}") from e
|
||||
|
||||
try:
|
||||
_check_unsupported(parsed)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Regex uses unsupported feature for guided decoding: {e}. "
|
||||
"Only basic matching constructs are supported—lookarounds, "
|
||||
"backreferences, and unicode boundaries are not.") from e
|
||||
|
||||
if _prefix_needs_context(parsed):
|
||||
raise ValueError(
|
||||
"Regex does not have a anchored universal start state"
|
||||
"This means that the Regex uses anchors (^) or look-arounds "
|
||||
"in a way which requires context before any token is matched."
|
||||
"Guided decoding needs regexes that can match without needing "
|
||||
"that context. Try rewriting the pattern without using these "
|
||||
f"constructs. Pattern:\n{pattern}")
|
||||
@ -135,14 +135,7 @@ class Worker(WorkerBase):
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
|
||||
device_id = self.local_rank
|
||||
if envs.VLLM_VISIBLE_DEVICES is not None:
|
||||
devices = [
|
||||
int(dev) for dev in (x.strip() for x in envs.VLLM_VISIBLE_DEVICES.split(','))
|
||||
]
|
||||
device_id = devices[self.local_rank]
|
||||
self.device = torch.device(f"cuda:{device_id}")
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
|
||||
Reference in New Issue
Block a user