diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index b0006fac4..3938dc03e 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -20,20 +20,20 @@ import os import re from copy import deepcopy from functools import partial +from timeit import default_timer as timer from typing import Any import json_repair -from timeit import default_timer as timer -from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta -from api.db.services.llm_service import LLMBundle -from api.db.services.tenant_llm_service import TenantLLMService -from api.db.services.mcp_server_service import MCPServerService + +from agent.component.llm import LLM, LLMParam +from agent.tools.base import LLMToolPluginCallSession, ToolBase, ToolMeta, ToolParamBase from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name +from api.db.services.llm_service import LLMBundle +from api.db.services.mcp_server_service import MCPServerService +from api.db.services.tenant_llm_service import TenantLLMService from common.connection_utils import timeout -from rag.prompts.generator import next_step_async, COMPLETE_TASK, \ - citation_prompt, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool -from agent.component.llm import LLMParam, LLM +from rag.prompts.generator import citation_plus, citation_prompt, full_question, kb_prompt, message_fit_in, structured_output_prompt class AgentParam(LLMParam, ToolParamBase): @@ -42,35 +42,25 @@ class AgentParam(LLMParam, ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { - "name": "agent", - "description": "This is an agent for a specific task.", - "parameters": { - "user_prompt": { - "type": "string", - "description": "This is the order you need to send to the agent.", - "default": "", - "required": True - }, - "reasoning": { - "type": "string", - "description": ( - "Supervisor's reasoning for choosing the this agent. " - "Explain why this agent is being invoked and what is expected of it." - ), - "required": True - }, - "context": { - "type": "string", - "description": ( - "All relevant background information, prior facts, decisions, " - "and state needed by the agent to solve the current query. " - "Should be as detailed and self-contained as possible." - ), - "required": True - }, - } - } + self.meta: ToolMeta = { + "name": "agent", + "description": "This is an agent for a specific task.", + "parameters": { + "user_prompt": {"type": "string", "description": "This is the order you need to send to the agent.", "default": "", "required": True}, + "reasoning": { + "type": "string", + "description": ("Supervisor's reasoning for choosing the this agent. Explain why this agent is being invoked and what is expected of it."), + "required": True, + }, + "context": { + "type": "string", + "description": ( + "All relevant background information, prior facts, decisions, and state needed by the agent to solve the current query. Should be as detailed and self-contained as possible." + ), + "required": True, + }, + }, + } super().__init__() self.function_name = "agent" self.tools = [] @@ -92,12 +82,14 @@ class Agent(LLM, ToolBase): indexed_name = f"{original_name}_{idx}" self.tools[indexed_name] = cpn chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id) - self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config, - max_retries=self._param.max_retries, - retry_interval=self._param.delay_after_error, - max_rounds=self._param.max_rounds, - verbose_tool_use=True - ) + self.chat_mdl = LLMBundle( + self._canvas.get_tenant_id(), + chat_model_config, + max_retries=self._param.max_retries, + retry_interval=self._param.delay_after_error, + max_rounds=self._param.max_rounds, + verbose_tool_use=False, + ) self.tool_meta = [] for indexed_name, tool_obj in self.tools.items(): original_meta = tool_obj.get_meta() @@ -114,10 +106,30 @@ class Agent(LLM, ToolBase): self.tools[tnm] = tool_call_session self.callback = partial(self._canvas.tool_use_callback, id) self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback) - #self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas) + if self.tool_meta: + self.chat_mdl.bind_tools(self.toolcall_session, self.tool_meta) + + def _fit_messages(self, prompt: str, msg: list[dict]) -> list[dict]: + _, fitted_messages = message_fit_in( + [{"role": "system", "content": prompt}, *msg], + int(self.chat_mdl.max_length * 0.97), + ) + return fitted_messages + + @staticmethod + def _append_system_prompt(msg: list[dict], extra_prompt: str) -> None: + if extra_prompt and msg and msg[0]["role"] == "system": + msg[0]["content"] += "\n" + extra_prompt + + @staticmethod + def _clean_formatted_answer(ans: str) -> str: + ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) + ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) + return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) def _load_tool_obj(self, cpn: dict) -> object: from agent.component import component_class + tool_name = cpn["component_name"] param = component_class(tool_name + "Param")() param.update(cpn["params"]) @@ -130,7 +142,7 @@ class Agent(LLM, ToolBase): return component_class(cpn["component_name"])(self._canvas, cpn_id, param) def get_meta(self) -> dict[str, Any]: - self._param.function_name= self._id.split("-->")[-1] + self._param.function_name = self._id.split("-->")[-1] m = super().get_meta() if hasattr(self._param, "user_prompt") and self._param.user_prompt: m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt @@ -139,10 +151,7 @@ class Agent(LLM, ToolBase): def get_input_form(self) -> dict[str, dict]: res = {} for k, v in self.get_input_elements().items(): - res[k] = { - "type": "line", - "name": v["name"] - } + res[k] = {"type": "line", "name": v["name"]} for cpn in self._param.tools: if not isinstance(cpn, LLM): continue @@ -175,7 +184,7 @@ class Agent(LLM, ToolBase): def _invoke(self, **kwargs): return asyncio.run(self._invoke_async(**kwargs)) - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20 * 60))) async def _invoke_async(self, **kwargs): if self.check_if_canceled("Agent processing"): return @@ -204,19 +213,17 @@ class Agent(LLM, ToolBase): schema = json.dumps(output_schema, ensure_ascii=False, indent=2) schema_prompt = structured_output_prompt(schema) - downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] + component = self._canvas.get_component(self._id) + downstreams = component["downstream"] if component else [] ex = self.exception_handler() - if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema: + has_message_downstream = any(self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams) + if has_message_downstream and not (ex and ex["goto"]) and not output_schema: self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt)) return - _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) - use_tools = [] - ans = "" - async for delta_ans, _tk in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt): - if self.check_if_canceled("Agent processing"): - return - ans += delta_ans + msg = self._fit_messages(prompt, msg) + self._append_system_prompt(msg, schema_prompt) + ans = await self._generate_async(msg) if ans.find("**ERROR**") >= 0: logging.error(f"Agent._chat got error. response: {ans}") @@ -230,14 +237,8 @@ class Agent(LLM, ToolBase): error = "" for _ in range(self._param.max_retries + 1): try: - def clean_formated_answer(ans: str) -> str: - ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) - ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) - return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) - obj = json_repair.loads(clean_formated_answer(ans)) + obj = json_repair.loads(self._clean_formatted_answer(ans)) self.set_output("structured", obj) - if use_tools: - self.set_output("use_tools", use_tools) return obj except Exception: error = "The answer cannot be parsed as JSON" @@ -248,333 +249,118 @@ class Agent(LLM, ToolBase): self.set_output("_ERROR", error) return + attachment_content = self._collect_tool_attachment_content(existing_text=ans) + if attachment_content: + ans += "\n\n" + attachment_content + artifact_md = self._collect_tool_artifact_markdown(existing_text=ans) + if artifact_md: + ans += "\n\n" + artifact_md self.set_output("content", ans) - if use_tools: - self.set_output("use_tools", use_tools) return ans async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}): - _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) - answer_without_toolcall = "" - use_tools = [] - async for delta_ans, _ in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt): + if len(msg) > 3: + st = timer() + user_request = await full_question(messages=msg, chat_mdl=self.chat_mdl) + self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer() - st) + msg = [*msg[:-1], {"role": "user", "content": user_request}] + + msg = self._fit_messages(prompt, msg) + + need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 + cited = False + if need2cite and len(msg) < 7: + self._append_system_prompt(msg, citation_prompt()) + cited = True + + answer = "" + async for delta in self._generate_streamly(msg): if self.check_if_canceled("Agent streaming"): return - - if delta_ans.find("**ERROR**") >= 0: + if delta.find("**ERROR**") >= 0: if self.get_exception_default_value(): self.set_output("content", self.get_exception_default_value()) yield self.get_exception_default_value() else: - self.set_output("_ERROR", delta_ans) - return - answer_without_toolcall += delta_ans - yield delta_ans - - self.set_output("content", answer_without_toolcall) - if use_tools: - self.set_output("use_tools", use_tools) - - async def _react_with_tools_streamly_async_simple(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): - token_count = 0 - tool_metas = self.tool_meta - hist = deepcopy(history) - last_calling = "" - if len(hist) > 3: - st = timer() - user_request = await full_question(messages=history, chat_mdl=self.chat_mdl) - self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) - else: - user_request = history[-1]["content"] - - def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str: - """Build a minimal task_desc by concatenating prompt, query, and tool schemas.""" - user_defined_prompt = user_defined_prompt or {} - - task_desc = ( - "### Agent Prompt\n" - f"{prompt}\n\n" - "### User Request\n" - f"{user_request}\n\n" - ) - - if user_defined_prompt: - udp_json = json.dumps(user_defined_prompt, ensure_ascii=False, indent=2) - task_desc += "\n### User Defined Prompts\n" + udp_json + "\n" - - return task_desc - - - async def use_tool_async(name, args): - nonlocal hist, use_tools, last_calling - logging.info(f"{last_calling=} == {name=}") - last_calling = name - tool_response = await self.toolcall_session.tool_call_async(name, args) - use_tools.append({ - "name": name, - "arguments": args, - "results": tool_response - }) - return name, tool_response - - async def complete(): - nonlocal hist - need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 - if schema_prompt: - need2cite = False - cited = False - if hist and hist[0]["role"] == "system": - if schema_prompt: - hist[0]["content"] += "\n" + schema_prompt - if need2cite and len(hist) < 7: - hist[0]["content"] += citation_prompt() - cited = True - yield "", token_count - - _hist = hist - if len(hist) > 12: - _hist = [hist[0], hist[1], *hist[-10:]] - entire_txt = "" - async for delta_ans in self._generate_streamly(_hist): - if not need2cite or cited: - yield delta_ans, 0 - entire_txt += delta_ans - if not need2cite or cited: + self.set_output("_ERROR", delta) return + if not need2cite or cited: + yield delta + answer += delta - st = timer() - txt = "" - async for delta_ans in self._gen_citations_async(entire_txt): - if self.check_if_canceled("Agent streaming"): - return - yield delta_ans, 0 - txt += delta_ans - - self.callback("gen_citations", {}, txt, elapsed_time=timer()-st) - - def build_observation(tool_call_res: list[tuple]) -> str: - """ - Build a Observation from tool call results. - No LLM involved. - """ - if not tool_call_res: - return "" - - lines = ["Observation:"] - for name, result in tool_call_res: - lines.append(f"[{name} result]") - lines.append(str(result)) - - return "\n".join(lines) - - def append_user_content(hist, content): - if hist[-1]["role"] == "user": - hist[-1]["content"] += content - else: - hist.append({"role": "user", "content": content}) + if not need2cite or cited: + attachment_content = self._collect_tool_attachment_content(existing_text=answer) + if attachment_content: + yield "\n\n" + attachment_content + answer += "\n\n" + attachment_content + artifact_md = self._collect_tool_artifact_markdown(existing_text=answer) + if artifact_md: + yield "\n\n" + artifact_md + answer += "\n\n" + artifact_md + self.set_output("content", answer) + return st = timer() - task_desc = build_task_desc(prompt, user_request, user_defined_prompt) - self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) - for _ in range(self._param.max_rounds + 1): + cited_answer = "" + async for delta in self._gen_citations_async(answer): if self.check_if_canceled("Agent streaming"): return - response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) - # self.callback("next_step", {}, str(response)[:256]+"...") - token_count += tk or 0 - hist.append({"role": "assistant", "content": response}) - try: - # Remove markdown code fences properly - cleaned_response = re.sub(r"^.*```json\s*", "", response, flags=re.DOTALL) - cleaned_response = re.sub(r"```\s*$", "", cleaned_response, flags=re.DOTALL) - functions = json_repair.loads(cleaned_response) - if not isinstance(functions, list): - raise TypeError(f"List should be returned, but `{functions}`") - for f in functions: - if not isinstance(f, dict): - raise TypeError(f"An object type should be returned, but `{f}`") - - tool_tasks = [] - for func in functions: - name = func["name"] - args = func["arguments"] - if name == COMPLETE_TASK: - append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n") - async for txt, tkcnt in complete(): - yield txt, tkcnt - return - - tool_tasks.append(asyncio.create_task(use_tool_async(name, args))) - - results = await asyncio.gather(*tool_tasks) if tool_tasks else [] - st = timer() - reflection = build_observation(results) - append_user_content(hist, reflection) - self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st) - - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}") - e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}" - append_user_content(hist, str(e)) - - logging.warning( f"Exceed max rounds: {self._param.max_rounds}") - final_instruction = f""" -{user_request} -IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request. -Instructions: -1. SYNTHESIZE all information collected during this conversation -2. Provide a COMPLETE response using existing data - do not suggest additional research -3. Structure your response as a FINAL DELIVERABLE, not a plan -4. If information is incomplete, state what you found and provide the best analysis possible with available data -5. DO NOT mention conversation limits or suggest further steps -6. Focus on delivering VALUE with the information already gathered -Respond immediately with your final comprehensive answer. - """ - if self.check_if_canceled("Agent final instruction"): - return - append_user_content(hist, final_instruction) - - async for txt, tkcnt in complete(): - yield txt, tkcnt - -# async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): -# token_count = 0 -# tool_metas = self.tool_meta -# hist = deepcopy(history) -# last_calling = "" -# if len(hist) > 3: -# st = timer() -# user_request = await full_question(messages=history, chat_mdl=self.chat_mdl) -# self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) -# else: -# user_request = history[-1]["content"] - -# async def use_tool_async(name, args): -# nonlocal hist, use_tools, last_calling -# logging.info(f"{last_calling=} == {name=}") -# last_calling = name -# tool_response = await self.toolcall_session.tool_call_async(name, args) -# use_tools.append({ -# "name": name, -# "arguments": args, -# "results": tool_response -# }) -# # self.callback("add_memory", {}, "...") -# #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt) - -# return name, tool_response - -# async def complete(): -# nonlocal hist -# need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 -# if schema_prompt: -# need2cite = False -# cited = False -# if hist and hist[0]["role"] == "system": -# if schema_prompt: -# hist[0]["content"] += "\n" + schema_prompt -# if need2cite and len(hist) < 7: -# hist[0]["content"] += citation_prompt() -# cited = True -# yield "", token_count - -# _hist = hist -# if len(hist) > 12: -# _hist = [hist[0], hist[1], *hist[-10:]] -# entire_txt = "" -# async for delta_ans in self._generate_streamly(_hist): -# if not need2cite or cited: -# yield delta_ans, 0 -# entire_txt += delta_ans -# if not need2cite or cited: -# return - -# st = timer() -# txt = "" -# async for delta_ans in self._gen_citations_async(entire_txt): -# if self.check_if_canceled("Agent streaming"): -# return -# yield delta_ans, 0 -# txt += delta_ans - -# self.callback("gen_citations", {}, txt, elapsed_time=timer()-st) - -# def append_user_content(hist, content): -# if hist[-1]["role"] == "user": -# hist[-1]["content"] += content -# else: -# hist.append({"role": "user", "content": content}) - -# st = timer() -# task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt) -# self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) -# for _ in range(self._param.max_rounds + 1): -# if self.check_if_canceled("Agent streaming"): -# return -# response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) -# # self.callback("next_step", {}, str(response)[:256]+"...") -# token_count += tk or 0 -# hist.append({"role": "assistant", "content": response}) -# try: -# functions = json_repair.loads(re.sub(r"```.*", "", response)) -# if not isinstance(functions, list): -# raise TypeError(f"List should be returned, but `{functions}`") -# for f in functions: -# if not isinstance(f, dict): -# raise TypeError(f"An object type should be returned, but `{f}`") - -# tool_tasks = [] -# for func in functions: -# name = func["name"] -# args = func["arguments"] -# if name == COMPLETE_TASK: -# append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n") -# async for txt, tkcnt in complete(): -# yield txt, tkcnt -# return - -# tool_tasks.append(asyncio.create_task(use_tool_async(name, args))) - -# results = await asyncio.gather(*tool_tasks) if tool_tasks else [] -# st = timer() -# reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt) -# append_user_content(hist, reflection) -# self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st) - -# except Exception as e: -# logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}") -# e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}" -# append_user_content(hist, str(e)) - -# logging.warning( f"Exceed max rounds: {self._param.max_rounds}") -# final_instruction = f""" -# {user_request} -# IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request. -# Instructions: -# 1. SYNTHESIZE all information collected during this conversation -# 2. Provide a COMPLETE response using existing data - do not suggest additional research -# 3. Structure your response as a FINAL DELIVERABLE, not a plan -# 4. If information is incomplete, state what you found and provide the best analysis possible with available data -# 5. DO NOT mention conversation limits or suggest further steps -# 6. Focus on delivering VALUE with the information already gathered -# Respond immediately with your final comprehensive answer. -# """ -# if self.check_if_canceled("Agent final instruction"): -# return -# append_user_content(hist, final_instruction) - -# async for txt, tkcnt in complete(): -# yield txt, tkcnt + yield delta + cited_answer += delta + attachment_content = self._collect_tool_attachment_content(existing_text=cited_answer) + if attachment_content: + yield "\n\n" + attachment_content + cited_answer += "\n\n" + attachment_content + artifact_md = self._collect_tool_artifact_markdown(existing_text=cited_answer) + if artifact_md: + yield "\n\n" + artifact_md + cited_answer += "\n\n" + artifact_md + self.callback("gen_citations", {}, cited_answer, elapsed_time=timer() - st) + self.set_output("content", cited_answer) async def _gen_citations_async(self, text): retrievals = self._canvas.get_reference() retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True) - async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, - {"role": "user", "content": text} - ]): + async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, {"role": "user", "content": text}]): yield delta_ans + def _collect_tool_artifact_markdown(self, existing_text: str = "") -> str: + md_parts = [] + for tool_obj in self.tools.values(): + if not hasattr(tool_obj, "_param") or not hasattr(tool_obj._param, "outputs"): + continue + artifacts_meta = tool_obj._param.outputs.get("_ARTIFACTS", {}) + artifacts = artifacts_meta.get("value") if isinstance(artifacts_meta, dict) else None + if not artifacts: + continue + for art in artifacts: + if not isinstance(art, dict): + continue + url = art.get("url", "") + if url and (f"![]({url})" in existing_text or f"![{art.get('name', '')}]({url})" in existing_text): + continue + if art.get("mime_type", "").startswith("image/"): + md_parts.append(f"![{art['name']}]({url})") + else: + md_parts.append(f"[Download {art['name']}]({url})") + return "\n\n".join(md_parts) + + def _collect_tool_attachment_content(self, existing_text: str = "") -> str: + text_parts = [] + for tool_obj in self.tools.values(): + if not hasattr(tool_obj, "_param") or not hasattr(tool_obj._param, "outputs"): + continue + content_meta = tool_obj._param.outputs.get("_ATTACHMENT_CONTENT", {}) + content = content_meta.get("value") if isinstance(content_meta, dict) else None + if not content or not isinstance(content, str): + continue + content = content.strip() + if not content or content in existing_text: + continue + text_parts.append(content) + return "\n\n".join(text_parts) + def reset(self, only_output=False): """ Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession. diff --git a/agent/sandbox/README.md b/agent/sandbox/README.md index a86361872..bbd1e5cb6 100644 --- a/agent/sandbox/README.md +++ b/agent/sandbox/README.md @@ -189,7 +189,11 @@ Currently, the following languages are officially supported: ### 🐍 Python -To add Python dependencies, simply edit the following file: +Pre-installed packages: `requests`, `numpy`, `pandas`, `matplotlib`. + +> `matplotlib` uses the `Agg` (non-interactive) backend by default in the sandbox (`MPLBACKEND=Agg`). No display server is available, so always save figures to files (e.g. `fig.savefig("artifacts/chart.png")`) rather than calling `plt.show()`. + +To add more dependencies, edit: ```bash sandbox_base_image/python/requirements.txt @@ -199,6 +203,8 @@ Add any additional packages you need, one per line (just like a normal pip requi ### 🟨 Node.js +Pre-installed packages: `axios`. + To add Node.js dependencies: 1. Navigate to the Node.js base image directory: diff --git a/agent/sandbox/executor_manager/models/schemas.py b/agent/sandbox/executor_manager/models/schemas.py index 750db5bc8..9baa94b5f 100644 --- a/agent/sandbox/executor_manager/models/schemas.py +++ b/agent/sandbox/executor_manager/models/schemas.py @@ -21,6 +21,13 @@ from pydantic import BaseModel, Field, field_validator from models.enums import ResourceLimitType, ResultStatus, RuntimeErrorType, SupportLanguage, UnauthorizedAccessType +class ArtifactItem(BaseModel): + name: str + mime_type: str + size: int + content_b64: str + + class CodeExecutionResult(BaseModel): status: ResultStatus stdout: str @@ -37,6 +44,9 @@ class CodeExecutionResult(BaseModel): unauthorized_access_type: Optional[UnauthorizedAccessType] = None runtime_error_type: Optional[RuntimeErrorType] = None + # File artifacts produced by code execution (images, PDFs, CSVs, etc.) + artifacts: list[ArtifactItem] = [] + class CodeExecutionRequest(BaseModel): code_b64: str = Field(..., description="Base64 encoded code string") diff --git a/agent/sandbox/executor_manager/services/execution.py b/agent/sandbox/executor_manager/services/execution.py index eae366585..358d122c2 100644 --- a/agent/sandbox/executor_manager/services/execution.py +++ b/agent/sandbox/executor_manager/services/execution.py @@ -24,7 +24,7 @@ from core.config import TIMEOUT from core.container import allocate_container_blocking, release_container from core.logger import logger from models.enums import ResourceLimitType, ResultStatus, RuntimeErrorType, SupportLanguage, UnauthorizedAccessType -from models.schemas import CodeExecutionRequest, CodeExecutionResult +from models.schemas import ArtifactItem, CodeExecutionRequest, CodeExecutionResult from utils.common import async_run_command @@ -59,8 +59,12 @@ async def execute_code(req: CodeExecutionRequest): f.write("""import json import os import sys + +os.makedirs(os.path.join(os.getcwd(), "artifacts"), exist_ok=True) + sys.path.insert(0, os.path.dirname(__file__)) from main import main + if __name__ == "__main__": args = json.loads(sys.argv[1]) result = main(**args) @@ -180,12 +184,14 @@ if (fs.existsSync(mainPath)) { logger.info(f"{args_json=}") if returncode == 0: + artifacts = await _collect_artifacts(container, task_id, workdir) return CodeExecutionResult( status=ResultStatus.SUCCESS, stdout=str(stdout), stderr=stderr, exit_code=0, time_used_ms=time_used_ms, + artifacts=artifacts, ) elif returncode == 124: return CodeExecutionResult( @@ -229,6 +235,84 @@ if (fs.existsSync(mainPath)) { await release_container(container, language) +ALLOWED_ARTIFACT_EXTENSIONS = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".svg": "image/svg+xml", + ".pdf": "application/pdf", + ".csv": "text/csv", + ".json": "application/json", + ".html": "text/html", +} +MAX_ARTIFACT_COUNT = 10 +MAX_ARTIFACT_SIZE = 10 * 1024 * 1024 # 10MB per file + + +async def _collect_artifacts(container: str, task_id: str, host_workdir: str) -> list[ArtifactItem]: + artifacts_path = f"/workspace/{task_id}/artifacts" + + # List files in the artifacts directory inside the container + returncode, stdout, _ = await async_run_command( + "docker", "exec", container, "find", artifacts_path, + "-maxdepth", "1", "-type", "f", timeout=5, + ) + if returncode != 0 or not stdout.strip(): + return [] + + raw_names = [line.split("/")[-1] for line in stdout.strip().splitlines() if line.strip()] + # Sanitize: reject names with path traversal or control characters + filenames = [n for n in raw_names if n and "/" not in n and "\\" not in n and ".." not in n and not n.startswith(".")] + if not filenames: + return [] + + items: list[ArtifactItem] = [] + + for fname in filenames[:MAX_ARTIFACT_COUNT]: + ext = os.path.splitext(fname)[1].lower() + mime_type = ALLOWED_ARTIFACT_EXTENSIONS.get(ext) + if not mime_type: + logger.warning(f"Skipping artifact with disallowed extension: {fname}") + continue + + file_path = f"{artifacts_path}/{fname}" + + # Check file size inside the container + returncode, size_str, _ = await async_run_command( + "docker", "exec", container, "stat", "-c", "%s", file_path, timeout=5, + ) + if returncode != 0: + logger.warning(f"Failed to stat artifact {fname}") + continue + + file_size = int(size_str.strip()) + if file_size > MAX_ARTIFACT_SIZE: + logger.warning(f"Artifact {fname} too large ({file_size} bytes), skipping") + continue + if file_size == 0: + continue + + # Read file content via docker exec (docker cp doesn't work with gVisor tmpfs) + returncode, content_b64, stderr = await async_run_command( + "docker", "exec", container, "base64", file_path, timeout=30, + ) + if returncode != 0: + logger.warning(f"Failed to read artifact {fname}: {stderr}") + continue + + content_b64 = content_b64.replace("\n", "").strip() + + items.append(ArtifactItem( + name=fname, + mime_type=mime_type, + size=file_size, + content_b64=content_b64, + )) + logger.info(f"Collected artifact: {fname} ({file_size} bytes, {mime_type})") + + return items + + def analyze_error_result(stderr: str, exit_code: int) -> CodeExecutionResult: """Analyze the error result and classify it""" if "Permission denied" in stderr: diff --git a/agent/sandbox/providers/self_managed.py b/agent/sandbox/providers/self_managed.py index 7078f6f76..d4e0c6d68 100644 --- a/agent/sandbox/providers/self_managed.py +++ b/agent/sandbox/providers/self_managed.py @@ -199,6 +199,7 @@ class SelfManagedProvider(SandboxProvider): "memory_used_kb": result.get("memory_used_kb"), "detail": result.get("detail"), "instance_id": instance_id, + "artifacts": result.get("artifacts", []), } ) diff --git a/agent/sandbox/sandbox_base_image/python/Dockerfile b/agent/sandbox/sandbox_base_image/python/Dockerfile index 7b985764f..93227055f 100644 --- a/agent/sandbox/sandbox_base_image/python/Dockerfile +++ b/agent/sandbox/sandbox_base_image/python/Dockerfile @@ -2,12 +2,15 @@ FROM python:3.11-slim-bookworm COPY --from=ghcr.io/astral-sh/uv:0.7.5 /uv /uvx /bin/ ENV UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple +ENV MPLBACKEND=Agg +ENV MPLCONFIGDIR=/tmp/matplotlib COPY requirements.txt . RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g' && \ apt-get update && \ apt-get install -y curl gcc && \ + mkdir -p /tmp/matplotlib && \ uv pip install --system -r requirements.txt WORKDIR /workspace diff --git a/agent/sandbox/sandbox_base_image/python/requirements.txt b/agent/sandbox/sandbox_base_image/python/requirements.txt index 4ad150163..d199e9e91 100644 --- a/agent/sandbox/sandbox_base_image/python/requirements.txt +++ b/agent/sandbox/sandbox_base_image/python/requirements.txt @@ -1,3 +1,4 @@ numpy pandas +matplotlib requests diff --git a/agent/tools/base.py b/agent/tools/base.py index 1f629a252..f5a42de4d 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -57,17 +57,19 @@ class LLMToolPluginCallSession(ToolCallSession): async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any: assert name in self.tools_map, f"LLM tool {name} does not exist" + logging.info(f"[ToolCall] invoke name={name} arguments={str(arguments)[:200]}") st = timer() tool_obj = self.tools_map[name] if isinstance(tool_obj, MCPToolCallSession): resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60) + elif hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): + resp = await tool_obj.invoke_async(**arguments) else: - if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): - resp = await tool_obj.invoke_async(**arguments) - else: - resp = await thread_pool_exec(tool_obj.invoke, **arguments) + resp = await thread_pool_exec(tool_obj.invoke, **arguments) - self.callback(name, arguments, resp, elapsed_time=timer()-st) + elapsed = timer() - st + logging.info(f"[ToolCall] done name={name} elapsed={elapsed:.2f}s result={str(resp)[:200]}") + self.callback(name, arguments, resp, elapsed_time=elapsed) return resp def get_tool_obj(self, name): @@ -101,13 +103,8 @@ class ToolParamBase(ComponentParamBase): if "enum" in p: params[k]["enum"] = p["enum"] - desc = self.meta["description"] - if hasattr(self, "description"): - desc = self.description - - function_name = self.meta["name"] - if hasattr(self, "function_name"): - function_name = self.function_name + desc = getattr(self, "description", None) or self.meta["description"] + function_name = getattr(self, "function_name", self.meta["name"]) return { "type": "function", diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py index bc42415e0..c896de57c 100644 --- a/agent/tools/code_exec.py +++ b/agent/tools/code_exec.py @@ -18,6 +18,7 @@ import base64 import json import logging import os +import uuid from abc import ABC from typing import Optional @@ -25,8 +26,10 @@ from pydantic import BaseModel, Field, field_validator from strenum import StrEnum from agent.tools.base import ToolBase, ToolMeta, ToolParamBase +from api.db.services.file_service import FileService from common import settings from common.connection_utils import timeout +from common.constants import SANDBOX_ARTIFACT_BUCKET, SANDBOX_ARTIFACT_EXPIRE_DAYS class Language(StrEnum): @@ -70,6 +73,7 @@ class CodeExecParam(ToolParamBase): "name": "execute_code", "description": """ This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string. + Here's a code example for Python(`main` function MUST be included): def main() -> dict: \"\"\" @@ -84,6 +88,26 @@ def main() -> dict: "result": fibonacci_recursive(100), } +To generate charts or files (images, PDFs, CSVs, etc.), save them to the `artifacts/` directory (relative to the working directory). The sandbox will automatically collect these files and return them. Example: +def main() -> dict: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import pandas as pd + + df = pd.DataFrame({"x": [1, 2, 3, 4], "y": [10, 20, 25, 30]}) + fig, ax = plt.subplots() + ax.plot(df["x"], df["y"]) + ax.set_title("Sample Chart") + fig.savefig("artifacts/chart.png", dpi=150, bbox_inches="tight") + plt.close(fig) + return {"summary": "Chart saved to artifacts/chart.png"} + +Available Python packages: pandas, numpy, matplotlib, requests. +Supported artifact file types: .png, .jpg, .jpeg, .svg, .pdf, .csv, .json, .html + +Collected artifacts are also parsed automatically and appended to the stable text output `content`. The content includes sections like `attachment1 (image): ...`, `attachment2 (pdf): ...`, so downstream nodes can consume a single text output without depending on unstable attachment-specific variables. + Here's a code example for Javascript(`main` function MUST be included and exported): const axios = require('axios'); async function main(args) { @@ -125,6 +149,7 @@ module.exports = { main }; class CodeExec(ToolBase, ABC): component_name = "CodeExec" + _lifecycle_configured = False @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): @@ -148,6 +173,8 @@ class CodeExec(ToolBase, ABC): if self.check_if_canceled("CodeExec execution"): return self.output() + timeout_seconds = int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)) + try: # Try using the new sandbox provider system first try: @@ -157,25 +184,13 @@ class CodeExec(ToolBase, ABC): return # Execute code using the provider system - result = sandbox_execute_code( - code=code, - language=language, - timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)), - arguments=arguments - ) + result = sandbox_execute_code(code=code, language=language, timeout=timeout_seconds, arguments=arguments) if self.check_if_canceled("CodeExec execution"): return - # Process the result - if result.stderr: - self.set_output("_ERROR", result.stderr) - return - - parsed_stdout = self._deserialize_stdout(result.stdout) - logging.info(f"[CodeExec]: Provider system -> {parsed_stdout}") - self._populate_outputs(parsed_stdout, result.stdout) - return + artifacts = result.metadata.get("artifacts", []) if result.metadata else [] + return self._process_execution_result(result.stdout, result.stderr, "Provider system", artifacts) except (ImportError, RuntimeError) as provider_error: # Provider system not available or not configured, fall back to HTTP @@ -196,7 +211,7 @@ class CodeExec(ToolBase, ABC): self.set_output("_ERROR", "Task has been canceled") return self.output() - resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) + resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=timeout_seconds) logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:") if self.check_if_canceled("CodeExec execution"): @@ -206,14 +221,12 @@ class CodeExec(ToolBase, ABC): resp.raise_for_status() body = resp.json() if body: - stderr = body.get("stderr") - if stderr: - self.set_output("_ERROR", stderr) - return self.output() - raw_stdout = body.get("stdout", "") - parsed_stdout = self._deserialize_stdout(raw_stdout) - logging.info(f"[CodeExec]: http://{settings.SANDBOX_HOST}:9385/run -> {parsed_stdout}") - self._populate_outputs(parsed_stdout, raw_stdout) + return self._process_execution_result( + body.get("stdout", ""), + body.get("stderr"), + f"http://{settings.SANDBOX_HOST}:9385/run", + body.get("artifacts", []), + ) else: self.set_output("_ERROR", "There is no response from sandbox") return self.output() @@ -226,6 +239,100 @@ class CodeExec(ToolBase, ABC): return self.output() + def _process_execution_result(self, stdout: str, stderr: str | None, source: str, artifacts: list | None = None): + if stderr and not stdout and not artifacts: + self.set_output("_ERROR", stderr) + return self.output() + + # Clear any stale error from previous runs or base class initialization + self.set_output("_ERROR", "") + + if stderr: + logging.warning(f"[CodeExec]: stderr (non-fatal): {stderr[:500]}") + + parsed_stdout = self._deserialize_stdout(stdout) + logging.info(f"[CodeExec]: {source} -> {parsed_stdout}") + self._populate_outputs(parsed_stdout, stdout) + content_parts = [] + base_content = self._build_content_text(parsed_stdout, raw_stdout=stdout) + if base_content: + content_parts.append(base_content) + + if artifacts: + artifact_urls = self._upload_artifacts(artifacts) + if artifact_urls: + self.set_output("_ARTIFACTS", artifact_urls) + attachment_text = self._build_attachment_content(artifacts, artifact_urls) + self.set_output("_ATTACHMENT_CONTENT", attachment_text) + if attachment_text: + content_parts.append(attachment_text) + else: + self.set_output("_ATTACHMENT_CONTENT", "") + + self.set_output("content", "\n\n".join([part for part in content_parts if part]).strip()) + + return self.output() + + @classmethod + def _ensure_bucket_lifecycle(cls): + if cls._lifecycle_configured: + return + try: + storage = settings.STORAGE_IMPL + # Only MinIO/S3 backends expose .conn for lifecycle config + if not hasattr(storage, "conn") or storage.conn is None: + cls._lifecycle_configured = True + return + if not storage.conn.bucket_exists(SANDBOX_ARTIFACT_BUCKET): + storage.conn.make_bucket(SANDBOX_ARTIFACT_BUCKET) + from minio.commonconfig import Filter + from minio.lifecycleconfig import Expiration, LifecycleConfig, Rule + + rule = Rule( + rule_id="auto-expire", + status="Enabled", + rule_filter=Filter(prefix=""), + expiration=Expiration(days=SANDBOX_ARTIFACT_EXPIRE_DAYS), + ) + storage.conn.set_bucket_lifecycle(SANDBOX_ARTIFACT_BUCKET, LifecycleConfig([rule])) + logging.info(f"[CodeExec]: Set {SANDBOX_ARTIFACT_EXPIRE_DAYS}-day lifecycle on bucket '{SANDBOX_ARTIFACT_BUCKET}'") + cls._lifecycle_configured = True + except Exception as e: + # Do NOT set _lifecycle_configured so we retry next time + logging.warning(f"[CodeExec]: Failed to set bucket lifecycle: {e}") + + def _upload_artifacts(self, artifacts: list) -> list[dict]: + self._ensure_bucket_lifecycle() + uploaded = [] + for art in artifacts: + try: + name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "") + content_b64 = art.get("content_b64", "") if isinstance(art, dict) else getattr(art, "content_b64", "") + mime_type = art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", "") + size = art.get("size", 0) if isinstance(art, dict) else getattr(art, "size", 0) + if not content_b64 or not name: + continue + + ext = os.path.splitext(name)[1].lower() + storage_name = f"{uuid.uuid4().hex}{ext}" + binary = base64.b64decode(content_b64) + + settings.STORAGE_IMPL.put(SANDBOX_ARTIFACT_BUCKET, storage_name, binary) + + url = f"/v1/document/artifact/{storage_name}" + uploaded.append( + { + "name": name, + "url": url, + "mime_type": mime_type, + "size": size, + } + ) + logging.info(f"[CodeExec]: Uploaded artifact {name} -> {url}") + except Exception as e: + logging.warning(f"[CodeExec]: Failed to upload artifact: {e}") + return uploaded + def _encode_code(self, code: str) -> str: return base64.b64encode(code.encode("utf-8")).decode("utf-8") @@ -357,6 +464,84 @@ class CodeExec(ToolBase, ABC): logging.info(f"[CodeExec]: populate scalar key='{key}' raw='{val}' coerced='{coerced}'") self.set_output(key, coerced) + def _build_attachment_content(self, artifacts: list, artifact_urls: list[dict] | None = None) -> str: + sections = [] + artifact_urls = artifact_urls or [] + + for idx, art in enumerate(artifacts, start=1): + key = f"attachment{idx}" + try: + name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "") + content_b64 = art.get("content_b64", "") if isinstance(art, dict) else getattr(art, "content_b64", "") + mime_type = art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", "") + if not name or not content_b64: + continue + + blob = base64.b64decode(content_b64) + parsed = FileService.parse( + name, + blob, + False, + tenant_id=self._canvas.get_tenant_id(), + ) + attachment_type = self._normalize_attachment_type(name, mime_type) + section = self._format_attachment_section(key, attachment_type, name, parsed) + sections.append(section) + logging.info(f"[CodeExec]: parse attachment section key='{key}' from artifact='{name}'") + except Exception as e: + logging.warning(f"[CodeExec]: Failed to parse artifact for content section '{key}': {e}") + fallback_type = self._normalize_attachment_type( + art.get("name", "") if isinstance(art, dict) else getattr(art, "name", ""), + art.get("mime_type", "") if isinstance(art, dict) else getattr(art, "mime_type", ""), + ) + fallback_name = art.get("name", "") if isinstance(art, dict) else getattr(art, "name", "") + fallback_url = "" + if idx - 1 < len(artifact_urls): + fallback_url = artifact_urls[idx - 1].get("url", "") + fallback_text = "Artifact generated but parse failed." + if fallback_url: + fallback_text += f" Download: {fallback_url}" + sections.append(self._format_attachment_section(key, fallback_type, fallback_name, fallback_text)) + + if sections: + return f"attachment_count: {len(sections)}\n\n" + "\n\n".join(sections) + return "attachment_count: 0" + + def _normalize_attachment_type(self, name: str, mime_type: str) -> str: + mime_type = str(mime_type or "").strip().lower() + if mime_type.startswith("image/"): + return "image" + if mime_type == "application/pdf": + return "pdf" + if mime_type == "text/csv": + return "csv" + if mime_type == "application/json": + return "json" + if mime_type == "text/html": + return "html" + + ext = os.path.splitext(name or "")[1].lower().lstrip(".") + return ext or "file" + + def _format_attachment_section(self, key: str, attachment_type: str, name: str, parsed: str) -> str: + title = f"{key} ({attachment_type})" + if name: + title += f": {name}" + body = parsed if isinstance(parsed, str) else json.dumps(parsed, ensure_ascii=False) + return f"{title}\n{body}".strip() + + def _build_content_text(self, parsed_stdout, raw_stdout: str = "") -> str: + if isinstance(parsed_stdout, str): + return parsed_stdout.strip() + if isinstance(parsed_stdout, (dict, list, tuple)): + try: + return json.dumps(parsed_stdout, ensure_ascii=False, indent=2).strip() + except Exception: + return str(parsed_stdout).strip() + if parsed_stdout is None: + return str(raw_stdout or "").strip() + return str(parsed_stdout).strip() + def _get_by_path(self, data, path: str): if not path: return None diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 5192756e0..c9c20c911 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -18,36 +18,38 @@ import os.path import pathlib import re from pathlib import Path, PurePosixPath, PureWindowsPath -from quart import request, make_response + +from quart import make_response, request + from api.apps import current_user, login_required from api.common.check_team_permission import check_kb_team_permission from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.db import VALID_FILE_TYPES, FileType from api.db.db_models import Task from api.db.services import duplicate_name -from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.doc_metadata_service import DocMetadataService -from common.metadata_utils import meta_filter, convert_conditions, turn2jsonschema +from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import TaskService, cancel_all_task_of from api.db.services.user_service import UserTenantService -from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import ( get_data_error_result, get_json_result, + get_request_json, server_error_response, validate_request, - get_request_json, ) from api.utils.file_utils import filename_type, thumbnail -from common.file_utils import get_project_base_directory -from common.constants import RetCode, VALID_TASK_STATUS, ParserType, TaskStatus from api.utils.web_utils import CONTENT_TYPE_MAP, apply_safe_file_response_headers, html2pdf, is_valid_url -from deepdoc.parser.html_parser import RAGFlowHtmlParser -from rag.nlp import search, rag_tokenizer from common import settings +from common.constants import SANDBOX_ARTIFACT_BUCKET, VALID_TASK_STATUS, ParserType, RetCode, TaskStatus +from common.file_utils import get_project_base_directory +from common.metadata_utils import convert_conditions, meta_filter, turn2jsonschema +from common.misc_utils import get_uuid, thread_pool_exec +from deepdoc.parser.html_parser import RAGFlowHtmlParser +from rag.nlp import rag_tokenizer, search def _is_safe_download_filename(name: str) -> bool: @@ -75,6 +77,7 @@ async def upload(): return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) file_objs = files.getlist("file") + def _close_file_objs(objs): for obj in objs: try: @@ -84,6 +87,7 @@ async def upload(): obj.stream.close() except Exception: pass + for file_obj in file_objs: if file_obj.filename == "": _close_file_objs(file_objs) @@ -239,7 +243,6 @@ async def list_docs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - tenants = UserTenantService.query(user_id=current_user.id) for tenant in tenants: if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): @@ -608,6 +611,7 @@ async def run(): req = await get_request_json() uid = current_user.id try: + def _run_sync(): for doc_id in req["doc_ids"]: if not DocumentService.accessible(doc_id, uid): @@ -670,6 +674,7 @@ async def rename(): req = await get_request_json() uid = current_user.id try: + def _rename_sync(): if not DocumentService.accessible(req["doc_id"], uid): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) @@ -827,6 +832,44 @@ async def get_image(image_id): return server_error_response(e) +ARTIFACT_CONTENT_TYPES = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".svg": "image/svg+xml", + ".pdf": "application/pdf", + ".csv": "text/csv", + ".json": "application/json", + ".html": "text/html", +} + + +@manager.route("/artifact/", methods=["GET"]) # noqa: F821 +@login_required +async def get_artifact(filename): + try: + bucket = SANDBOX_ARTIFACT_BUCKET + # Validate filename: must be uuid hex + allowed extension, nothing else + basename = os.path.basename(filename) + if basename != filename or "/" in filename or "\\" in filename: + return get_data_error_result(message="Invalid filename.") + ext = os.path.splitext(basename)[1].lower() + if ext not in ARTIFACT_CONTENT_TYPES: + return get_data_error_result(message="Invalid file type.") + data = await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, basename) + if not data: + return get_data_error_result(message="Artifact not found.") + content_type = ARTIFACT_CONTENT_TYPES.get(ext, "application/octet-stream") + response = await make_response(data) + safe_filename = re.sub(r"[^\w.\-]", "_", basename) + apply_safe_file_response_headers(response, content_type, ext) + if not response.headers.get("Content-Disposition"): + response.headers.set("Content-Disposition", f'inline; filename="{safe_filename}"') + return response + except Exception as e: + return server_error_response(e) + + @manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id") @@ -942,8 +985,8 @@ async def set_meta(): @manager.route("/upload_info", methods=["POST"]) # noqa: F821 async def upload_info(): files = await request.files - file = files['file'] if files and files.get("file") else None + file = files["file"] if files and files.get("file") else None try: return get_json_result(data=FileService.upload_info(current_user.id, file, request.args.get("url"))) except Exception as e: - return server_error_response(e) + return server_error_response(e) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index d24cc4b94..6058c6b69 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -405,7 +405,10 @@ class LLMBundle(LLM4Tenant): async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): total_tokens = 0 ans = "" - if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"): + _bundle_is_tools = self.is_tools + _mdl_is_tools = getattr(self.mdl, "is_tools", False) + _has_with_tools = hasattr(self.mdl, "async_chat_streamly_with_tools") + if _bundle_is_tools and _mdl_is_tools and _has_with_tools: stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None) elif hasattr(self.mdl, "async_chat_streamly"): stream_fn = getattr(self.mdl, "async_chat_streamly", None) @@ -425,7 +428,7 @@ class LLMBundle(LLM4Tenant): total_tokens = txt break - if txt.endswith(""): + if txt.endswith("") and ans.endswith(""): ans = ans[: -len("")] if not self.verbose_tool_use: @@ -468,7 +471,7 @@ class LLMBundle(LLM4Tenant): total_tokens = txt break - if txt.endswith(""): + if txt.endswith("") and ans.endswith(""): ans = ans[: -len("")] if not self.verbose_tool_use: diff --git a/common/constants.py b/common/constants.py index 24530c457..274255401 100644 --- a/common/constants.py +++ b/common/constants.py @@ -14,11 +14,14 @@ # limitations under the License. # +import os from enum import Enum, IntEnum from strenum import StrEnum SERVICE_CONF = "service_conf.yaml" RAG_FLOW_SERVICE_NAME = "ragflow" +SANDBOX_ARTIFACT_BUCKET = os.environ.get("SANDBOX_ARTIFACT_BUCKET", "sandbox-artifacts") +SANDBOX_ARTIFACT_EXPIRE_DAYS = int(os.environ.get("SANDBOX_ARTIFACT_EXPIRE_DAYS", "7")) class CustomEnum(Enum): diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 3e6a1fce7..b6d2b7ff1 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -379,6 +379,13 @@ "rank": "950", "url" : "https://dashscope.aliyuncs.com/compatible-mode/v1", "llm": [ + { + "llm_name": "qwen3.5-122b-a10b", + "tags": "LLM,CHAT,128k", + "max_tokens": 128000, + "model_type": "chat", + "is_tools": true + }, { "llm_name": "Moonshot-Kimi-K2-Instruct", "tags": "LLM,CHAT,128K", diff --git a/docker/.env b/docker/.env index 0440c5c01..858c053d8 100644 --- a/docker/.env +++ b/docker/.env @@ -261,6 +261,10 @@ REGISTER_ENABLED=1 # SANDBOX_ENABLE_SECCOMP=false # SANDBOX_MAX_MEMORY=256m # b, k, m, g # SANDBOX_TIMEOUT=10s # s, m, 1m30s +# The MinIO bucket name for storing sandbox-generated artifacts (charts, files, etc.). +SANDBOX_ARTIFACT_BUCKET=sandbox-artifacts +# Number of days before sandbox artifacts are automatically deleted from storage. +SANDBOX_ARTIFACT_EXPIRE_DAYS=7 # Enable DocLing USE_DOCLING=false diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 8880d4d61..8cbd8933c 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -324,6 +324,34 @@ class Base(ABC): hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)}) return hist + def _append_history_batch(self, hist, results): + """ + Append a batch of tool calls to history following the OpenAI protocol: + one assistant message containing all tool_calls, followed by one tool message per call. + results: list of (tool_call, name, args, result, error) + """ + hist.append({ + "role": "assistant", + "tool_calls": [ + { + "index": tc.index, + "id": tc.id, + "function": {"name": tc.function.name, "arguments": tc.function.arguments}, + "type": "function", + } + for tc, _, _, _, _ in results + ], + }) + for tc, _, _, result, err in results: + if err: + content = str(err) + elif isinstance(result, dict): + content = json.dumps(result, ensure_ascii=False) + else: + content = str(result) + hist.append({"role": "tool", "tool_call_id": tc.id, "content": content}) + return hist + def bind_tools(self, toolcall_session, tools): if not (toolcall_session and tools): return @@ -360,18 +388,24 @@ class Base(ABC): return ans, tk_count - for tool_call in response.choices[0].message.tool_calls: - logging.info(f"Response {tool_call=}") - name = tool_call.function.name + async def _exec_tool(tc): + name = tc.function.name try: - args = json_repair.loads(tool_call.function.arguments) - tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args) - history = self._append_history(history, tool_call, tool_response) - ans += self._verbose_tool_use(name, args, tool_response) + args = json_repair.loads(tc.function.arguments) + if hasattr(self.toolcall_session, "tool_call_async"): + result = await self.toolcall_session.tool_call_async(name, args) + else: + result = await thread_pool_exec(self.toolcall_session.tool_call, name, args) + return tc, name, args, result, None except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - ans += self._verbose_tool_use(name, {}, str(e)) + logging.exception(f"Tool call failed: {tc}") + return tc, name, {}, None, e + + logging.info(f"Response tool_calls={response.choices[0].message.tool_calls}") + results = await asyncio.gather(*[_exec_tool(tc) for tc in response.choices[0].message.tool_calls]) + history = self._append_history_batch(history, results) + for tc, name, args, result, err in results: + ans += self._verbose_tool_use(name, args, err if err else result) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) @@ -398,9 +432,9 @@ class Base(ABC): for attempt in range(self.max_retries + 1): history = deepcopy(hist) try: - for _ in range(self.max_rounds + 1): + for _round in range(self.max_rounds + 1): reasoning_start = False - logging.info(f"{tools=}") + logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}") response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) @@ -450,22 +484,36 @@ class Base(ABC): if finish_reason == "length": yield self._length_stop("") - if answer: + if answer and not final_tool_calls: + logging.info(f"[ToolLoop] round={_round} completed with text response, exiting") yield total_tokens return - for tool_call in final_tool_calls.values(): - name = tool_call.function.name + async def _exec_tool(tc): + name = tc.function.name try: - args = json_repair.loads(tool_call.function.arguments) - yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args) - history = self._append_history(history, tool_call, tool_response) - yield self._verbose_tool_use(name, args, tool_response) + args = json_repair.loads(tc.function.arguments) + if hasattr(self.toolcall_session, "tool_call_async"): + result = await self.toolcall_session.tool_call_async(name, args) + else: + result = await thread_pool_exec(self.toolcall_session.tool_call, name, args) + return tc, name, args, result, None except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - yield self._verbose_tool_use(name, {}, str(e)) + logging.exception(f"Tool call failed: {tc}") + return tc, name, {}, None, e + + tcs = list(final_tool_calls.values()) + logging.info(f"[ToolLoop] round={_round} executing {len(tcs)} tool(s): {[tc.function.name for tc in tcs]}") + for tc in tcs: + try: + args = json_repair.loads(tc.function.arguments) + except Exception: + args = {} + yield self._verbose_tool_use(tc.function.name, args, "Begin to call...") + results = await asyncio.gather(*[_exec_tool(tc) for tc in tcs]) + history = self._append_history_batch(history, results) + for tc, name, args, result, err in results: + yield self._verbose_tool_use(name, args, err if err else result) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) @@ -1419,6 +1467,34 @@ class LiteLLMBase(ABC): hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)}) return hist + def _append_history_batch(self, hist, results): + """ + Append a batch of tool calls to history following the OpenAI protocol: + one assistant message containing all tool_calls, followed by one tool message per call. + results: list of (tool_call, name, args, result, error) + """ + hist.append({ + "role": "assistant", + "tool_calls": [ + { + "index": tc.index, + "id": tc.id, + "function": {"name": tc.function.name, "arguments": tc.function.arguments}, + "type": "function", + } + for tc, _, _, _, _ in results + ], + }) + for tc, _, _, result, err in results: + if err: + content = str(err) + elif isinstance(result, dict): + content = json.dumps(result, ensure_ascii=False) + else: + content = str(result) + hist.append({"role": "tool", "tool_call_id": tc.id, "content": content}) + return hist + def bind_tools(self, toolcall_session, tools): if not (toolcall_session and tools): return @@ -1463,18 +1539,24 @@ class LiteLLMBase(ABC): ans = self._length_stop(ans) return ans, tk_count - for tool_call in message.tool_calls: - logging.info(f"Response {tool_call=}") - name = tool_call.function.name + async def _exec_tool(tc): + name = tc.function.name try: - args = json_repair.loads(tool_call.function.arguments) - tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args) - history = self._append_history(history, tool_call, tool_response) - ans += self._verbose_tool_use(name, args, tool_response) + args = json_repair.loads(tc.function.arguments) + if hasattr(self.toolcall_session, "tool_call_async"): + result = await self.toolcall_session.tool_call_async(name, args) + else: + result = await thread_pool_exec(self.toolcall_session.tool_call, name, args) + return tc, name, args, result, None except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - ans += self._verbose_tool_use(name, {}, str(e)) + logging.exception(f"Tool call failed: {tc}") + return tc, name, {}, None, e + + logging.info(f"Response tool_calls={message.tool_calls}") + results = await asyncio.gather(*[_exec_tool(tc) for tc in message.tool_calls]) + history = self._append_history_batch(history, results) + for tc, name, args, result, err in results: + ans += self._verbose_tool_use(name, args, err if err else result) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) @@ -1503,9 +1585,9 @@ class LiteLLMBase(ABC): for attempt in range(self.max_retries + 1): history = deepcopy(hist) try: - for _ in range(self.max_rounds + 1): + for _round in range(self.max_rounds + 1): reasoning_start = False - logging.info(f"{tools=}") + logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}") completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) response = await litellm.acompletion( @@ -1560,22 +1642,36 @@ class LiteLLMBase(ABC): if finish_reason == "length": yield self._length_stop("") - if answer: + if answer and not final_tool_calls: + logging.info(f"[ToolLoop] round={_round} completed with text response, exiting") yield total_tokens return - for tool_call in final_tool_calls.values(): - name = tool_call.function.name + async def _exec_tool(tc): + name = tc.function.name try: - args = json_repair.loads(tool_call.function.arguments) - yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args) - history = self._append_history(history, tool_call, tool_response) - yield self._verbose_tool_use(name, args, tool_response) + args = json_repair.loads(tc.function.arguments) + if hasattr(self.toolcall_session, "tool_call_async"): + result = await self.toolcall_session.tool_call_async(name, args) + else: + result = await thread_pool_exec(self.toolcall_session.tool_call, name, args) + return tc, name, args, result, None except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - yield self._verbose_tool_use(name, {}, str(e)) + logging.exception(f"Tool call failed: {tc}") + return tc, name, {}, None, e + + tcs = list(final_tool_calls.values()) + logging.info(f"[ToolLoop] round={_round} executing {len(tcs)} tool(s): {[tc.function.name for tc in tcs]}") + for tc in tcs: + try: + args = json_repair.loads(tc.function.arguments) + except Exception: + args = {} + yield self._verbose_tool_use(tc.function.name, args, "Begin to call...") + results = await asyncio.gather(*[_exec_tool(tc) for tc in tcs]) + history = self._append_history_batch(history, results) + for tc, name, args, result, err in results: + yield self._verbose_tool_use(name, args, err if err else result) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) diff --git a/web/src/components/markdown-content/index.tsx b/web/src/components/markdown-content/index.tsx index 603e0552e..72247674f 100644 --- a/web/src/components/markdown-content/index.tsx +++ b/web/src/components/markdown-content/index.tsx @@ -211,7 +211,7 @@ const MarkdownContent = ({ const renderReference = useCallback( (text: string) => { - let replacedText = reactStringReplace(text, currentReg, (match, i) => { + const replacedText = reactStringReplace(text, currentReg, (match, i) => { const chunkIndex = getChunkIndex(match); return ( @@ -242,9 +242,7 @@ const MarkdownContent = ({ remarkPlugins={[remarkGfm, remarkMath]} components={ { - p: ({ children, node, ...props }: any) => ( -

{children}

- ), + p: ({ children, ...props }: any) =>

{children}

, 'custom-typography': ({ children }: { children: string }) => renderReference(children), code(props: any) { diff --git a/web/src/components/next-markdown-content/index.module.less b/web/src/components/next-markdown-content/index.module.less index 3d544b112..71de7f615 100644 --- a/web/src/components/next-markdown-content/index.module.less +++ b/web/src/components/next-markdown-content/index.module.less @@ -79,3 +79,27 @@ display: inline-block; max-width: 40px; } + +.artifactImageWrapper { + display: block; + margin: 8px 0; +} + +.artifactImage { + max-width: 100%; + max-height: 60vh; + border-radius: 8px; + border: 1px solid #e5e7eb; + display: block; +} + +.artifactDownload { + display: inline-block; + margin-top: 4px; + font-size: 12px; + color: #1677ff; + text-decoration: none; + &:hover { + text-decoration: underline; + } +} diff --git a/web/src/components/next-markdown-content/index.tsx b/web/src/components/next-markdown-content/index.tsx index 903a526c3..c13cb6159 100644 --- a/web/src/components/next-markdown-content/index.tsx +++ b/web/src/components/next-markdown-content/index.tsx @@ -2,8 +2,10 @@ import Image from '@/components/image'; import SvgIcon from '@/components/svg-icon'; import { IReferenceChunk, IReferenceObject } from '@/interfaces/database/chat'; import { getExtension } from '@/utils/document-util'; +import { downloadFileFromBlob } from '@/utils/file-util'; +import request from '@/utils/request'; import DOMPurify from 'dompurify'; -import { memo, useCallback, useEffect, useMemo } from 'react'; +import { memo, useCallback, useEffect, useMemo, useState } from 'react'; import Markdown from 'react-markdown'; import SyntaxHighlighter from 'react-syntax-highlighter'; import rehypeKatex from 'rehype-katex'; @@ -38,9 +40,120 @@ import { HoverCardContent, HoverCardTrigger, } from '../ui/hover-card'; +import message from '../ui/message'; import styles from './index.module.less'; const getChunkIndex = (match: string) => parseCitationIndex(match); + +const isArtifactUrl = (url?: string) => + Boolean(url && url.includes('/document/artifact/')); + +const fetchArtifactBlob = async (url: string): Promise => { + const response = await request(url, { + method: 'GET', + responseType: 'blob', + }); + + return response.data as Blob; +}; + +const getArtifactName = (url?: string, fallback?: string) => + fallback || url?.split('/').pop()?.split('?')[0] || 'artifact'; + +function ArtifactLink({ + href, + className, + children, +}: { + href: string; + className?: string; + children: React.ReactNode; +}) { + const handleClick = useCallback( + async (e: React.MouseEvent) => { + e.preventDefault(); + try { + const blob = await fetchArtifactBlob(href); + const objectUrl = URL.createObjectURL(blob); + window.open(objectUrl, '_blank', 'noopener,noreferrer'); + window.setTimeout(() => URL.revokeObjectURL(objectUrl), 60 * 1000); + } catch { + message.error('Failed to open artifact'); + } + }, + [href], + ); + + return ( + + {children} + + ); +} + +function ArtifactImage({ + src, + alt, + downloadLabel, +}: { + src: string; + alt?: string; + downloadLabel: string; +}) { + const [imageSrc, setImageSrc] = useState(''); + + useEffect(() => { + let objectUrl = ''; + let active = true; + + const load = async () => { + try { + const blob = await fetchArtifactBlob(src); + objectUrl = URL.createObjectURL(blob); + if (active) { + setImageSrc(objectUrl); + } + } catch { + message.error('Failed to load artifact image'); + } + }; + + load(); + + return () => { + active = false; + if (objectUrl) { + URL.revokeObjectURL(objectUrl); + } + }; + }, [alt, src]); + + const handleDownload = useCallback(async () => { + try { + const blob = await fetchArtifactBlob(src); + downloadFileFromBlob(blob, getArtifactName(src, alt)); + } catch { + message.error('Failed to download artifact'); + } + }, [alt, src]); + + return ( + + {imageSrc ? ( + {alt + ) : ( + + )} + + + ); +} // TODO: The display of the table is inconsistent with the display previously placed in the MessageItem. function MarkdownContent({ reference, @@ -213,7 +326,7 @@ function MarkdownContent({ const renderReference = useCallback( (text: string) => { - let replacedText = reactStringReplace(text, currentReg, (match, i) => { + const replacedText = reactStringReplace(text, currentReg, (match, i) => { const chunkIndex = getChunkIndex(match); return ( @@ -244,11 +357,44 @@ function MarkdownContent({ remarkPlugins={[remarkGfm, remarkMath]} components={ { - p: ({ children, node, ...props }: any) => ( -

{children}

- ), + p: ({ children, ...props }: any) =>

{children}

, 'custom-typography': ({ children }: { children: string }) => renderReference(children), + a({ href, children, ...props }: any) { + if (isArtifactUrl(href)) { + return ( + + {children} + + ); + } + return ( + + {children} + + ); + }, + img({ src, alt, ...props }: any) { + if (isArtifactUrl(src)) { + return ( + + ); + } + return ( + + {alt + + ); + }, code(props: any) { const { children, className, ...rest } = props; const restProps = omit(rest, 'node'); diff --git a/web/src/pages/agent/form/code-form/index.tsx b/web/src/pages/agent/form/code-form/index.tsx index 2883fdf46..f9797ad24 100644 --- a/web/src/pages/agent/form/code-form/index.tsx +++ b/web/src/pages/agent/form/code-form/index.tsx @@ -42,6 +42,12 @@ const options = [ ].map((x) => ({ value: x, label: x })); const DynamicFieldName = 'outputs'; +const CodeSystemOutputs = { + content: { + type: 'string', + value: '', + }, +}; function CodeForm({ node }: INextOperatorForm) { const formData = node?.data.form as ICodeForm; @@ -159,7 +165,12 @@ function CodeForm({ node }: INextOperatorForm) { )}
- +
); diff --git a/web/src/utils/canvas-util.tsx b/web/src/utils/canvas-util.tsx index d3b825607..818dc9cf2 100644 --- a/web/src/utils/canvas-util.tsx +++ b/web/src/utils/canvas-util.tsx @@ -61,13 +61,28 @@ export function buildSecondaryOutputOptions( })); } +function getNodeOutputs(x: BaseNode) { + const outputs = x.data.form?.outputs ?? {}; + if (x.data.label !== Operator.Code) { + return outputs; + } + + return { + ...outputs, + content: outputs.content ?? { + type: JsonSchemaDataType.String, + value: '', + }, + }; +} + export function buildOutputOptions(x: BaseNode) { return { label: x.data.name, value: x.id, title: x.data.name, options: buildSecondaryOutputOptions( - x.data.form.outputs, + getNodeOutputs(x), x.id, x.data.name, , @@ -83,7 +98,7 @@ export function buildNodeOutputOptions({ nodeIds: string[]; }) { const nodeWithOutputList = nodes.filter( - (x) => nodeIds.some((y) => y === x.id) && !isEmpty(x.data?.form?.outputs), + (x) => nodeIds.some((y) => y === x.id) && !isEmpty(getNodeOutputs(x)), ); return nodeWithOutputList.map((x) => buildOutputOptions(x)); @@ -114,7 +129,7 @@ export function buildChildOutputOptions({ nodeId?: string; }) { const nodeWithOutputList = nodes.filter( - (x) => x.parentId === nodeId && !isEmpty(x.data?.form?.outputs), + (x) => x.parentId === nodeId && !isEmpty(getNodeOutputs(x)), ); return nodeWithOutputList.map((x) => buildOutputOptions(x));