mirror of
https://github.com/langgenius/dify.git
synced 2026-04-23 20:25:56 +08:00
feat: support structured output in sandbox and tool mode
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
@ -1,10 +1,18 @@
|
||||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Protocol, Union, cast
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -42,9 +50,24 @@ class FunctionCallStrategy(AgentPattern):
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
structured_output_payload: dict[str, Any] | None = None
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
class _LLMInvoker(Protocol):
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
stream: Literal[False],
|
||||
user: str | None,
|
||||
callbacks: list[Any],
|
||||
) -> LLMResult: ...
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
@ -54,8 +77,11 @@ class FunctionCallStrategy(AgentPattern):
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
# On last iteration, restrict tools to output tools
|
||||
if iteration_step == max_iterations:
|
||||
current_tools = [tool for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET]
|
||||
else:
|
||||
current_tools = prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
@ -72,31 +98,41 @@ class FunctionCallStrategy(AgentPattern):
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
invoker = cast(_LLMInvoker, self.model_instance)
|
||||
chunks = invoker.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream=False,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
chunks, round_usage, model_log, emit_chunks=False
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
if not tool_calls:
|
||||
tool_calls = [
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
{
|
||||
"raw": response_content,
|
||||
},
|
||||
)
|
||||
]
|
||||
response_content = ""
|
||||
|
||||
messages.append(self._create_assistant_message("", tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
@ -105,14 +141,27 @@ class FunctionCallStrategy(AgentPattern):
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
terminal_tool_seen = False
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
if tool_name == OUTPUT_TEXT_TOOL:
|
||||
output_text_payload = self._format_output_text(tool_args.get("text"))
|
||||
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
|
||||
data = tool_args.get("data")
|
||||
structured_output_payload = cast(dict[str, Any] | None, data)
|
||||
elif tool_name == FINAL_OUTPUT_TOOL:
|
||||
final_text = self._format_output_text(tool_args.get("text"))
|
||||
terminal_tool_seen = True
|
||||
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
if terminal_tool_seen:
|
||||
function_call_state = False
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
@ -131,8 +180,28 @@ class FunctionCallStrategy(AgentPattern):
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=None,
|
||||
)
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=None,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
output=output_payload,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
@ -143,6 +212,8 @@ class FunctionCallStrategy(AgentPattern):
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
*,
|
||||
emit_chunks: bool,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
@ -174,7 +245,8 @@ class FunctionCallStrategy(AgentPattern):
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
if emit_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
@ -189,11 +261,12 @@ class FunctionCallStrategy(AgentPattern):
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
if emit_chunks:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
@ -203,6 +276,14 @@ class FunctionCallStrategy(AgentPattern):
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
@staticmethod
|
||||
def _format_output_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
|
||||
Reference in New Issue
Block a user