mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
feat(sandbox): implement sandbox runtime checks and integrate bash tool invocation in LLMNode
This commit is contained in:
@ -0,0 +1,3 @@
|
||||
from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool
|
||||
|
||||
__all__ = ["SandboxBashTool"]
|
||||
96
api/core/tools/builtin_tool/providers/sandbox/bash_tool.py
Normal file
96
api/core/tools/builtin_tool/providers/sandbox/bash_tool.py
Normal file
@ -0,0 +1,96 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
SANDBOX_BASH_TOOL_NAME = "bash"
|
||||
SANDBOX_BASH_TOOL_PROVIDER = "sandbox"
|
||||
COMMAND_TIMEOUT_SECONDS = 60
|
||||
|
||||
|
||||
class SandboxBashTool(Tool):
|
||||
def __init__(self, sandbox: VirtualEnvironment, tenant_id: str):
|
||||
self._sandbox = sandbox
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="Dify",
|
||||
name=SANDBOX_BASH_TOOL_NAME,
|
||||
label=I18nObject(en_US="Bash", zh_Hans="Bash"),
|
||||
provider=SANDBOX_BASH_TOOL_PROVIDER,
|
||||
),
|
||||
parameters=[
|
||||
ToolParameter.get_simple_instance(
|
||||
name="command",
|
||||
llm_description="The bash command to execute in the sandbox environment",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US="Execute bash commands in the sandbox environment",
|
||||
zh_Hans="在沙盒环境中执行 bash 命令",
|
||||
),
|
||||
llm="Execute bash commands in the sandbox environment. "
|
||||
"Use this tool to run shell commands, scripts, or interact with the system. "
|
||||
"The command will be executed in an isolated sandbox environment.",
|
||||
),
|
||||
)
|
||||
|
||||
runtime = ToolRuntime(tenant_id=tenant_id)
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
command = tool_parameters.get("command", "")
|
||||
if not command:
|
||||
yield self.create_text_message("Error: No command provided")
|
||||
return
|
||||
|
||||
connection_handle = self._sandbox.establish_connection()
|
||||
try:
|
||||
cmd_list = ["sh", "-c", command]
|
||||
future = self._sandbox.run_command(connection_handle, cmd_list)
|
||||
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
|
||||
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
|
||||
exit_code = result.exit_code
|
||||
|
||||
output_parts: list[str] = []
|
||||
if stdout:
|
||||
output_parts.append(f"stdout:\n{stdout}")
|
||||
if stderr:
|
||||
output_parts.append(f"stderr:\n{stderr}")
|
||||
output_parts.append(f"exit_code: {exit_code}")
|
||||
|
||||
yield self.create_text_message("\n".join(output_parts))
|
||||
|
||||
except TimeoutError:
|
||||
yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
yield self.create_text_message(f"Error: {e!s}")
|
||||
finally:
|
||||
self._sandbox.release_connection(connection_handle)
|
||||
@ -1,6 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Final
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool
|
||||
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
@ -83,6 +89,10 @@ class SandboxManager:
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
return workflow_execution_id in cls._shards[shard_index]
|
||||
|
||||
@classmethod
|
||||
def is_sandbox_runtime(cls, workflow_execution_id: str) -> bool:
|
||||
return cls.has(workflow_execution_id)
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
for lock in cls._shard_locks:
|
||||
@ -98,3 +108,28 @@ class SandboxManager:
|
||||
@classmethod
|
||||
def count(cls) -> int:
|
||||
return sum(len(shard) for shard in cls._shards)
|
||||
|
||||
@classmethod
|
||||
def get_bash_tool(
|
||||
cls,
|
||||
workflow_execution_id: str,
|
||||
tenant_id: str,
|
||||
configured_tools: list[Tool],
|
||||
) -> SandboxBashTool:
|
||||
from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool
|
||||
|
||||
sandbox = cls.get(workflow_execution_id)
|
||||
if sandbox is None:
|
||||
raise RuntimeError(f"Sandbox not found for workflow_execution_id={workflow_execution_id}")
|
||||
|
||||
cls._initialize_tools_in_sandbox(sandbox, configured_tools)
|
||||
|
||||
return SandboxBashTool(sandbox=sandbox, tenant_id=tenant_id)
|
||||
|
||||
@classmethod
|
||||
def _initialize_tools_in_sandbox(
|
||||
cls,
|
||||
sandbox: VirtualEnvironment,
|
||||
configured_tools: list[Tool],
|
||||
) -> None:
|
||||
raise NotImplementedError("TODO: Initialize configured tools in sandbox environment")
|
||||
|
||||
@ -61,6 +61,7 @@ from core.variables import (
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.entities.tool_entities import ToolCallResult
|
||||
@ -261,18 +262,33 @@ class LLMNode(Node[LLMNodeData]):
|
||||
generation_data: LLMGenerationData | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
# Check if tools are configured
|
||||
if self.tool_call_enabled:
|
||||
# Use tool-enabled invocation (Agent V2 style)
|
||||
generator = self._invoke_llm_with_tools(
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
files=files,
|
||||
variable_pool=variable_pool,
|
||||
node_inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
|
||||
is_sandbox_runtime = (
|
||||
workflow_execution_id is not None
|
||||
and SandboxManager.is_sandbox_runtime(workflow_execution_id)
|
||||
)
|
||||
|
||||
if is_sandbox_runtime:
|
||||
generator = self._invoke_llm_with_sandbox(
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
files=files,
|
||||
variable_pool=variable_pool,
|
||||
node_inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
else:
|
||||
generator = self._invoke_llm_with_tools(
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
files=files,
|
||||
variable_pool=variable_pool,
|
||||
node_inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
else:
|
||||
# Use traditional LLM invocation
|
||||
generator = LLMNode.invoke_llm(
|
||||
@ -1565,7 +1581,52 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Process outputs and return generation result
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
return result
|
||||
|
||||
def _invoke_llm_with_sandbox(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
files: Sequence[File],
|
||||
variable_pool: VariablePool,
|
||||
node_inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
from core.agent.entities import AgentEntity
|
||||
|
||||
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
raise LLMNodeError("workflow_execution_id is required for sandbox runtime mode")
|
||||
|
||||
configured_tools = self._prepare_tool_instances(variable_pool)
|
||||
|
||||
bash_tool = SandboxManager.get_bash_tool(
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
tenant_id=self.tenant_id,
|
||||
configured_tools=configured_tools,
|
||||
)
|
||||
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=[],
|
||||
model_instance=model_instance,
|
||||
tools=[bash_tool],
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 10,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
)
|
||||
|
||||
outputs = strategy.run(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user