Consolidate rendering parameters into RenderConfig dataclass (#24543)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
@ -10,7 +10,7 @@ import pybase64
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
|
||||
from vllm.inputs.data import is_embeds_prompt
|
||||
|
||||
|
||||
@ -56,8 +56,8 @@ class TestRenderPrompt:
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input(self, renderer):
|
||||
tokens = [101, 7592, 2088]
|
||||
results = await renderer.render_prompt(prompt_or_prompts=tokens,
|
||||
max_length=100)
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=tokens, config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
@ -65,8 +65,8 @@ class TestRenderPrompt:
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_list_input(self, renderer):
|
||||
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
||||
results = await renderer.render_prompt(prompt_or_prompts=token_lists,
|
||||
max_length=100)
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
@ -80,8 +80,9 @@ class TestRenderPrompt:
|
||||
renderer.async_tokenizer_pool[
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=100)
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts="Hello world",
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
@ -96,7 +97,8 @@ class TestRenderPrompt:
|
||||
|
||||
text_list_input = ["Hello world", "How are you?", "Good morning"]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=text_list_input, max_length=100)
|
||||
prompt_or_prompts=text_list_input,
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
@ -110,8 +112,9 @@ class TestRenderPrompt:
|
||||
renderer.async_tokenizer_pool[
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=100)
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts="Hello world",
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
@ -126,8 +129,9 @@ class TestRenderPrompt:
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=100,
|
||||
truncate_prompt_tokens=50)
|
||||
config=RenderConfig(
|
||||
max_length=100,
|
||||
truncate_prompt_tokens=50))
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
@ -143,8 +147,9 @@ class TestRenderPrompt:
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=200,
|
||||
truncate_prompt_tokens=-1)
|
||||
config=RenderConfig(
|
||||
max_length=200,
|
||||
truncate_prompt_tokens=-1))
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
@ -157,8 +162,9 @@ class TestRenderPrompt:
|
||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
|
||||
109] # 10 tokens
|
||||
results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
|
||||
max_length=100,
|
||||
truncate_prompt_tokens=5)
|
||||
config=RenderConfig(
|
||||
max_length=100,
|
||||
truncate_prompt_tokens=5))
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
|
||||
@ -170,7 +176,7 @@ class TestRenderPrompt:
|
||||
|
||||
with pytest.raises(ValueError, match="maximum context length"):
|
||||
await renderer.render_prompt(prompt_or_prompts=long_tokens,
|
||||
max_length=100)
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tokenizer_for_text(self, mock_model_config):
|
||||
@ -181,7 +187,8 @@ class TestRenderPrompt:
|
||||
|
||||
with pytest.raises(ValueError, match="No tokenizer available"):
|
||||
await renderer_no_tokenizer.render_prompt(
|
||||
prompt_or_prompts="Hello world", max_length=100)
|
||||
prompt_or_prompts="Hello world",
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input_with_needs_detokenization(
|
||||
@ -196,7 +203,7 @@ class TestRenderPrompt:
|
||||
tokens = [1, 2, 3, 4]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=tokens,
|
||||
needs_detokenization=True,
|
||||
config=RenderConfig(needs_detokenization=True),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
@ -221,7 +228,9 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes, cache_salt="test_salt")
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(cache_salt="test_salt"),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert is_embeds_prompt(results[0])
|
||||
@ -240,7 +249,9 @@ class TestRenderEmbedPrompt:
|
||||
]
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes_list)
|
||||
prompt_embeds=embed_bytes_list,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
for i, result in enumerate(results):
|
||||
@ -254,7 +265,9 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes, truncate_prompt_tokens=10)
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(truncate_prompt_tokens=10),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep last 10 tokens
|
||||
@ -271,7 +284,9 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes)
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_embeds"].dtype == dtype
|
||||
@ -283,7 +298,9 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes)
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should be squeezed to 2D
|
||||
@ -303,7 +320,10 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes)
|
||||
prompt_or_prompts="Hello world",
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# First should be embed prompt
|
||||
|
||||
Reference in New Issue
Block a user