Files
ragflow/agent/component/agent_with_tools.py
Yongteng Lei dd839f30e8 Fix: code supports matplotlib (#13724)
### What problem does this PR solve?

Code as "final" node: 

![img_v3_02vs_aece4caf-8403-4939-9e68-9845a22c2cfg](https://github.com/user-attachments/assets/9d87b8df-da6b-401c-bf6d-8b807fe92c22)

Code as "mid" node:

![img_v3_02vv_f74f331f-d755-44ab-a18c-96fff8cbd34g](https://github.com/user-attachments/assets/c94ef3f9-2a6c-47cb-9d2b-19703d2752e4)


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-03-20 20:32:00 +08:00

379 lines
16 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
import re
from copy import deepcopy
from functools import partial
from timeit import default_timer as timer
from typing import Any
import json_repair
from agent.component.llm import LLM, LLMParam
from agent.tools.base import LLMToolPluginCallSession, ToolBase, ToolMeta, ToolParamBase
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name
from api.db.services.llm_service import LLMBundle
from api.db.services.mcp_server_service import MCPServerService
from api.db.services.tenant_llm_service import TenantLLMService
from common.connection_utils import timeout
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from rag.prompts.generator import citation_plus, citation_prompt, full_question, kb_prompt, message_fit_in, structured_output_prompt
class AgentParam(LLMParam, ToolParamBase):
"""
Define the Agent component parameters.
"""
def __init__(self):
self.meta: ToolMeta = {
"name": "agent",
"description": "This is an agent for a specific task.",
"parameters": {
"user_prompt": {"type": "string", "description": "This is the order you need to send to the agent.", "default": "", "required": True},
"reasoning": {
"type": "string",
"description": ("Supervisor's reasoning for choosing the this agent. Explain why this agent is being invoked and what is expected of it."),
"required": True,
},
"context": {
"type": "string",
"description": (
"All relevant background information, prior facts, decisions, and state needed by the agent to solve the current query. Should be as detailed and self-contained as possible."
),
"required": True,
},
},
}
super().__init__()
self.function_name = "agent"
self.tools = []
self.mcp = []
self.max_rounds = 5
self.description = ""
self.custom_header = {}
class Agent(LLM, ToolBase):
component_name = "Agent"
def __init__(self, canvas, id, param: LLMParam):
LLM.__init__(self, canvas, id, param)
self.tools = {}
for idx, cpn in enumerate(self._param.tools):
cpn = self._load_tool_obj(cpn)
original_name = cpn.get_meta()["function"]["name"]
indexed_name = f"{original_name}_{idx}"
self.tools[indexed_name] = cpn
chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id)
self.chat_mdl = LLMBundle(
self._canvas.get_tenant_id(),
chat_model_config,
max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error,
max_rounds=self._param.max_rounds,
verbose_tool_use=False,
)
self.tool_meta = []
for indexed_name, tool_obj in self.tools.items():
original_meta = tool_obj.get_meta()
indexed_meta = deepcopy(original_meta)
indexed_meta["function"]["name"] = indexed_name
self.tool_meta.append(indexed_meta)
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
custom_header = self._param.custom_header
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables, custom_header)
for tnm, meta in mcp["tools"].items():
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
self.tools[tnm] = tool_call_session
self.callback = partial(self._canvas.tool_use_callback, id)
self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
if self.tool_meta:
self.chat_mdl.bind_tools(self.toolcall_session, self.tool_meta)
def _fit_messages(self, prompt: str, msg: list[dict]) -> list[dict]:
_, fitted_messages = message_fit_in(
[{"role": "system", "content": prompt}, *msg],
int(self.chat_mdl.max_length * 0.97),
)
return fitted_messages
@staticmethod
def _append_system_prompt(msg: list[dict], extra_prompt: str) -> None:
if extra_prompt and msg and msg[0]["role"] == "system":
msg[0]["content"] += "\n" + extra_prompt
@staticmethod
def _clean_formatted_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
def _load_tool_obj(self, cpn: dict) -> object:
from agent.component import component_class
tool_name = cpn["component_name"]
param = component_class(tool_name + "Param")()
param.update(cpn["params"])
try:
param.check()
except Exception as e:
self.set_output("_ERROR", cpn["component_name"] + f" configuration error: {e}")
raise
cpn_id = f"{self._id}-->" + cpn.get("name", "").replace(" ", "_")
return component_class(cpn["component_name"])(self._canvas, cpn_id, param)
def get_meta(self) -> dict[str, Any]:
self._param.function_name = self._id.split("-->")[-1]
m = super().get_meta()
if hasattr(self._param, "user_prompt") and self._param.user_prompt:
m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
return m
def get_input_form(self) -> dict[str, dict]:
res = {}
for k, v in self.get_input_elements().items():
res[k] = {"type": "line", "name": v["name"]}
for cpn in self._param.tools:
if not isinstance(cpn, LLM):
continue
res.update(cpn.get_input_form())
return res
def _get_output_schema(self):
try:
cand = self._param.outputs.get("structured")
except Exception:
return None
if isinstance(cand, dict):
if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0:
return cand
for k in ("schema", "structured"):
if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0:
return cand[k]
return None
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
fmt_msgs = [
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
{"role": "user", "content": text},
]
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
return await self._generate_async(fmt_msgs)
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20 * 60)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return await LLM._invoke_async(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
component = self._canvas.get_component(self._id)
downstreams = component["downstream"] if component else []
ex = self.exception_handler()
has_message_downstream = any(self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams)
if has_message_downstream and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
return
msg = self._fit_messages(prompt, msg)
self._append_system_prompt(msg, schema_prompt)
ans = await self._generate_async(msg)
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
obj = json_repair.loads(self._clean_formatted_answer(ans))
self.set_output("structured", obj)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = await self._force_format_to_schema_async(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
attachment_content = self._collect_tool_attachment_content(existing_text=ans)
if attachment_content:
ans += "\n\n" + attachment_content
artifact_md = self._collect_tool_artifact_markdown(existing_text=ans)
if artifact_md:
ans += "\n\n" + artifact_md
self.set_output("content", ans)
return ans
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
if len(msg) > 3:
st = timer()
user_request = await full_question(messages=msg, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer() - st)
msg = [*msg[:-1], {"role": "user", "content": user_request}]
msg = self._fit_messages(prompt, msg)
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
cited = False
if need2cite and len(msg) < 7:
self._append_system_prompt(msg, citation_prompt())
cited = True
answer = ""
async for delta in self._generate_streamly(msg):
if self.check_if_canceled("Agent streaming"):
return
if delta.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta)
return
if not need2cite or cited:
yield delta
answer += delta
if not need2cite or cited:
attachment_content = self._collect_tool_attachment_content(existing_text=answer)
if attachment_content:
yield "\n\n" + attachment_content
answer += "\n\n" + attachment_content
artifact_md = self._collect_tool_artifact_markdown(existing_text=answer)
if artifact_md:
yield "\n\n" + artifact_md
answer += "\n\n" + artifact_md
self.set_output("content", answer)
return
st = timer()
cited_answer = ""
async for delta in self._gen_citations_async(answer):
if self.check_if_canceled("Agent streaming"):
return
yield delta
cited_answer += delta
attachment_content = self._collect_tool_attachment_content(existing_text=cited_answer)
if attachment_content:
yield "\n\n" + attachment_content
cited_answer += "\n\n" + attachment_content
artifact_md = self._collect_tool_artifact_markdown(existing_text=cited_answer)
if artifact_md:
yield "\n\n" + artifact_md
cited_answer += "\n\n" + artifact_md
self.callback("gen_citations", {}, cited_answer, elapsed_time=timer() - st)
self.set_output("content", cited_answer)
async def _gen_citations_async(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, {"role": "user", "content": text}]):
yield delta_ans
def _collect_tool_artifact_markdown(self, existing_text: str = "") -> str:
md_parts = []
for tool_obj in self.tools.values():
if not hasattr(tool_obj, "_param") or not hasattr(tool_obj._param, "outputs"):
continue
artifacts_meta = tool_obj._param.outputs.get("_ARTIFACTS", {})
artifacts = artifacts_meta.get("value") if isinstance(artifacts_meta, dict) else None
if not artifacts:
continue
for art in artifacts:
if not isinstance(art, dict):
continue
url = art.get("url", "")
if url and (f"![]({url})" in existing_text or f"![{art.get('name', '')}]({url})" in existing_text):
continue
if art.get("mime_type", "").startswith("image/"):
md_parts.append(f"![{art['name']}]({url})")
else:
md_parts.append(f"[Download {art['name']}]({url})")
return "\n\n".join(md_parts)
def _collect_tool_attachment_content(self, existing_text: str = "") -> str:
text_parts = []
for tool_obj in self.tools.values():
if not hasattr(tool_obj, "_param") or not hasattr(tool_obj._param, "outputs"):
continue
content_meta = tool_obj._param.outputs.get("_ATTACHMENT_CONTENT", {})
content = content_meta.get("value") if isinstance(content_meta, dict) else None
if not content or not isinstance(content, str):
continue
content = content.strip()
if not content or content in existing_text:
continue
text_parts.append(content)
return "\n\n".join(text_parts)
def reset(self, only_output=False):
"""
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
"""
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
for k, cpn in self.tools.items():
if hasattr(cpn, "reset") and callable(cpn.reset):
cpn.reset()
if only_output:
return
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}