[Frontend] Implement Tool Calling with tool_choice='required' (#13483)
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at> Co-authored-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@ -786,56 +786,135 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_tool_use_not_yet_supported(client: openai.AsyncOpenAI,
|
||||
sample_json_schema):
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_required_tool_use(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool, model_name: str):
|
||||
if is_v1_server:
|
||||
pytest.skip("sample_json_schema has features unsupported on V1")
|
||||
pytest.skip(
|
||||
"tool_choice='required' requires features unsupported on V1")
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {sample_json_schema}"
|
||||
}]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice="required")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice="auto")
|
||||
# Non-streaming test
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
extra_body=dict(guided_decoding_backend="outlines"),
|
||||
)
|
||||
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
|
||||
# Streaming test
|
||||
stream = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
extra_body=dict(guided_decoding_backend="outlines"),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool,
|
||||
sample_json_schema):
|
||||
|
||||
if is_v1_server:
|
||||
|
||||
@ -43,7 +43,8 @@ def test_chat_completion_request_with_no_tools():
|
||||
assert request.tool_choice == 'none'
|
||||
|
||||
|
||||
def test_chat_completion_request_with_tool_choice_but_no_tools():
|
||||
@pytest.mark.parametrize('tool_choice', ['auto', 'required'])
|
||||
def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice):
|
||||
with pytest.raises(ValueError,
|
||||
match="When using `tool_choice`, `tools` must be set."):
|
||||
ChatCompletionRequest.model_validate({
|
||||
@ -54,7 +55,7 @@ def test_chat_completion_request_with_tool_choice_but_no_tools():
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
'tool_choice':
|
||||
'auto'
|
||||
tool_choice
|
||||
})
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
@ -67,7 +68,7 @@ def test_chat_completion_request_with_tool_choice_but_no_tools():
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
'tool_choice':
|
||||
'auto',
|
||||
tool_choice,
|
||||
'tools':
|
||||
None
|
||||
})
|
||||
|
||||
336
tests/tool_use/test_tool_choice_required.py
Normal file
336
tests/tool_use/test_tool_choice_required.py
Normal file
@ -0,0 +1,336 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
||||
EXAMPLE_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for"
|
||||
", e.g. 'San Francisco'",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
},
|
||||
"strict": True
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'New York'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
},
|
||||
"required": ["city", "days"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
},
|
||||
"strict": True
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
||||
should_match: bool):
|
||||
self = MagicMock(tool_choice="required", tools=tools)
|
||||
schema = ChatCompletionRequest._get_guided_json_from_tool(self)
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
regex = build_regex_from_schema(json.dumps(schema))
|
||||
compiled = re.compile(regex)
|
||||
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
||||
|
||||
assert matches == should_match
|
||||
|
||||
|
||||
VALID_TOOL_OUTPUTS = [
|
||||
([{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Berlin"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Berlin",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Berlin"
|
||||
}
|
||||
}], True),
|
||||
]
|
||||
|
||||
VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_output, should_match",
|
||||
VALID_TOOL_OUTPUTS + [
|
||||
(None, False),
|
||||
([], False), # empty list cannot be generated
|
||||
({}, False), # empty object cannot be generated
|
||||
([{}], False), # list with empty object cannot be generated
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {}
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None
|
||||
}],
|
||||
False),
|
||||
(
|
||||
{ # tool call without lists cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
},
|
||||
False),
|
||||
(
|
||||
[{ # tool call with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"extra": "value"
|
||||
}
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # tool call where parameters are first cannot be generated
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
},
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # tool call without all required parameters cannot be generated
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}],
|
||||
False),
|
||||
( # tool call with incorrect name/parameters cannot be generated
|
||||
[{
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}], False),
|
||||
( # tool call with both valid and empty function cannot be generated
|
||||
[{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {}], False),
|
||||
])
|
||||
def test_guided_json(sample_output, should_match):
|
||||
_compile_and_check(tools=TypeAdapter(
|
||||
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
|
||||
sample_output=sample_output,
|
||||
should_match=should_match)
|
||||
|
||||
|
||||
def update_parameters_none(
|
||||
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||||
tool.function.parameters = None
|
||||
return tool
|
||||
|
||||
|
||||
def update_parameters_empty_dict(
|
||||
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||||
tool.function.parameters = {}
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_output, should_match",
|
||||
[
|
||||
(None, False),
|
||||
([], False), # empty list cannot be generated
|
||||
({}, False), # empty object cannot be generated
|
||||
([{}], False), # list with empty object cannot be generated
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"extra": "value"
|
||||
}
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # only function with empty parameters object is valid
|
||||
"name": "get_current_weather",
|
||||
"parameters": {}
|
||||
}],
|
||||
True),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"update_parameters",
|
||||
[update_parameters_none, update_parameters_empty_dict])
|
||||
def test_guided_json_without_parameters(sample_output, should_match,
|
||||
update_parameters):
|
||||
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
|
||||
tools = TypeAdapter(
|
||||
list[ChatCompletionToolsParam]).validate_python(updated_tools)
|
||||
tools = list(map(update_parameters, tools))
|
||||
assert all([
|
||||
tool.function.parameters is None or tool.function.parameters == {}
|
||||
for tool in tools
|
||||
])
|
||||
_compile_and_check(tools=tools,
|
||||
sample_output=sample_output,
|
||||
should_match=should_match)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("output", VALID_TOOLS)
|
||||
@pytest.mark.parametrize("empty_params", [False, True])
|
||||
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
self = MagicMock()
|
||||
|
||||
output = deepcopy(output)
|
||||
if empty_params:
|
||||
output = [{"name": o["name"], "parameters": {}} for o in output]
|
||||
output_json = json.dumps(output)
|
||||
|
||||
previous_text = ""
|
||||
function_name_returned = False
|
||||
messages = []
|
||||
for i in range(0, len(output_json), delta_len):
|
||||
delta_text = output_json[i:i + delta_len]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message, function_name_returned = (
|
||||
OpenAIServingChat.extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned))
|
||||
|
||||
if delta_message:
|
||||
messages.append(delta_message)
|
||||
|
||||
previous_text = current_text
|
||||
|
||||
assert len(messages) > 0
|
||||
combined_messages = "["
|
||||
for message in messages:
|
||||
if message.tool_calls[0].function.name:
|
||||
if len(combined_messages) > 1:
|
||||
combined_messages += "},"
|
||||
|
||||
combined_messages += '{"name": "' + \
|
||||
message.tool_calls[0].function.name + \
|
||||
'", "parameters": ' + \
|
||||
message.tool_calls[0].function.arguments
|
||||
else:
|
||||
combined_messages += message.tool_calls[0].function.arguments
|
||||
combined_messages += "}]"
|
||||
assert json.loads(combined_messages) == output
|
||||
assert json.dumps(json.loads(combined_messages)) == output_json
|
||||
Reference in New Issue
Block a user