feat: support structured output in sandbox and tool mode

Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
Stream
2026-01-30 06:46:38 +08:00
parent d3fc457331
commit 7926024569
22 changed files with 629 additions and 117 deletions

View File

@ -1,10 +1,18 @@
"""Function Call strategy implementation."""
import json
import uuid
from collections.abc import Generator
from typing import Any, Union
from typing import Any, Literal, Protocol, Union, cast
from core.agent.entities import AgentLog, AgentResult
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult
from core.agent.output_tools import (
FINAL_OUTPUT_TOOL,
FINAL_STRUCTURED_OUTPUT_TOOL,
ILLEGAL_OUTPUT_TOOL,
OUTPUT_TEXT_TOOL,
OUTPUT_TOOL_NAME_SET,
)
from core.file import File
from core.model_runtime.entities import (
AssistantPromptMessage,
@ -42,9 +50,24 @@ class FunctionCallStrategy(AgentPattern):
total_usage: dict[str, LLMUsage | None] = {"usage": None}
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
structured_output_payload: dict[str, Any] | None = None
output_text_payload: str | None = None
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
class _LLMInvoker(Protocol):
def invoke_llm(
self,
*,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
tools: list[PromptMessageTool],
stop: list[str],
stream: Literal[False],
user: str | None,
callbacks: list[Any],
) -> LLMResult: ...
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
round_log = self._create_log(
@ -54,8 +77,11 @@ class FunctionCallStrategy(AgentPattern):
data={},
)
yield round_log
# On last iteration, remove tools to force final answer
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
# On last iteration, restrict tools to output tools
if iteration_step == max_iterations:
current_tools = [tool for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET]
else:
current_tools = prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
@ -72,31 +98,41 @@ class FunctionCallStrategy(AgentPattern):
round_usage: dict[str, LLMUsage | None] = {"usage": None}
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
invoker = cast(_LLMInvoker, self.model_instance)
chunks = invoker.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=current_tools,
stop=stop,
stream=stream,
stream=False,
user=self.context.user_id,
callbacks=[],
)
# Process response
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log
chunks, round_usage, model_log, emit_chunks=False
)
messages.append(self._create_assistant_message(response_content, tool_calls))
if not tool_calls:
tool_calls = [
(
str(uuid.uuid4()),
ILLEGAL_OUTPUT_TOOL,
{
"raw": response_content,
},
)
]
response_content = ""
messages.append(self._create_assistant_message("", tool_calls))
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update final text if no tool calls (this is likely the final answer)
if not tool_calls:
final_text = response_content
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
@ -105,14 +141,27 @@ class FunctionCallStrategy(AgentPattern):
tool_outputs: dict[str, str] = {}
if tool_calls:
function_call_state = True
terminal_tool_seen = False
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
if tool_name == OUTPUT_TEXT_TOOL:
output_text_payload = self._format_output_text(tool_args.get("text"))
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
data = tool_args.get("data")
structured_output_payload = cast(dict[str, Any] | None, data)
elif tool_name == FINAL_OUTPUT_TOOL:
final_text = self._format_output_text(tool_args.get("text"))
terminal_tool_seen = True
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
if terminal_tool_seen:
function_call_state = False
yield self._finish_log(
round_log,
data={
@ -131,8 +180,28 @@ class FunctionCallStrategy(AgentPattern):
# Return final result
from core.agent.entities import AgentResult
output_payload: str | AgentResult.StructuredOutput
if final_text:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
output_text=final_text,
output_data=structured_output_payload,
)
elif output_text_payload:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.OUTPUT_TEXT,
output_text=str(output_text_payload),
output_data=None,
)
else:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
output_text="Model failed to produce a final output.",
output_data=None,
)
return AgentResult(
text=final_text,
output=output_payload,
files=output_files,
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
finish_reason=finish_reason,
@ -143,6 +212,8 @@ class FunctionCallStrategy(AgentPattern):
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, LLMUsage | None],
start_log: AgentLog,
*,
emit_chunks: bool,
) -> Generator[
LLMResultChunk | AgentLog,
None,
@ -174,7 +245,8 @@ class FunctionCallStrategy(AgentPattern):
if chunk.delta.finish_reason:
finish_reason = chunk.delta.finish_reason
yield chunk
if emit_chunks:
yield chunk
else:
# Non-streaming response
result: LLMResult = chunks
@ -189,11 +261,12 @@ class FunctionCallStrategy(AgentPattern):
self._accumulate_usage(llm_usage, result.usage)
# Convert to streaming format
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
if emit_chunks:
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
yield self._finish_log(
start_log,
data={
@ -203,6 +276,14 @@ class FunctionCallStrategy(AgentPattern):
)
return tool_calls, response_content, finish_reason
@staticmethod
def _format_output_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
return json.dumps(value, ensure_ascii=False)
def _create_assistant_message(
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
) -> AssistantPromptMessage: