mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat: support structured output in sandbox and tool mode
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
@ -4,7 +4,7 @@ from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
@ -13,6 +13,7 @@ from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
@ -106,7 +107,6 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
@ -118,7 +118,7 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
@ -133,17 +133,10 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
# No more expect streaming data
|
||||
continue
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
else:
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
@ -156,7 +149,6 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
@ -265,7 +257,22 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
output_payload = result.output
|
||||
if isinstance(output_payload, AgentResult.StructuredOutput):
|
||||
if output_payload.output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
|
||||
raise ValueError("Agent returned illegal output")
|
||||
if output_payload.output_kind not in {
|
||||
AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
AgentOutputKind.OUTPUT_TEXT,
|
||||
}:
|
||||
raise ValueError("Agent did not return text output")
|
||||
if not output_payload.output_text:
|
||||
raise ValueError("Agent returned empty text output")
|
||||
final_answer = output_payload.output_text
|
||||
else:
|
||||
if not output_payload:
|
||||
raise ValueError("Agent returned empty output")
|
||||
final_answer = str(output_payload)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
@ -282,6 +289,17 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
if False:
|
||||
yield LLMResultChunk(
|
||||
model="",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Union, cast
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -36,6 +37,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@ -251,6 +253,14 @@ class BaseAgentRunner(AppRunner):
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
output_tools = build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
)
|
||||
for tool in output_tools:
|
||||
tool_instances[tool.entity.identity.name] = tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agent.output_tools import FINAL_OUTPUT_TOOL
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
@ -41,9 +42,9 @@ class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
action_input: dict[str, Any] | str
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert to dictionary.
|
||||
"""
|
||||
@ -62,9 +63,9 @@ class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
return self.action is None or (
|
||||
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||
)
|
||||
if self.action is None:
|
||||
return False
|
||||
return self.action.action_name.lower() == FINAL_OUTPUT_TOOL
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
@ -125,7 +126,7 @@ class ExecutionContext(BaseModel):
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
def with_updates(self, **kwargs: Any) -> "ExecutionContext":
|
||||
"""Create a new context with updated fields."""
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
@ -178,12 +179,35 @@ class AgentLog(BaseModel):
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||
|
||||
|
||||
class AgentOutputKind(StrEnum):
|
||||
"""
|
||||
Agent output kind.
|
||||
"""
|
||||
|
||||
OUTPUT_TEXT = "output_text"
|
||||
FINAL_OUTPUT_ANSWER = "final_output_answer"
|
||||
FINAL_STRUCTURED_OUTPUT = "final_structured_output"
|
||||
ILLEGAL_OUTPUT = "illegal_output"
|
||||
|
||||
|
||||
OutputKind = AgentOutputKind
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
text: str = Field(default="", description="The generated text")
|
||||
class StructuredOutput(BaseModel):
|
||||
"""
|
||||
Structured output payload from output tools.
|
||||
"""
|
||||
|
||||
output_kind: AgentOutputKind
|
||||
output_text: str | None = None
|
||||
output_data: Mapping[str, Any] | None = None
|
||||
|
||||
output: str | StructuredOutput = Field(default="", description="The generated output")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
@ -10,46 +10,52 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name = None
|
||||
action_input = None
|
||||
if isinstance(action, str):
|
||||
def parse_action(action: Any) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name: str | None = None
|
||||
action_input: Any | None = None
|
||||
parsed_action: Any = action
|
||||
if isinstance(parsed_action, str):
|
||||
try:
|
||||
action = json.loads(action, strict=False)
|
||||
parsed_action = json.loads(parsed_action, strict=False)
|
||||
except json.JSONDecodeError:
|
||||
return action or ""
|
||||
return parsed_action or ""
|
||||
|
||||
# cohere always returns a list
|
||||
if isinstance(action, list) and len(action) == 1:
|
||||
action = action[0]
|
||||
if isinstance(parsed_action, list):
|
||||
action_list: list[Any] = cast(list[Any], parsed_action)
|
||||
if len(action_list) == 1:
|
||||
parsed_action = action_list[0]
|
||||
|
||||
for key, value in action.items():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
if isinstance(parsed_action, dict):
|
||||
action_dict: dict[str, Any] = cast(dict[str, Any], parsed_action)
|
||||
for key, value in action_dict.items():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
elif isinstance(value, str):
|
||||
action_name = value
|
||||
else:
|
||||
return json.dumps(parsed_action)
|
||||
|
||||
if action_name is not None and action_input is not None:
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json.dumps(action)
|
||||
return json.dumps(parsed_action)
|
||||
|
||||
def extra_json_from_code_block(code_block) -> list[Union[list, dict]]:
|
||||
def extra_json_from_code_block(code_block: str) -> list[dict[str, Any] | list[Any]]:
|
||||
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
|
||||
if not blocks:
|
||||
return []
|
||||
try:
|
||||
json_blocks = []
|
||||
json_blocks: list[dict[str, Any] | list[Any]] = []
|
||||
for block in blocks:
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
json_blocks.append(json.loads(json_text, strict=False))
|
||||
return json_blocks
|
||||
except:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
code_block_cache = ""
|
||||
|
||||
57
api/core/agent/output_tools.py
Normal file
57
api/core/agent/output_tools.py
Normal file
@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
OUTPUT_TOOL_PROVIDER = "agent_output"
|
||||
|
||||
OUTPUT_TEXT_TOOL = "output_text"
|
||||
FINAL_OUTPUT_TOOL = "final_output_answer"
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL = "final_structured_output"
|
||||
ILLEGAL_OUTPUT_TOOL = "illegal_output"
|
||||
|
||||
OUTPUT_TOOL_NAMES: Sequence[str] = (
|
||||
OUTPUT_TEXT_TOOL,
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
)
|
||||
|
||||
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
|
||||
|
||||
|
||||
def build_agent_output_tools(
|
||||
*,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
tool_invoke_from: ToolInvokeFrom,
|
||||
structured_output_schema: dict[str, Any] | None = None,
|
||||
) -> list[Tool]:
|
||||
tools: list[Tool] = []
|
||||
for tool_name in OUTPUT_TOOL_NAMES:
|
||||
tool = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id=OUTPUT_TOOL_PROVIDER,
|
||||
tool_name=tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
|
||||
if tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and structured_output_schema:
|
||||
tool.entity = tool.entity.model_copy(deep=True)
|
||||
for parameter in tool.entity.parameters:
|
||||
if parameter.name != "data":
|
||||
continue
|
||||
parameter.type = ToolParameter.ToolParameterType.OBJECT
|
||||
parameter.form = ToolParameter.ToolParameterForm.LLM
|
||||
parameter.required = True
|
||||
parameter.input_schema = structured_output_schema
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
@ -1,10 +1,18 @@
|
||||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Protocol, Union, cast
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -42,9 +50,24 @@ class FunctionCallStrategy(AgentPattern):
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
structured_output_payload: dict[str, Any] | None = None
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
class _LLMInvoker(Protocol):
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
stream: Literal[False],
|
||||
user: str | None,
|
||||
callbacks: list[Any],
|
||||
) -> LLMResult: ...
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
@ -54,8 +77,11 @@ class FunctionCallStrategy(AgentPattern):
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
# On last iteration, restrict tools to output tools
|
||||
if iteration_step == max_iterations:
|
||||
current_tools = [tool for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET]
|
||||
else:
|
||||
current_tools = prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
@ -72,31 +98,41 @@ class FunctionCallStrategy(AgentPattern):
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
invoker = cast(_LLMInvoker, self.model_instance)
|
||||
chunks = invoker.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream=False,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
chunks, round_usage, model_log, emit_chunks=False
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
if not tool_calls:
|
||||
tool_calls = [
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
{
|
||||
"raw": response_content,
|
||||
},
|
||||
)
|
||||
]
|
||||
response_content = ""
|
||||
|
||||
messages.append(self._create_assistant_message("", tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
@ -105,14 +141,27 @@ class FunctionCallStrategy(AgentPattern):
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
terminal_tool_seen = False
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
if tool_name == OUTPUT_TEXT_TOOL:
|
||||
output_text_payload = self._format_output_text(tool_args.get("text"))
|
||||
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
|
||||
data = tool_args.get("data")
|
||||
structured_output_payload = cast(dict[str, Any] | None, data)
|
||||
elif tool_name == FINAL_OUTPUT_TOOL:
|
||||
final_text = self._format_output_text(tool_args.get("text"))
|
||||
terminal_tool_seen = True
|
||||
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
if terminal_tool_seen:
|
||||
function_call_state = False
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
@ -131,8 +180,28 @@ class FunctionCallStrategy(AgentPattern):
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=None,
|
||||
)
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=None,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
output=output_payload,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
@ -143,6 +212,8 @@ class FunctionCallStrategy(AgentPattern):
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
*,
|
||||
emit_chunks: bool,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
@ -174,7 +245,8 @@ class FunctionCallStrategy(AgentPattern):
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
if emit_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
@ -189,11 +261,12 @@ class FunctionCallStrategy(AgentPattern):
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
if emit_chunks:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
@ -203,6 +276,14 @@ class FunctionCallStrategy(AgentPattern):
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
@staticmethod
|
||||
def _format_output_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
|
||||
@ -4,10 +4,17 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
@ -67,6 +74,8 @@ class ReActStrategy(AgentPattern):
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
structured_output_payload: dict[str, Any] | None = None
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
@ -84,10 +93,13 @@ class ReActStrategy(AgentPattern):
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
# Build prompt with tool restrictions on last iteration
|
||||
if iteration_step == max_iterations:
|
||||
tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET]
|
||||
else:
|
||||
tools_for_prompt = self.tools
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
@ -109,18 +121,21 @@ class ReActStrategy(AgentPattern):
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
chunks = cast(
|
||||
Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
),
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
chunks, round_usage, model_log, current_messages, emit_chunks=False
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
@ -134,28 +149,46 @@ class ReActStrategy(AgentPattern):
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
if scratchpad.action is None:
|
||||
illegal_action = AgentScratchpadUnit.Action(
|
||||
action_name=ILLEGAL_OUTPUT_TOOL,
|
||||
action_input={"raw": scratchpad.thought or ""},
|
||||
)
|
||||
scratchpad.action = illegal_action
|
||||
scratchpad.action_str = json.dumps(illegal_action.to_dict())
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
action_name = scratchpad.action.action_name
|
||||
if action_name == FINAL_OUTPUT_TOOL:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_text = self._format_output_text(scratchpad.action.action_input.get("text"))
|
||||
else:
|
||||
final_text = self._format_output_text(scratchpad.action.action_input)
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
react_state = False
|
||||
else:
|
||||
if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
||||
output_text_payload = scratchpad.action.action_input.get("text")
|
||||
elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(
|
||||
scratchpad.action.action_input, dict
|
||||
):
|
||||
data = scratchpad.action.action_input.get("data")
|
||||
if isinstance(data, dict):
|
||||
structured_output_payload = data
|
||||
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
@ -173,15 +206,38 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
output=output_payload,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage"),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
tools: list[Tool] | None,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
@ -198,9 +254,9 @@ class ReActStrategy(AgentPattern):
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
if tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
@ -253,6 +309,8 @@ class ReActStrategy(AgentPattern):
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
*,
|
||||
emit_chunks: bool,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
@ -306,14 +364,16 @@ class ReActStrategy(AgentPattern):
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
@ -337,6 +397,14 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
@staticmethod
|
||||
def _format_output_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
|
||||
@ -7,7 +7,7 @@ You have access to the following tools:
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish.
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
@ -32,12 +32,14 @@ Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
"action": "final_output_answer",
|
||||
"action_input": {
|
||||
"text": "Final response to human"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
{{historic_messages}}
|
||||
Question: {{query}}
|
||||
{{agent_scratchpad}}
|
||||
@ -56,7 +58,7 @@ You have access to the following tools:
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish.
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
@ -81,12 +83,14 @@ Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
"action": "final_output_answer",
|
||||
"action_input": {
|
||||
"text": "Final response to human"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user