mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
feat: support structured output in sandbox and tool mode
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
@ -4,10 +4,17 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
@ -67,6 +74,8 @@ class ReActStrategy(AgentPattern):
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
structured_output_payload: dict[str, Any] | None = None
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
@ -84,10 +93,13 @@ class ReActStrategy(AgentPattern):
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
# Build prompt with tool restrictions on last iteration
|
||||
if iteration_step == max_iterations:
|
||||
tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET]
|
||||
else:
|
||||
tools_for_prompt = self.tools
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
@ -109,18 +121,21 @@ class ReActStrategy(AgentPattern):
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
chunks = cast(
|
||||
Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
),
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
chunks, round_usage, model_log, current_messages, emit_chunks=False
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
@ -134,28 +149,46 @@ class ReActStrategy(AgentPattern):
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
if scratchpad.action is None:
|
||||
illegal_action = AgentScratchpadUnit.Action(
|
||||
action_name=ILLEGAL_OUTPUT_TOOL,
|
||||
action_input={"raw": scratchpad.thought or ""},
|
||||
)
|
||||
scratchpad.action = illegal_action
|
||||
scratchpad.action_str = json.dumps(illegal_action.to_dict())
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
action_name = scratchpad.action.action_name
|
||||
if action_name == FINAL_OUTPUT_TOOL:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_text = self._format_output_text(scratchpad.action.action_input.get("text"))
|
||||
else:
|
||||
final_text = self._format_output_text(scratchpad.action.action_input)
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
react_state = False
|
||||
else:
|
||||
if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
||||
output_text_payload = scratchpad.action.action_input.get("text")
|
||||
elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(
|
||||
scratchpad.action.action_input, dict
|
||||
):
|
||||
data = scratchpad.action.action_input.get("data")
|
||||
if isinstance(data, dict):
|
||||
structured_output_payload = data
|
||||
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
@ -173,15 +206,38 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
output=output_payload,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage"),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
tools: list[Tool] | None,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
@ -198,9 +254,9 @@ class ReActStrategy(AgentPattern):
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
if tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
@ -253,6 +309,8 @@ class ReActStrategy(AgentPattern):
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
*,
|
||||
emit_chunks: bool,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
@ -306,14 +364,16 @@ class ReActStrategy(AgentPattern):
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
@ -337,6 +397,14 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
@staticmethod
|
||||
def _format_output_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
|
||||
Reference in New Issue
Block a user