mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
feat: support structured output in sandbox and tool mode
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_valid
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMStructuredOutput, LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCall, ToolCallResult
|
||||
@ -156,6 +156,9 @@ class LLMGenerationData(BaseModel):
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
structured_output: LLMStructuredOutput | None = Field(
|
||||
default=None, description="Structured output from tool-only agent runs"
|
||||
)
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
@ -284,6 +287,7 @@ class AggregatedResult(BaseModel):
|
||||
files: list[File] = Field(default_factory=list)
|
||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
finish_reason: str | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
|
||||
@ -12,7 +12,8 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
@ -20,6 +21,7 @@ from core.app_assets.constants import AppAssetsAttrs
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@ -62,6 +64,7 @@ from core.skill.entities.skill_document import SkillDocument
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||
from core.skill.skill_compiler import SkillCompiler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables import (
|
||||
@ -355,6 +358,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
reasoning_content = ""
|
||||
usage = generation_data.usage
|
||||
finish_reason = generation_data.finish_reason
|
||||
if generation_data.structured_output:
|
||||
structured_output = generation_data.structured_output
|
||||
|
||||
# Unified process_data building
|
||||
process_data = {
|
||||
@ -1900,7 +1905,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
@ -1921,11 +1926,22 @@ class LLMNode(Node[LLMNodeData]):
|
||||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
model_features = self._get_model_features(model_instance)
|
||||
structured_output_schema = None
|
||||
if self._node_data.structured_output_enabled:
|
||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
output_tools = build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=[session.bash_tool],
|
||||
tools=[session.bash_tool, *output_tools],
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 100,
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
@ -1936,7 +1952,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
@ -2011,6 +2027,20 @@ class LLMNode(Node[LLMNodeData]):
|
||||
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
||||
continue
|
||||
|
||||
structured_output_schema = None
|
||||
if self._node_data.structured_output_enabled:
|
||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
tool_instances.extend(
|
||||
build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_instances
|
||||
|
||||
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
@ -2480,6 +2510,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
finish_reason=aggregate.finish_reason,
|
||||
files=aggregate.files,
|
||||
trace=trace_state.trace_segments,
|
||||
structured_output=aggregate.structured_output,
|
||||
)
|
||||
|
||||
def _process_tool_outputs(
|
||||
@ -2494,19 +2525,54 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if isinstance(output, AgentLog):
|
||||
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
||||
else:
|
||||
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
|
||||
continue
|
||||
except StopIteration as exception:
|
||||
if isinstance(getattr(exception, "value", None), AgentResult):
|
||||
state.agent.agent_result = exception.value
|
||||
|
||||
if state.agent.agent_result:
|
||||
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
|
||||
output_payload = state.agent.agent_result.output
|
||||
structured_output_data: Mapping[str, Any] | None = None
|
||||
if isinstance(output_payload, AgentResult.StructuredOutput):
|
||||
output_kind = output_payload.output_kind
|
||||
if output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
|
||||
raise ValueError("Agent returned illegal output")
|
||||
if output_kind in {AgentOutputKind.FINAL_OUTPUT_ANSWER, AgentOutputKind.OUTPUT_TEXT}:
|
||||
if not output_payload.output_text:
|
||||
raise ValueError("Agent returned empty text output")
|
||||
state.aggregate.text = output_payload.output_text
|
||||
elif output_kind == AgentOutputKind.FINAL_STRUCTURED_OUTPUT:
|
||||
if output_payload.output_data is None:
|
||||
raise ValueError("Agent returned empty structured output")
|
||||
else:
|
||||
raise ValueError("Agent returned unsupported output kind")
|
||||
|
||||
if output_payload.output_data is not None:
|
||||
if not isinstance(output_payload.output_data, Mapping):
|
||||
raise ValueError("Agent returned invalid structured output")
|
||||
structured_output_data = output_payload.output_data
|
||||
else:
|
||||
if not output_payload:
|
||||
raise ValueError("Agent returned empty output")
|
||||
state.aggregate.text = str(output_payload)
|
||||
|
||||
state.aggregate.files = state.agent.agent_result.files
|
||||
if state.agent.agent_result.usage:
|
||||
state.aggregate.usage = state.agent.agent_result.usage
|
||||
if state.agent.agent_result.finish_reason:
|
||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||
|
||||
if structured_output_data is not None:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
converted_output = convert_file_refs_in_output(
|
||||
output=structured_output_data,
|
||||
json_schema=output_schema,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
state.aggregate.structured_output = LLMStructuredOutput(structured_output=converted_output)
|
||||
|
||||
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
||||
yield from self._close_streams()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user