feat: add agent package

This commit is contained in:
Novice
2025-12-09 11:26:02 +08:00
parent 15fec024c0
commit 2b23c43434
71 changed files with 5945 additions and 1213 deletions

View File

@ -0,0 +1,358 @@
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentEntity, AgentLog, AgentResult
from core.agent.patterns.strategy_factory import StrategyFactory
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
logger = logging.getLogger(__name__)
class AgentAppRunner(BaseAgentRunner):
def _create_tool_invoke_hook(self, message: Message):
"""
Create a tool invoke hook that uses ToolEngine.agent_invoke.
This hook handles file creation and returns proper meta information.
"""
# Get trace manager from app generate entity
trace_manager = self.application_generate_entity.trace_manager
def tool_invoke_hook(
tool: Tool, tool_args: dict[str, Any], tool_name: str
) -> tuple[str, list[str], ToolInvokeMeta]:
"""Hook that uses agent_invoke for proper file and meta handling."""
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters=tool_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=message.id,
conversation_id=self.conversation.id,
)
# Publish files and track IDs
for message_file_id in message_files:
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id),
PublishFrom.APPLICATION_MANAGER,
)
self._current_message_file_ids.append(message_file_id)
return tool_invoke_response, message_files, tool_invoke_meta
return tool_invoke_hook
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run Agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, _ = self._init_prompt_tools()
assert app_config.agent
# Create tool invoke hook for agent_invoke
tool_invoke_hook = self._create_tool_invoke_hook(message)
# Get instruction for ReAct strategy
instruction = self.app_config.prompt_template.simple_prompt_template or ""
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
model_features=self.model_features,
model_instance=self.model_instance,
tools=list(tool_instances.values()),
files=list(self.files),
max_iterations=app_config.agent.max_iteration,
context=self.build_execution_context(),
agent_strategy=self.config.strategy,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Initialize state variables
current_agent_thought_id = None
has_published_thought = False
current_tool_name: str | None = None
self._current_message_file_ids = []
# organize prompt messages
prompt_messages = self._organize_prompt_messages()
# Run strategy
generator = strategy.run(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
stream=True,
)
# Consume generator and collect result
result: AgentResult | None = None
try:
while True:
try:
output = next(generator)
except StopIteration as e:
# Generator finished, get the return value
result = e.value
break
if isinstance(output, LLMResultChunk):
# Handle LLM chunk
if current_agent_thought_id and not has_published_thought:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
has_published_thought = True
yield output
elif isinstance(output, AgentLog):
# Handle Agent Log using log_type for type-safe dispatch
if output.status == AgentLog.LogStatus.START:
if output.log_type == AgentLog.LogType.ROUND:
# Start of a new round
message_file_ids: list[str] = []
current_agent_thought_id = self.create_agent_thought(
message_id=message.id,
message="",
tool_name="",
tool_input="",
messages_ids=message_file_ids,
)
has_published_thought = False
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call start - extract data from structured fields
current_tool_name = output.data.get("tool_name", "")
tool_input = output.data.get("tool_args", {})
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=current_tool_name,
tool_input=tool_input,
thought=None,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.status == AgentLog.LogStatus.SUCCESS:
if output.log_type == AgentLog.LogType.THOUGHT:
pass
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call finished
tool_output = output.data.get("output")
# Get meta from strategy output (now properly populated)
tool_meta = output.data.get("meta")
# Wrap tool_meta with tool_name as key (required by agent_service)
if tool_meta and current_tool_name:
tool_meta = {current_tool_name: tool_meta}
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=None,
observation=tool_output,
tool_invoke_meta=tool_meta,
answer=None,
messages_ids=self._current_message_file_ids,
)
# Clear message file ids after saving
self._current_message_file_ids = []
current_tool_name = None
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.ROUND:
if current_agent_thought_id is None:
continue
# Round finished - save LLM usage and answer
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
llm_result = output.data.get("llm_result")
final_answer = output.data.get("final_answer")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=llm_result,
observation=None,
tool_invoke_meta=None,
answer=final_answer,
messages_ids=[],
llm_usage=llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
except Exception:
# Re-raise any other exceptions
raise
# Process final result
if isinstance(result, AgentResult):
final_answer = result.text
usage = result.usage or LLMUsage.empty_usage()
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_messages and prompt_template:
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages or []
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
# For ReAct strategy, use the agent prompt template
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
prompt_template = self.config.prompt.first_prompt
else:
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@ -5,7 +5,7 @@ from typing import Union, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -114,9 +114,20 @@ class BaseAgentRunner(AppRunner):
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.model_features = features
self.query: str | None = ""
self._current_thoughts: list[PromptMessage] = []
def build_execution_context(self) -> ExecutionContext:
"""Build execution context."""
return ExecutionContext(
user_id=self.user_id,
app_id=self.app_config.app_id,
conversation_id=self.conversation.id,
message_id=self.message.id,
tenant_id=self.tenant_id,
)
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:

View File

@ -1,431 +0,0 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
logger = logging.getLogger(__name__)
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
def run(
self,
message: Message,
query: str,
inputs: Mapping[str, str],
) -> Generator:
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query)
trace_manager = app_generate_entity.trace_manager
# check model mode
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
assert app_config.agent
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
agent_thought_id = "" # Initialize agent_thought_id
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
self._prompt_messages_tools = []
message_file_ids: list[str] = []
agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
stop=app_generate_entity.model_conf.stop,
stream=True,
user=self.user_id,
callbacks=[],
)
usage_dict: dict[str, LLMUsage | None] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought or "",
observation="",
answer=scratchpad.agent_response or "",
messages_ids=[],
llm_usage=usage_dict["usage"],
)
if not scratchpad.is_final():
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ""
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
except TypeError:
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
tool_instances=tool_instances,
message_file_ids=message_file_ids,
trace_manager=trace_manager,
)
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought or "",
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict["usage"],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name="",
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[],
)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: Mapping[str, Tool],
message_file_ids: list[str],
trace_manager: TraceQueueManager | None = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
:param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta
"""
# action is tool call, invoke tool
tool_call_name = action.action_name
tool_call_args = action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
"""
fill in inputs from external data tools
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception:
continue
return instruction
def _init_react_state(self, query):
"""
init agent scratchpad
"""
self._query = query
self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages()
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
"""
message = ""
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
else:
message += f"Thought: {scratchpad.thought}\n\n"
if scratchpad.action_str:
message += f"Action: {scratchpad.action_str}\n\n"
if scratchpad.observation:
message += f"Observation: {scratchpad.observation}\n\n"
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory,
).get_prompt()
return historic_prompts

View File

@ -1,118 +0,0 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.utils.encoders import jsonable_encoder
class CotChatAgentRunner(CotAgentRunner):
def _organize_system_prompt(self) -> SystemPromptMessage:
"""
Organize system prompt
"""
assert self.app_config.agent
assert self.app_config.agent.prompt
prompt_entity = self.app_config.agent.prompt
if not prompt_entity:
raise ValueError("Agent prompt configuration is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_message.content += f"Observation: {unit.observation}\n\n"
assistant_messages = [assistant_message]
# query messages
query_messages = self._organize_user_query(self._query, [])
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content="continue"),
]
else:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
messages = [system_message, *historic_messages, *query_messages]
# join all messages
return messages

View File

@ -1,87 +0,0 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
class CotCompletionAgentRunner(CotAgentRunner):
def _organize_instruction_prompt(self) -> str:
"""
Organize instruction prompt
"""
if self.app_config.agent is None:
raise ValueError("Agent configuration is not set")
prompt_entity = self.app_config.agent.prompt
if prompt_entity is None:
raise ValueError("prompt entity is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
"""
Organize historic prompt
"""
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
historic_prompt = ""
for message in historic_prompt_messages:
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
return historic_prompt
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages
"""
# organize system prompt
system_prompt = self._organize_instruction_prompt()
# organize historic prompt messages
historic_prompt = self._organize_historic_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ""
for unit in agent_scratchpad or []:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
else:
assistant_prompt += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_prompt += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_prompt += f"Observation: {unit.observation}\n\n"
# query messages
query_prompt = f"Question: {self._query}"
# join all messages
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]

View File

@ -1,3 +1,5 @@
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Union
@ -92,3 +94,94 @@ class AgentInvokeMessage(ToolInvokeMessage):
"""
pass
class ExecutionContext(BaseModel):
"""Execution context containing trace and audit information.
This context carries all the IDs and metadata that are not part of
the core business logic but needed for tracing, auditing, and
correlation purposes.
"""
user_id: str | None = None
app_id: str | None = None
conversation_id: str | None = None
message_id: str | None = None
tenant_id: str | None = None
@classmethod
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
"""Create a minimal context with only essential fields."""
return cls(user_id=user_id)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for passing to legacy code."""
return {
"user_id": self.user_id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"message_id": self.message_id,
"tenant_id": self.tenant_id,
}
def with_updates(self, **kwargs) -> "ExecutionContext":
"""Create a new context with updated fields."""
data = self.to_dict()
data.update(kwargs)
return ExecutionContext(
user_id=data.get("user_id"),
app_id=data.get("app_id"),
conversation_id=data.get("conversation_id"),
message_id=data.get("message_id"),
tenant_id=data.get("tenant_id"),
)
class AgentLog(BaseModel):
"""
Agent Log.
"""
class LogType(StrEnum):
"""Type of agent log entry."""
ROUND = "round" # A complete iteration round
THOUGHT = "thought" # LLM thinking/reasoning
TOOL_CALL = "tool_call" # Tool invocation
class LogMetadata(StrEnum):
STARTED_AT = "started_at"
FINISHED_AT = "finished_at"
ELAPSED_TIME = "elapsed_time"
TOTAL_PRICE = "total_price"
TOTAL_TOKENS = "total_tokens"
PROVIDER = "provider"
CURRENCY = "currency"
LLM_USAGE = "llm_usage"
class LogStatus(StrEnum):
START = "start"
ERROR = "error"
SUCCESS = "success"
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
label: str = Field(..., description="The label of the log")
log_type: LogType = Field(..., description="The type of the log")
parent_id: str | None = Field(default=None, description="Leave empty for root log")
error: str | None = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
class AgentResult(BaseModel):
"""
Agent execution result.
"""
text: str = Field(default="", description="The generated text")
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
usage: Any | None = Field(default=None, description="LLM usage statistics")
finish_reason: str | None = Field(default=None, description="Reason for completion")

View File

@ -1,465 +0,0 @@
import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
# continue to run until there is not any tool call
function_call_state = True
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
prompt_messages_tools = []
message_file_ids: list[str] = []
agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=prompt_messages_tools,
stop=app_generate_entity.model_conf.stop,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ""
# save tool call names and inputs
tool_call_names = ""
tool_call_inputs = ""
current_llm_usage = None
if isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except TypeError:
# fallback: force ASCII to handle non-serializable objects
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += str(chunk.delta.message.content)
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except TypeError:
# fallback: force ASCII to handle non-serializable objects
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if result.usage:
increase_usage(llm_usage, result.usage)
current_llm_usage = result.usage
if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data
else:
response += str(result.message.content)
if not result.message.content:
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
),
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
),
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
tool_invoke_meta=None,
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
}
else:
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict(),
}
tool_responses.append(tool_response)
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=str(tool_response["tool_response"]),
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name="",
tool_input="",
thought="",
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
observation={
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer="",
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
Check if there is any tool call in llm result chunk
"""
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
"""
if llm_result.message.tool_calls:
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract tool calls from llm result chunk
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_messages and prompt_template:
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages or []
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@ -0,0 +1,67 @@
# Agent Patterns
A unified agent pattern module that provides common agent execution strategies for both Agent V2 nodes and Agent Applications in Dify.
## Overview
This module implements a strategy pattern for agent execution, automatically selecting the appropriate strategy based on model capabilities. It serves as the core engine for agent-based interactions across different components of the Dify platform.
## Key Features
### 1. Multiple Agent Strategies
- **Function Call Strategy**: Leverages native function/tool calling capabilities of advanced LLMs (e.g., GPT-4, Claude)
- **ReAct Strategy**: Implements the ReAct (Reasoning + Acting) approach for models without native function calling support
### 2. Automatic Strategy Selection
The `StrategyFactory` intelligently selects the optimal strategy based on model features:
- Models with `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL` capabilities → Function Call Strategy
- Other models → ReAct Strategy
### 3. Unified Interface
- Common base class (`AgentPattern`) ensures consistent behavior across strategies
- Seamless integration with both workflow nodes and standalone agent applications
- Standardized input/output formats for easy consumption
### 4. Advanced Capabilities
- **Streaming Support**: Real-time response streaming for better user experience
- **File Handling**: Built-in support for processing and managing files during agent execution
- **Iteration Control**: Configurable maximum iterations with safety limits (capped at 99)
- **Tool Management**: Flexible tool integration supporting various tool types
- **Context Propagation**: Execution context for tracing, auditing, and debugging
## Architecture
```
agent/patterns/
├── base.py # Abstract base class defining the agent pattern interface
├── function_call.py # Implementation using native LLM function calling
├── react.py # Implementation using ReAct prompting approach
└── strategy_factory.py # Factory for automatic strategy selection
```
## Usage
The module is designed to be used by:
1. **Agent V2 Nodes**: In workflow orchestration for complex agent tasks
1. **Agent Applications**: For standalone conversational agents
1. **Custom Implementations**: As a foundation for building specialized agent behaviors
## Integration Points
- **Model Runtime**: Interfaces with Dify's model runtime for LLM interactions
- **Tool System**: Integrates with the tool framework for external capabilities
- **Memory Management**: Compatible with conversation memory systems
- **File Management**: Handles file inputs/outputs during agent execution
## Benefits
1. **Consistency**: Unified implementation reduces code duplication and maintenance overhead
1. **Flexibility**: Easy to extend with new strategies or customize existing ones
1. **Performance**: Optimized for each model's capabilities to ensure best performance
1. **Reliability**: Built-in safety mechanisms and error handling

View File

@ -0,0 +1,19 @@
"""Agent patterns module.
This module provides different strategies for agent execution:
- FunctionCallStrategy: Uses native function/tool calling
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
- StrategyFactory: Factory for creating strategies based on model features
"""
from .base import AgentPattern
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
from .strategy_factory import StrategyFactory
__all__ = [
"AgentPattern",
"FunctionCallStrategy",
"ReActStrategy",
"StrategyFactory",
]

View File

@ -0,0 +1,444 @@
"""Base class for agent strategies."""
from __future__ import annotations
import json
import re
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
# Type alias for tool invoke hook
# Returns: (response_content, message_file_ids, tool_invoke_meta)
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
class AgentPattern(ABC):
"""Base class for agent execution strategies."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
):
"""Initialize the agent strategy."""
self.model_instance = model_instance
self.tools = tools
self.context = context
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
self.workflow_call_depth = workflow_call_depth
self.files: list[File] = files
self.tool_invoke_hook = tool_invoke_hook
@abstractmethod
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the agent strategy."""
pass
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
"""Accumulate LLM usage statistics."""
if not total_usage.get("usage"):
# Create a copy to avoid modifying the original
total_usage["usage"] = LLMUsage(
prompt_tokens=delta_usage.prompt_tokens,
prompt_unit_price=delta_usage.prompt_unit_price,
prompt_price_unit=delta_usage.prompt_price_unit,
prompt_price=delta_usage.prompt_price,
completion_tokens=delta_usage.completion_tokens,
completion_unit_price=delta_usage.completion_unit_price,
completion_price_unit=delta_usage.completion_price_unit,
completion_price=delta_usage.completion_price,
total_tokens=delta_usage.total_tokens,
total_price=delta_usage.total_price,
currency=delta_usage.currency,
latency=delta_usage.latency,
)
else:
current: LLMUsage = total_usage["usage"]
current.prompt_tokens += delta_usage.prompt_tokens
current.completion_tokens += delta_usage.completion_tokens
current.total_tokens += delta_usage.total_tokens
current.prompt_price += delta_usage.prompt_price
current.completion_price += delta_usage.completion_price
current.total_price += delta_usage.total_price
def _extract_content(self, content: Any) -> str:
"""Extract text content from message content."""
if isinstance(content, list):
# Content items are PromptMessageContentUnionTypes
text_parts = []
for c in content:
# Check if it's a TextPromptMessageContent (which has data attribute)
if isinstance(c, TextPromptMessageContent):
text_parts.append(c.data)
return "".join(text_parts)
return str(content)
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
"""Check if chunk contains tool calls."""
# LLMResultChunk always has delta attribute
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
def _has_tool_calls_result(self, result: LLMResult) -> bool:
"""Check if result contains tool calls (non-streaming)."""
# LLMResult always has message attribute
return bool(result.message and result.message.tool_calls)
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from streaming chunk."""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
if chunk.delta.message and chunk.delta.message.tool_calls:
for tool_call in chunk.delta.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from non-streaming result."""
tool_calls = []
if result.message and result.message.tool_calls:
for tool_call in result.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_text_from_message(self, message: PromptMessage) -> str:
"""Extract text content from a prompt message."""
# PromptMessage always has content attribute
content = message.content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract text from content list
text_parts = []
for item in content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return " ".join(text_parts)
return ""
def _create_log(
self,
label: str,
log_type: AgentLog.LogType,
status: AgentLog.LogStatus,
data: dict[str, Any] | None = None,
parent_id: str | None = None,
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
) -> AgentLog:
"""Create a new AgentLog with standard metadata."""
metadata = {
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
}
if extra_metadata:
metadata.update(extra_metadata)
return AgentLog(
label=label,
log_type=log_type,
status=status,
data=data or {},
parent_id=parent_id,
metadata=metadata,
)
def _finish_log(
self,
log: AgentLog,
data: dict[str, Any] | None = None,
usage: LLMUsage | None = None,
) -> AgentLog:
"""Finish an AgentLog by updating its status and metadata."""
log.status = AgentLog.LogStatus.SUCCESS
if data is not None:
log.data = data
# Calculate elapsed time
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
finished_at = time.perf_counter()
# Update metadata
log.metadata = {
**log.metadata,
AgentLog.LogMetadata.FINISHED_AT: finished_at,
AgentLog.LogMetadata.ELAPSED_TIME: finished_at - started_at,
}
# Add usage information if provided
if usage:
log.metadata.update(
{
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
AgentLog.LogMetadata.CURRENCY: usage.currency,
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
AgentLog.LogMetadata.LLM_USAGE: usage,
}
)
return log
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
"""
Replace file references in tool arguments with actual File objects.
Args:
tool_args: Dictionary of tool arguments
Returns:
Updated tool arguments with file references replaced
"""
# Process each argument in the dictionary
processed_args: dict[str, Any] = {}
for key, value in tool_args.items():
processed_args[key] = self._process_file_reference(value)
return processed_args
def _process_file_reference(self, data: Any) -> Any:
"""
Recursively process data to replace file references.
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
Args:
data: The data to process (can be dict, list, str, or other types)
Returns:
Processed data with file references replaced
"""
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
if isinstance(data, dict):
# Process dictionary recursively
return {key: self._process_file_reference(value) for key, value in data.items()}
elif isinstance(data, list):
# Process list recursively
return [self._process_file_reference(item) for item in data]
elif isinstance(data, str):
# Check for single file pattern [File: file_id]
single_match = single_file_pattern.match(data.strip())
if single_match:
file_id = single_match.group(1).strip()
# Find the file in self.files
for file in self.files:
if file.id and str(file.id) == file_id:
return file
# If file not found, return original value
return data
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
multiple_match = multiple_files_pattern.match(data.strip())
if multiple_match:
file_ids_str = multiple_match.group(1).strip()
# Split by comma and strip whitespace
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
# Find all matching files
matched_files: list[File] = []
for file_id in file_ids:
for file in self.files:
if file.id and str(file.id) == file_id:
matched_files.append(file)
break
# Return list of files if any were found, otherwise return original
return matched_files or data
return data
else:
# Return other types as-is
return data
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
"""Create a text chunk for streaming."""
return LLMResultChunk(
model=self.model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
usage=None,
),
system_fingerprint="",
)
def _invoke_tool(
self,
tool_instance: Tool,
tool_args: dict[str, Any],
tool_name: str,
) -> tuple[str, list[File], ToolInvokeMeta | None]:
"""
Invoke a tool and collect its response.
Args:
tool_instance: The tool instance to invoke
tool_args: Tool arguments
tool_name: Name of the tool
Returns:
Tuple of (response_content, tool_files, tool_invoke_meta)
"""
# Process tool_args to replace file references with actual File objects
tool_args = self._replace_file_references(tool_args)
# If a tool invoke hook is set, use it instead of generic_invoke
if self.tool_invoke_hook:
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
# The caller (AgentAppRunner) handles file publishing
return response_content, [], tool_invoke_meta
# Default: use generic_invoke for workflow scenarios
# Import here to avoid circular import
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
tool_response = ToolEngine().generic_invoke(
tool=tool_instance,
tool_parameters=tool_args,
user_id=self.context.user_id or "",
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=self.context.app_id,
conversation_id=self.context.conversation_id,
message_id=self.context.message_id,
)
# Collect response and files
response_content = ""
tool_files: list[File] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
response_content += response.message.text
elif response.type == ToolInvokeMessage.MessageType.LINK:
# Handle link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Link: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
# Handle image URL messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
# Handle image link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
# Handle binary file link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
filename = response.meta.get("filename", "file") if response.meta else "file"
response_content += f"[File: {filename} - {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.JSON:
# Handle JSON messages
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# Handle blob messages - convert to text representation
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
mime_type = (
response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream"
)
size = len(response.message.blob)
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
# Handle variable messages
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
var_name = response.message.variable_name
var_value = response.message.variable_value
if isinstance(var_value, str):
response_content += var_value
else:
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
# Handle blob chunk messages - these are parts of a larger blob
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
# Handle retriever resources messages
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
response_content += response.message.context
elif response.type == ToolInvokeMessage.MessageType.FILE:
# Extract file from meta
if response.meta and "file" in response.meta:
file = response.meta["file"]
if isinstance(file, File):
# Check if file is for model or tool output
if response.meta.get("target") == "self":
# File is for model - add to files for next prompt
self.files.append(file)
response_content += f"File '{file.filename}' has been loaded into your context."
else:
# File is tool output
tool_files.append(file)
return response_content, tool_files, None
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
"""Find a tool instance by its name."""
for tool in self.tools:
if tool.entity.identity.name == tool_name:
return tool
return None
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
"""Convert tools to prompt message format."""
prompt_tools: list[PromptMessageTool] = []
for tool in self.tools:
prompt_tools.append(tool.to_prompt_message_tool())
return prompt_tools
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
"""Initialize usage tracking with empty usage if not set."""
if "usage" not in llm_usage or llm_usage["usage"] is None:
llm_usage["usage"] = LLMUsage.empty_usage()

View File

@ -0,0 +1,273 @@
"""Function Call strategy implementation."""
import json
from collections.abc import Generator
from typing import Any, Union
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
)
from core.tools.entities.tool_entities import ToolInvokeMeta
from .base import AgentPattern
class FunctionCallStrategy(AgentPattern):
"""Function Call strategy using model's native tool calling capability."""
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
# Initialize tracking
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
function_call_state: bool = True
total_usage: dict[str, LLMUsage | None] = {"usage": None}
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"round_index": iteration_step},
)
yield round_log
# On last iteration, remove tools to force final answer
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, LLMUsage | None] = {"usage": None}
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=current_tools,
stop=stop,
stream=stream,
user=self.context.user_id,
callbacks=[],
)
# Process response
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log
)
messages.append(self._create_assistant_message(response_content, tool_calls))
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update final text if no tool calls (this is likely the final answer)
if not tool_calls:
final_text = response_content
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Process tool calls
tool_outputs: dict[str, str] = {}
if tool_calls:
function_call_state = True
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
yield self._finish_log(
round_log,
data={
"llm_result": response_content,
"tool_calls": [
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
]
if tool_calls
else [],
"final_answer": final_text if not function_call_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text,
files=output_files,
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
finish_reason=finish_reason,
)
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, LLMUsage | None],
start_log: AgentLog,
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
]:
"""Handle LLM response chunks and extract tool calls and content.
Returns a tuple of (tool_calls, response_content, finish_reason).
"""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
response_content: str = ""
finish_reason: str | None = None
if isinstance(chunks, Generator):
# Streaming response
for chunk in chunks:
# Extract tool calls
if self._has_tool_calls(chunk):
tool_calls.extend(self._extract_tool_calls(chunk))
# Extract content
if chunk.delta.message and chunk.delta.message.content:
response_content += self._extract_content(chunk.delta.message.content)
# Track usage
if chunk.delta.usage:
self._accumulate_usage(llm_usage, chunk.delta.usage)
# Capture finish reason
if chunk.delta.finish_reason:
finish_reason = chunk.delta.finish_reason
yield chunk
else:
# Non-streaming response
result: LLMResult = chunks
if self._has_tool_calls_result(result):
tool_calls.extend(self._extract_tool_calls_result(result))
if result.message and result.message.content:
response_content += self._extract_content(result.message.content)
if result.usage:
self._accumulate_usage(llm_usage, result.usage)
# Convert to streaming format
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
yield self._finish_log(
start_log,
data={
"result": response_content,
},
usage=llm_usage.get("usage"),
)
return tool_calls, response_content, finish_reason
def _create_assistant_message(
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
) -> AssistantPromptMessage:
"""Create assistant message with tool calls."""
if tool_calls is None:
return AssistantPromptMessage(content=content)
return AssistantPromptMessage(
content=content or "",
tool_calls=[
AssistantPromptMessage.ToolCall(
id=tc[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
)
for tc in tool_calls
],
)
def _handle_tool_call(
self,
tool_name: str,
tool_args: dict[str, Any],
tool_call_id: str,
messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
"""Handle a single tool call and return response with files and meta."""
# Find tool
tool_instance = self._find_tool_by_name(tool_name)
if not tool_instance:
raise ValueError(f"Tool {tool_name} not found")
# Create tool call log
tool_call_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
)
yield tool_call_log
# Invoke tool using base class method
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
yield self._finish_log(
tool_call_log,
data={
**tool_call_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
final_content = response_content or "Tool executed successfully"
# Add tool response to messages
messages.append(
ToolPromptMessage(
content=final_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return response_content, tool_files, tool_invoke_meta

View File

@ -0,0 +1,402 @@
"""ReAct strategy implementation."""
from __future__ import annotations
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Union
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
SystemPromptMessage,
)
from .base import AgentPattern, ToolInvokeHook
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class ReActStrategy(AgentPattern):
"""ReAct strategy using reasoning and acting approach."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
):
"""Initialize the ReAct strategy with instruction support."""
super().__init__(
model_instance=model_instance,
tools=tools,
context=context,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
files=files,
tool_invoke_hook=tool_invoke_hook,
)
self.instruction = instruction
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the ReAct agent strategy."""
# Initialize tracking
agent_scratchpad: list[AgentScratchpadUnit] = []
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
react_state: bool = True
total_usage: dict[str, Any] = {"usage": None}
output_files: list[File] = [] # Track files produced by tools
final_text: str = ""
finish_reason: str | None = None
# Add "Observation" to stop sequences
if "Observation" not in stop:
stop = stop.copy()
stop.append("Observation")
while react_state and iteration_step <= max_iterations:
react_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"round_index": iteration_step},
)
yield round_log
# Build prompt with/without tools based on iteration
include_tools = iteration_step < max_iterations
current_messages = self._build_prompt_with_react_format(
prompt_messages, agent_scratchpad, include_tools, self.instruction
)
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, Any] = {"usage": None}
# Use current messages directly (files are handled by base class if needed)
messages_to_use = current_messages
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=self.context.user_id or "",
callbacks=[],
)
# Process response
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log, current_messages
)
agent_scratchpad.append(scratchpad)
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Check if we have an action to execute
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
react_state = True
# Execute tool
observation, tool_files = yield from self._handle_tool_call(
scratchpad.action, current_messages, round_log
)
scratchpad.observation = observation
# Track files produced by tools
output_files.extend(tool_files)
# Add observation to scratchpad for display
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
else:
# Extract final answer
if scratchpad.action and scratchpad.action.action_input:
final_answer = scratchpad.action.action_input
if isinstance(final_answer, dict):
final_answer = json.dumps(final_answer, ensure_ascii=False)
final_text = str(final_answer)
elif scratchpad.thought:
# If no action but we have thought, use thought as final answer
final_text = scratchpad.thought
yield self._finish_log(
round_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
"observation": scratchpad.observation or None,
"final_answer": final_text if not react_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
)
def _build_prompt_with_react_format(
self,
original_messages: list[PromptMessage],
agent_scratchpad: list[AgentScratchpadUnit],
include_tools: bool = True,
instruction: str = "",
) -> list[PromptMessage]:
"""Build prompt messages with ReAct format."""
# Copy messages to avoid modifying original
messages = list(original_messages)
# Find and update the system prompt that should already exist
system_prompt_found = False
for i, msg in enumerate(messages):
if isinstance(msg, SystemPromptMessage):
system_prompt_found = True
# The system prompt from frontend already has the template, just replace placeholders
# Format tools
tools_str = ""
tool_names = []
if include_tools and self.tools:
# Convert tools to prompt message tools format
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
tool_names = [tool.name for tool in prompt_tools]
# Format tools as JSON for comprehensive information
from core.model_runtime.utils.encoders import jsonable_encoder
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
else:
tools_str = "No tools available"
tool_names_str = ""
# Replace placeholders in the existing system prompt
updated_content = msg.content
assert isinstance(updated_content, str)
updated_content = updated_content.replace("{{instruction}}", instruction)
updated_content = updated_content.replace("{{tools}}", tools_str)
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
# Create new SystemPromptMessage with updated content
messages[i] = SystemPromptMessage(content=updated_content)
break
# If no system prompt found, that's unexpected but add scratchpad anyway
if not system_prompt_found:
# This shouldn't happen if frontend is working correctly
pass
# Format agent scratchpad
scratchpad_str = ""
if agent_scratchpad:
scratchpad_parts: list[str] = []
for unit in agent_scratchpad:
if unit.thought:
scratchpad_parts.append(f"Thought: {unit.thought}")
if unit.action_str:
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
if unit.observation:
scratchpad_parts.append(f"Observation: {unit.observation}")
scratchpad_str = "\n".join(scratchpad_parts)
# If there's a scratchpad, append it to the last message
if scratchpad_str:
messages.append(AssistantPromptMessage(content=scratchpad_str))
return messages
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, Any],
model_log: AgentLog,
current_messages: list[PromptMessage],
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[AgentScratchpadUnit, str | None],
]:
"""Handle LLM response chunks and extract action/thought.
Returns a tuple of (scratchpad_unit, finish_reason).
"""
usage_dict: dict[str, Any] = {}
# Convert non-streaming to streaming format if needed
if isinstance(chunks, LLMResult):
# Create a generator from the LLMResult
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=chunks.model,
prompt_messages=chunks.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=chunks.message,
usage=chunks.usage,
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
),
system_fingerprint=chunks.system_fingerprint or "",
)
streaming_chunks = result_to_chunks()
else:
streaming_chunks = chunks
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
# Initialize scratchpad unit
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
finish_reason: str | None = None
# Process chunks
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
# Action detected
action_str = json.dumps(chunk.model_dump())
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
scratchpad.action_str = action_str
scratchpad.action = chunk
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
else:
# Text chunk
chunk_text = str(chunk)
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
scratchpad.thought = (scratchpad.thought or "") + chunk_text
yield self._create_text_chunk(chunk_text, current_messages)
# Update usage
if usage_dict.get("usage"):
if llm_usage.get("usage"):
self._accumulate_usage(llm_usage, usage_dict["usage"])
else:
llm_usage["usage"] = usage_dict["usage"]
# Clean up thought
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
# Finish model log
yield self._finish_log(
model_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
},
usage=llm_usage.get("usage"),
)
return scratchpad, finish_reason
def _handle_tool_call(
self,
action: AgentScratchpadUnit.Action,
prompt_messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
"""Handle tool call and return observation with files."""
tool_name = action.action_name
tool_args: dict[str, Any] | str = action.action_input
# Start tool log
tool_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
)
yield tool_log
# Find tool instance
tool_instance = self._find_tool_by_name(tool_name)
if not tool_instance:
# Finish tool log with error
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"error": f"Tool {tool_name} not found",
},
)
return f"Tool {tool_name} not found", []
# Ensure tool_args is a dict
tool_args_dict: dict[str, Any]
if isinstance(tool_args, str):
try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
tool_args_dict = {"input": tool_args}
elif not isinstance(tool_args, dict):
tool_args_dict = {"input": str(tool_args)}
else:
tool_args_dict = tool_args
# Invoke tool using base class method
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
# Finish tool log
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
return response_content or "Tool executed successfully", tool_files

View File

@ -0,0 +1,107 @@
"""Strategy factory for creating agent strategies."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.agent.entities import AgentEntity, ExecutionContext
from core.file.models import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
from .base import AgentPattern, ToolInvokeHook
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class StrategyFactory:
"""Factory for creating agent strategies based on model features."""
# Tool calling related features
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
@staticmethod
def create_strategy(
model_features: list[ModelFeature],
model_instance: ModelInstance,
context: ExecutionContext,
tools: list[Tool],
files: list[File],
max_iterations: int = 10,
workflow_call_depth: int = 0,
agent_strategy: AgentEntity.Strategy | None = None,
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
) -> AgentPattern:
"""
Create an appropriate strategy based on model features.
Args:
model_features: List of model features/capabilities
model_instance: Model instance to use
context: Execution context containing trace/audit information
tools: Available tools
files: Available files
max_iterations: Maximum iterations for the strategy
workflow_call_depth: Depth of workflow calls
agent_strategy: Optional explicit strategy override
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
instruction: Optional instruction for ReAct strategy
Returns:
AgentStrategy instance
"""
# If explicit strategy is provided and it's Function Calling, try to use it if supported
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
# Fallback to ReAct if FC is requested but not supported
# If explicit strategy is Chain of Thought (ReAct)
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Default auto-selection logic
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
# Model supports native function calling
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
else:
# Use ReAct strategy for models without function calling
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)