mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
@ -5,13 +5,10 @@ import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Protocol, Union, cast
|
||||
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
TERMINAL_OUTPUT_MESSAGE,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
@ -36,21 +33,11 @@ class FunctionCallStrategy(AgentPattern):
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
stop: list[str]
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
tool_instance_names = {tool.entity.identity.name for tool in self.tools}
|
||||
available_output_tool_names = {tool.name for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET}
|
||||
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
||||
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_OUTPUT_TOOL
|
||||
else:
|
||||
raise ValueError("No terminal output tool configured")
|
||||
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
@ -60,10 +47,10 @@ class FunctionCallStrategy(AgentPattern):
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
structured_output_payload: dict[str, Any] | None = None
|
||||
final_tool_args: dict[str, Any] = {"!!!": "!!!"}
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
terminal_output_seen = False
|
||||
|
||||
class _LLMInvoker(Protocol):
|
||||
def invoke_llm(
|
||||
@ -87,11 +74,7 @@ class FunctionCallStrategy(AgentPattern):
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, restrict tools to output tools
|
||||
if iteration_step == max_iterations:
|
||||
current_tools = [tool for tool in prompt_tools if tool.name in available_output_tool_names]
|
||||
else:
|
||||
current_tools = prompt_tools
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
@ -112,7 +95,7 @@ class FunctionCallStrategy(AgentPattern):
|
||||
chunks = invoker.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
tools=prompt_tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.context.user_id,
|
||||
@ -124,19 +107,15 @@ class FunctionCallStrategy(AgentPattern):
|
||||
chunks, round_usage, model_log, emit_chunks=False
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
if not allow_illegal_output:
|
||||
raise ValueError("Model did not call any tools")
|
||||
tool_calls = [
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
{
|
||||
"raw": response_content,
|
||||
},
|
||||
)
|
||||
]
|
||||
response_content = ""
|
||||
if response_content:
|
||||
replaced_tool_call = (
|
||||
str(uuid.uuid4()),
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
{
|
||||
"raw": response_content,
|
||||
},
|
||||
)
|
||||
tool_calls.append(replaced_tool_call)
|
||||
|
||||
messages.append(self._create_assistant_message("", tool_calls))
|
||||
|
||||
@ -149,35 +128,23 @@ class FunctionCallStrategy(AgentPattern):
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
terminal_tool_seen = False
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
if tool_name == OUTPUT_TEXT_TOOL:
|
||||
output_text_payload = self._format_output_text(tool_args.get("text"))
|
||||
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
|
||||
data = tool_args.get("data")
|
||||
structured_output_payload = cast(dict[str, Any] | None, data)
|
||||
if tool_name == terminal_tool_name:
|
||||
terminal_tool_seen = True
|
||||
elif tool_name == FINAL_OUTPUT_TOOL:
|
||||
final_text = self._format_output_text(tool_args.get("text"))
|
||||
if tool_name == terminal_tool_name:
|
||||
terminal_tool_seen = True
|
||||
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
if terminal_tool_seen:
|
||||
terminal_output_seen = True
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
if tool_response == TERMINAL_OUTPUT_MESSAGE:
|
||||
function_call_state = False
|
||||
final_tool_args = tool_args
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
@ -196,31 +163,16 @@ class FunctionCallStrategy(AgentPattern):
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT,
|
||||
output_text=None,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=None,
|
||||
)
|
||||
output_payload: str | dict
|
||||
output_text = final_tool_args.get("text")
|
||||
output_structured_payload = final_tool_args.get("data")
|
||||
|
||||
if isinstance(output_structured_payload, dict):
|
||||
output_payload = output_structured_payload
|
||||
elif isinstance(output_text, str):
|
||||
output_payload = output_text
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=None,
|
||||
)
|
||||
raise ValueError("Final output is not a string or structured data.")
|
||||
|
||||
return AgentResult(
|
||||
output=output_payload,
|
||||
|
||||
Reference in New Issue
Block a user