refactor: consolidate sandbox management and initialization

- Moved sandbox-related classes and functions into a dedicated module for better organization.
- Updated the sandbox initialization process to streamline asset management and environment setup.
- Removed deprecated constants and refactored related code to utilize new sandbox entities.
- Enhanced the workflow context to support sandbox integration, allowing for improved state management during execution.
- Adjusted various components to utilize the new sandbox structure, ensuring compatibility across the application.
This commit is contained in:
Harry
2026-01-21 20:42:19 +08:00
parent 1fcff5f8d1
commit 9ed83a808a
30 changed files with 449 additions and 545 deletions

View File

@ -17,7 +17,6 @@ from core.workflow.context.execution_context import (
register_context_capturer,
reset_context_provider,
)
from core.workflow.context.models import SandboxContext
__all__ = [
"AppContext",
@ -25,7 +24,6 @@ __all__ = [
"ExecutionContext",
"IExecutionContext",
"NullAppContext",
"SandboxContext",
"capture_current_context",
"read_context",
"register_context",

View File

@ -1,13 +1,3 @@
from __future__ import annotations
from pydantic import AnyHttpUrl, BaseModel
class SandboxContext(BaseModel):
"""Typed context for sandbox integration. All fields optional by design."""
sandbox_url: AnyHttpUrl | None = None
sandbox_token: str | None = None # optional, if later needed for auth
__all__ = ["SandboxContext"]
__all__: list[str] = []

View File

@ -2,11 +2,9 @@ import logging
from collections.abc import Mapping, Sequence
from typing import Any
from core.sandbox import SandboxManager, sandbox_debug
from core.sandbox.vm import SandboxBuilder
from core.sandbox import sandbox_debug
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.helpers import submit_command, with_connection
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@ -24,19 +22,6 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60
class CommandNode(Node[CommandNodeData]):
node_type = NodeType.COMMAND
# FIXME(Mairuis): should read sandbox from workflow run context...
def _get_sandbox(self) -> VirtualEnvironment | None:
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
if not workflow_execution_id:
return None
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
if sandbox_by_workflow_run_id is not None:
return sandbox_by_workflow_run_id
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
if sandbox_by_draft_id is not None:
return sandbox_by_draft_id
return None
def _render_template(self, template: str) -> str:
parser = VariableTemplateParser(template=template)
selectors = parser.extract_variable_selectors()
@ -65,7 +50,7 @@ class CommandNode(Node[CommandNodeData]):
return "1"
def _run(self) -> NodeRunResult:
sandbox = self._get_sandbox()
sandbox = self.graph_runtime_state.sandbox
if sandbox is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -88,12 +73,12 @@ class CommandNode(Node[CommandNodeData]):
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
try:
with with_connection(sandbox) as conn:
with with_connection(sandbox.vm) as conn:
command = ["bash", "-c", raw_command]
sandbox_debug("command_node", "command", command)
future = submit_command(sandbox, conn, command, cwd=working_directory)
future = submit_command(sandbox.vm, conn, command, cwd=working_directory)
result = future.result(timeout=timeout)
outputs: dict[str, Any] = {

View File

@ -50,8 +50,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.sandbox import SandboxBashSession, SandboxManager
from core.sandbox.vm import SandboxBuilder
from core.sandbox import Sandbox
from core.sandbox.bash.session import SandboxBashSession
from core.tools.__base.tool import Tool
from core.tools.signature import sign_upload_file
from core.tools.tool_manager import ToolManager
@ -64,7 +64,6 @@ from core.variables import (
ObjectSegment,
StringSegment,
)
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
from core.workflow.entities.tool_entities import ToolCallResult
@ -174,19 +173,6 @@ class LLMNode(Node[LLMNodeData]):
def version(cls) -> str:
return "1"
# FIXME(Mairuis): should read sandbox from workflow run context...
def _get_sandbox(self) -> VirtualEnvironment | None:
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
if not workflow_execution_id:
return None
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
if sandbox_by_workflow_run_id is not None:
return sandbox_by_workflow_run_id
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
if sandbox_by_draft_id is not None:
return sandbox_by_draft_id
return None
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
@ -301,8 +287,7 @@ class LLMNode(Node[LLMNodeData]):
generation_data: LLMGenerationData | None = None
structured_output: LLMStructuredOutput | None = None
# FIXME(Mairuis): should read sandbox from workflow run context...
sandbox = self._get_sandbox()
sandbox = self.graph_runtime_state.sandbox
if sandbox:
generator = self._invoke_llm_with_sandbox(
sandbox=sandbox,
@ -1839,7 +1824,7 @@ class LLMNode(Node[LLMNodeData]):
def _invoke_llm_with_sandbox(
self,
sandbox: VirtualEnvironment,
sandbox: Sandbox,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
@ -1849,23 +1834,14 @@ class LLMNode(Node[LLMNodeData]):
result: LLMGenerationData | None = None
with SandboxBashSession(
sandbox=sandbox,
tenant_id=self.tenant_id,
user_id=self.user_id,
node_id=self.id,
app_id=self.app_id,
# FIXME(Mairuis): should read from workflow run context...
assets_id=getattr(self, "assets_id", ""),
allow_tools=allow_tools,
) as sandbox_session:
with SandboxBashSession(sandbox=sandbox, node_id=self.id, allow_tools=allow_tools) as session:
prompt_files = self._extract_prompt_files(variable_pool)
model_features = self._get_model_features(model_instance)
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=model_instance,
tools=[sandbox_session.bash_tool],
tools=[session.bash_tool],
files=prompt_files,
max_iterations=self._node_data.max_iterations or 100,
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,

View File

@ -11,6 +11,7 @@ from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.sandbox.sandbox import Sandbox
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.runtime.variable_pool import VariablePool
@ -171,6 +172,8 @@ class GraphRuntimeState:
self._paused_nodes: set[str] = set()
self.stop_event: threading.Event = threading.Event()
self._sandbox: Sandbox | None = None
if graph is not None:
self.attach_graph(graph)
@ -294,6 +297,16 @@ class GraphRuntimeState:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
# ------------------------------------------------------------------
# Sandbox context (workflow-scoped)
# ------------------------------------------------------------------
@property
def sandbox(self) -> Sandbox | None:
return self._sandbox
def set_sandbox(self, sandbox: Sandbox) -> None:
self._sandbox = sandbox
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------

View File

@ -78,6 +78,10 @@ class ReadOnlyGraphRuntimeState(Protocol):
"""Get a single output value (returns a copy)."""
...
@property
def sandbox(self) -> Any:
...
def dumps(self) -> str:
"""Serialize the runtime state into a JSON snapshot (read-only)."""
...

View File

@ -82,6 +82,10 @@ class ReadOnlyGraphRuntimeStateWrapper:
def get_output(self, key: str, default: Any = None) -> Any:
return self._state.get_output(key, default)
@property
def sandbox(self) -> Any:
return self._state.sandbox
def dumps(self) -> str:
"""Serialize the underlying runtime state for external persistence."""
return self._state.dumps()

View File

@ -8,6 +8,7 @@ from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.sandbox import Sandbox
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.errors import WorkflowNodeRunFailedError
@ -128,6 +129,7 @@ class WorkflowEntry:
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
sandbox: Sandbox | None = None,
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Single step run workflow node
@ -156,6 +158,9 @@ class WorkflowEntry:
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
if sandbox is not None:
graph_runtime_state.set_sandbox(sandbox)
# init workflow run state
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,