mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 08:28:03 +08:00
refactor: consolidate sandbox management and initialization
- Moved sandbox-related classes and functions into a dedicated module for better organization. - Updated the sandbox initialization process to streamline asset management and environment setup. - Removed deprecated constants and refactored related code to utilize new sandbox entities. - Enhanced the workflow context to support sandbox integration, allowing for improved state management during execution. - Adjusted various components to utilize the new sandbox structure, ensuring compatibility across the application.
This commit is contained in:
@ -17,7 +17,6 @@ from core.workflow.context.execution_context import (
|
||||
register_context_capturer,
|
||||
reset_context_provider,
|
||||
)
|
||||
from core.workflow.context.models import SandboxContext
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
@ -25,7 +24,6 @@ __all__ = [
|
||||
"ExecutionContext",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"SandboxContext",
|
||||
"capture_current_context",
|
||||
"read_context",
|
||||
"register_context",
|
||||
|
||||
@ -1,13 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
|
||||
|
||||
class SandboxContext(BaseModel):
|
||||
"""Typed context for sandbox integration. All fields optional by design."""
|
||||
|
||||
sandbox_url: AnyHttpUrl | None = None
|
||||
sandbox_token: str | None = None # optional, if later needed for auth
|
||||
|
||||
|
||||
__all__ = ["SandboxContext"]
|
||||
__all__: list[str] = []
|
||||
|
||||
@ -2,11 +2,9 @@ import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox import SandboxManager, sandbox_debug
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.sandbox import sandbox_debug
|
||||
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
|
||||
from core.virtual_environment.__base.helpers import submit_command, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
@ -24,19 +22,6 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60
|
||||
class CommandNode(Node[CommandNodeData]):
|
||||
node_type = NodeType.COMMAND
|
||||
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
def _get_sandbox(self) -> VirtualEnvironment | None:
|
||||
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
return None
|
||||
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
|
||||
if sandbox_by_workflow_run_id is not None:
|
||||
return sandbox_by_workflow_run_id
|
||||
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
|
||||
if sandbox_by_draft_id is not None:
|
||||
return sandbox_by_draft_id
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str) -> str:
|
||||
parser = VariableTemplateParser(template=template)
|
||||
selectors = parser.extract_variable_selectors()
|
||||
@ -65,7 +50,7 @@ class CommandNode(Node[CommandNodeData]):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
sandbox = self._get_sandbox()
|
||||
sandbox = self.graph_runtime_state.sandbox
|
||||
if sandbox is None:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -88,12 +73,12 @@ class CommandNode(Node[CommandNodeData]):
|
||||
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
|
||||
|
||||
try:
|
||||
with with_connection(sandbox) as conn:
|
||||
with with_connection(sandbox.vm) as conn:
|
||||
command = ["bash", "-c", raw_command]
|
||||
|
||||
sandbox_debug("command_node", "command", command)
|
||||
|
||||
future = submit_command(sandbox, conn, command, cwd=working_directory)
|
||||
future = submit_command(sandbox.vm, conn, command, cwd=working_directory)
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
outputs: dict[str, Any] = {
|
||||
|
||||
@ -50,8 +50,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.sandbox import SandboxBashSession, SandboxManager
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.sandbox import Sandbox
|
||||
from core.sandbox.bash.session import SandboxBashSession
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@ -64,7 +64,6 @@ from core.variables import (
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.entities.tool_entities import ToolCallResult
|
||||
@ -174,19 +173,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
def _get_sandbox(self) -> VirtualEnvironment | None:
|
||||
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
return None
|
||||
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
|
||||
if sandbox_by_workflow_run_id is not None:
|
||||
return sandbox_by_workflow_run_id
|
||||
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
|
||||
if sandbox_by_draft_id is not None:
|
||||
return sandbox_by_draft_id
|
||||
return None
|
||||
|
||||
def _run(self) -> Generator:
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
@ -301,8 +287,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
generation_data: LLMGenerationData | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
sandbox = self._get_sandbox()
|
||||
sandbox = self.graph_runtime_state.sandbox
|
||||
if sandbox:
|
||||
generator = self._invoke_llm_with_sandbox(
|
||||
sandbox=sandbox,
|
||||
@ -1839,7 +1824,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
def _invoke_llm_with_sandbox(
|
||||
self,
|
||||
sandbox: VirtualEnvironment,
|
||||
sandbox: Sandbox,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
@ -1849,23 +1834,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
result: LLMGenerationData | None = None
|
||||
|
||||
with SandboxBashSession(
|
||||
sandbox=sandbox,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
node_id=self.id,
|
||||
app_id=self.app_id,
|
||||
# FIXME(Mairuis): should read from workflow run context...
|
||||
assets_id=getattr(self, "assets_id", ""),
|
||||
allow_tools=allow_tools,
|
||||
) as sandbox_session:
|
||||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, allow_tools=allow_tools) as session:
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=[sandbox_session.bash_tool],
|
||||
tools=[session.bash_tool],
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 100,
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
|
||||
@ -11,6 +11,7 @@ from typing import Any, Protocol
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
@ -171,6 +172,8 @@ class GraphRuntimeState:
|
||||
self._paused_nodes: set[str] = set()
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
self._sandbox: Sandbox | None = None
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
@ -294,6 +297,16 @@ class GraphRuntimeState:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sandbox context (workflow-scoped)
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def sandbox(self) -> Sandbox | None:
|
||||
return self._sandbox
|
||||
|
||||
def set_sandbox(self, sandbox: Sandbox) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@ -78,6 +78,10 @@ class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def sandbox(self) -> Any:
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the runtime state into a JSON snapshot (read-only)."""
|
||||
...
|
||||
|
||||
@ -82,6 +82,10 @@ class ReadOnlyGraphRuntimeStateWrapper:
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
return self._state.get_output(key, default)
|
||||
|
||||
@property
|
||||
def sandbox(self) -> Any:
|
||||
return self._state.sandbox
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the underlying runtime state for external persistence."""
|
||||
return self._state.dumps()
|
||||
|
||||
@ -8,6 +8,7 @@ from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@ -128,6 +129,7 @@ class WorkflowEntry:
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
sandbox: Sandbox | None = None,
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
@ -156,6 +158,9 @@ class WorkflowEntry:
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
if sandbox is not None:
|
||||
graph_runtime_state.set_sandbox(sandbox)
|
||||
|
||||
# init workflow run state
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
|
||||
Reference in New Issue
Block a user