Consolidate rendering parameters into RenderConfig dataclass (#24543)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2025-09-10 01:44:47 -07:00
committed by GitHub
parent feaf202e93
commit 77f62613f9
8 changed files with 167 additions and 108 deletions

View File

@ -10,7 +10,7 @@ import pybase64
import pytest
import torch
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
from vllm.inputs.data import is_embeds_prompt
@ -56,8 +56,8 @@ class TestRenderPrompt:
@pytest.mark.asyncio
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
results = await renderer.render_prompt(prompt_or_prompts=tokens,
max_length=100)
results = await renderer.render_prompt(
prompt_or_prompts=tokens, config=RenderConfig(max_length=100))
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@ -65,8 +65,8 @@ class TestRenderPrompt:
@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt(prompt_or_prompts=token_lists,
max_length=100)
results = await renderer.render_prompt(
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100))
assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
@ -80,8 +80,9 @@ class TestRenderPrompt:
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100)
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
@ -96,7 +97,8 @@ class TestRenderPrompt:
text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt(
prompt_or_prompts=text_list_input, max_length=100)
prompt_or_prompts=text_list_input,
config=RenderConfig(max_length=100))
assert len(results) == 3
for result in results:
@ -110,8 +112,9 @@ class TestRenderPrompt:
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100)
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
@ -126,8 +129,9 @@ class TestRenderPrompt:
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100,
truncate_prompt_tokens=50)
config=RenderConfig(
max_length=100,
truncate_prompt_tokens=50))
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
@ -143,8 +147,9 @@ class TestRenderPrompt:
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=200,
truncate_prompt_tokens=-1)
config=RenderConfig(
max_length=200,
truncate_prompt_tokens=-1))
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
@ -157,8 +162,9 @@ class TestRenderPrompt:
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
109] # 10 tokens
results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100,
truncate_prompt_tokens=5)
config=RenderConfig(
max_length=100,
truncate_prompt_tokens=5))
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
@ -170,7 +176,7 @@ class TestRenderPrompt:
with pytest.raises(ValueError, match="maximum context length"):
await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100)
config=RenderConfig(max_length=100))
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config):
@ -181,7 +187,8 @@ class TestRenderPrompt:
with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", max_length=100)
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
@ -196,7 +203,7 @@ class TestRenderPrompt:
tokens = [1, 2, 3, 4]
results = await renderer.render_prompt(
prompt_or_prompts=tokens,
needs_detokenization=True,
config=RenderConfig(needs_detokenization=True),
)
assert len(results) == 1
@ -221,7 +228,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes, cache_salt="test_salt")
prompt_embeds=embed_bytes,
config=RenderConfig(cache_salt="test_salt"),
)
assert len(results) == 1
assert is_embeds_prompt(results[0])
@ -240,7 +249,9 @@ class TestRenderEmbedPrompt:
]
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes_list)
prompt_embeds=embed_bytes_list,
config=RenderConfig(),
)
assert len(results) == 2
for i, result in enumerate(results):
@ -254,7 +265,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes, truncate_prompt_tokens=10)
prompt_embeds=embed_bytes,
config=RenderConfig(truncate_prompt_tokens=10),
)
assert len(results) == 1
# Should keep last 10 tokens
@ -271,7 +284,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes)
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype
@ -283,7 +298,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes)
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1
# Should be squeezed to 2D
@ -303,7 +320,10 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes)
prompt_or_prompts="Hello world",
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 2
# First should be embed prompt