refactor(dify_graph): introduce run_context and delegate child engine creation (#32964)

This commit is contained in:
99
2026-03-05 14:31:28 +08:00
committed by GitHub
parent 89a859ae32
commit 7432b58f82
78 changed files with 1281 additions and 733 deletions

View File

@ -3,7 +3,7 @@ from typing import Any
from pydantic import BaseModel, Field
from dify_graph.enums import InvokeFrom, UserFrom
DIFY_RUN_CONTEXT_KEY = "_dify"
class GraphInitParams(BaseModel):
@ -18,11 +18,7 @@ class GraphInitParams(BaseModel):
"""
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
run_context: Mapping[str, Any] = Field(..., description="runtime context")
call_depth: int = Field(..., description="call depth")

View File

@ -33,39 +33,6 @@ class SystemVariableKey(StrEnum):
INVOKE_FROM = "invoke_from"
class UserFrom(StrEnum):
ACCOUNT = "account"
END_USER = "end-user"
class InvokeFrom(StrEnum):
SERVICE_API = "service-api"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"
DEBUGGER = "debugger"
PUBLISHED_PIPELINE = "published"
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
return cls(value)
def to_source(self) -> str:
"""Get source of invoke from.
:return: source
"""
source_mapping = {
InvokeFrom.WEB_APP: "web_app",
InvokeFrom.DEBUGGER: "dev",
InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api",
}
return source_mapping.get(self, "dev")
class NodeType(StrEnum):
START = "start"
END = "end"

View File

@ -9,7 +9,7 @@ from __future__ import annotations
import logging
import queue
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, cast, final
from dify_graph.context import capture_current_context
@ -27,6 +27,7 @@ from dify_graph.graph_events import (
GraphRunSucceededEvent,
)
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from dify_graph.runtime.graph_runtime_state import GraphProtocol
@ -49,6 +50,7 @@ from .protocols.command_channel import CommandChannel
from .worker_management import WorkerPool
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.graph_engine.domain.graph_execution import GraphExecution
from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator
@ -74,6 +76,7 @@ class GraphEngine:
graph_runtime_state: GraphRuntimeState,
command_channel: CommandChannel,
config: GraphEngineConfig = _DEFAULT_CONFIG,
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
@ -83,6 +86,9 @@ class GraphEngine:
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
self._config = config
self._child_engine_builder = child_engine_builder
if child_engine_builder is not None:
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
# Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
@ -214,6 +220,25 @@ class GraphEngine:
self._bind_layer_context(layer)
return self
def create_child_engine(
self,
*,
workflow_id: str,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
graph_config: dict[str, object] | Mapping[str, object],
root_node_id: str,
layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (),
) -> GraphEngine:
return self._graph_runtime_state.create_child_engine(
workflow_id=workflow_id,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
graph_config=graph_config,
root_node_id=root_node_id,
layers=layers,
)
def run(self) -> Generator[GraphEngineEvent, None, None]:
"""
Execute the graph using the modular architecture.

View File

@ -80,9 +80,11 @@ class AgentNode(Node[AgentNodeData]):
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
@ -120,8 +122,8 @@ class AgentNode(Node[AgentNodeData]):
try:
message_stream = strategy.invoke(
params=parameters,
user_id=self.user_id,
app_id=self.app_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
@ -144,8 +146,8 @@ class AgentNode(Node[AgentNodeData]):
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
@ -283,8 +285,13 @@ class AgentNode(Node[AgentNodeData]):
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
dify_ctx = self.require_dify_context()
tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
dify_ctx.tenant_id,
dify_ctx.app_id,
entity,
dify_ctx.invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
@ -396,7 +403,8 @@ class AgentNode(Node[AgentNodeData]):
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
dify_ctx = self.require_dify_context()
plugins = manager.list_plugins(dify_ctx.tenant_id)
try:
current_plugin = next(
plugin
@ -417,8 +425,11 @@ class AgentNode(Node[AgentNodeData]):
return None
conversation_id = conversation_id_variable.value
dify_ctx = self.require_dify_context()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
stmt = select(Conversation).where(
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
)
conversation = session.scalar(stmt)
if not conversation:
@ -429,9 +440,10 @@ class AgentNode(Node[AgentNodeData]):
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
dify_ctx = self.require_dify_context()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
@ -440,7 +452,7 @@ class AgentNode(Node[AgentNodeData]):
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
tenant_id=dify_ctx.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,

View File

@ -8,10 +8,11 @@ from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from types import MappingProxyType
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import (
ErrorStrategy,
NodeExecutionType,
@ -64,10 +65,28 @@ from libs.datetime_utils import naive_utc_now
from .entities import BaseNodeData, RetryConfig
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
_MISSING_RUN_CONTEXT_VALUE = object()
logger = logging.getLogger(__name__)
class DifyRunContextProtocol(Protocol):
tenant_id: str
app_id: str
user_id: str
user_from: Any
invoke_from: Any
class _MappingDifyRunContext:
def __init__(self, mapping: Mapping[str, Any]) -> None:
self.tenant_id = str(mapping["tenant_id"])
self.app_id = str(mapping["app_id"])
self.user_id = str(mapping["user_id"])
self.user_from = mapping["user_from"]
self.invoke_from = mapping["invoke_from"]
class Node(Generic[NodeDataT]):
"""BaseNode serves as the foundational class for all node implementations.
@ -227,14 +246,10 @@ class Node(Generic[NodeDataT]):
graph_runtime_state: GraphRuntimeState,
) -> None:
self._graph_init_params = graph_init_params
self._run_context = MappingProxyType(dict(graph_init_params.run_context))
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.workflow_call_depth = graph_init_params.call_depth
self.graph_runtime_state = graph_runtime_state
self.state: NodeState = NodeState.UNKNOWN # node execution state
@ -263,6 +278,38 @@ class Node(Generic[NodeDataT]):
def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
@property
def run_context(self) -> Mapping[str, Any]:
return self._run_context
def get_run_context_value(self, key: str, default: Any = None) -> Any:
return self._run_context.get(key, default)
def require_run_context_value(self, key: str) -> Any:
value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE)
if value is _MISSING_RUN_CONTEXT_VALUE:
raise ValueError(f"run_context missing required key: {key}")
return value
def require_dify_context(self) -> DifyRunContextProtocol:
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
if raw_ctx is None:
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
if isinstance(raw_ctx, Mapping):
missing_keys = [
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
]
if missing_keys:
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
return _MappingDifyRunContext(raw_ctx)
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
if not hasattr(raw_ctx, attr):
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
return cast(DifyRunContextProtocol, raw_ctx)
@property
def execution_id(self) -> str:
return self._node_execution_id

View File

@ -52,6 +52,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
Run the datasource node
"""
dify_ctx = self.require_dify_context()
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
@ -75,7 +76,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
datasource_info["icon"] = self.datasource_manager.get_icon_url(
provider_id=provider_id,
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
tenant_id=dify_ctx.tenant_id,
datasource_type=datasource_type.value,
)
@ -104,11 +105,11 @@ class DatasourceNode(Node[DatasourceNodeData]):
yield from self.datasource_manager.stream_node_events(
node_id=self._node_id,
user_id=self.user_id,
user_id=dify_ctx.user_id,
datasource_name=node_data.datasource_name or "",
datasource_type=datasource_type.value,
provider_id=provider_id,
tenant_id=self.tenant_id,
tenant_id=dify_ctx.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=credential_id,
@ -136,7 +137,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
raise DatasourceNodeError("File is not exist")
file_info = self.datasource_manager.get_upload_file_by_id(
file_id=related_id, tenant_id=self.tenant_id
file_id=related_id, tenant_id=dify_ctx.tenant_id
)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())

View File

@ -212,6 +212,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
"""
Extract files from response by checking both Content-Type header and URL
"""
dify_ctx = self.require_dify_context()
files: list[File] = []
is_file = response.is_file
content_type = response.content_type
@ -236,8 +237,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=mime_type,
@ -249,7 +250,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
tenant_id=dify_ctx.tenant_id,
)
files.append(file)

View File

@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
@ -31,6 +31,8 @@ if TYPE_CHECKING:
_SELECTED_BRANCH_KEY = "selected_branch"
_INVOKE_FROM_DEBUGGER = "debugger"
_INVOKE_FROM_EXPLORE = "explore"
logger = logging.getLogger(__name__)
@ -155,30 +157,39 @@ class HumanInputNode(Node[HumanInputNodeData]):
return resolved_defaults
def _should_require_console_recipient(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
invoke_from = self._invoke_from_value()
if invoke_from == _INVOKE_FROM_DEBUGGER:
return True
if self.invoke_from == InvokeFrom.EXPLORE:
if invoke_from == _INVOKE_FROM_EXPLORE:
return self._node_data.is_webapp_enabled()
return False
def _display_in_ui(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER:
return True
return self._node_data.is_webapp_enabled()
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
dify_ctx = self.require_dify_context()
invoke_from = self._invoke_from_value()
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}:
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
return [
apply_debug_email_recipient(
method,
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
user_id=self.user_id or "",
enabled=invoke_from == _INVOKE_FROM_DEBUGGER,
user_id=dify_ctx.user_id,
)
for method in enabled_methods
]
def _invoke_from_value(self) -> str:
invoke_from = self.require_dify_context().invoke_from
if isinstance(invoke_from, str):
return invoke_from
return str(getattr(invoke_from, "value", invoke_from))
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
resolved_default_values = self.resolve_default_values()
@ -212,10 +223,11 @@ class HumanInputNode(Node[HumanInputNodeData]):
"""
repo = self._form_repository
form = repo.get_form(self._workflow_execution_id, self.id)
dify_ctx = self.require_dify_context()
if form is None:
display_in_ui = self._display_in_ui()
params = FormCreateParams(
app_id=self.app_id,
app_id=dify_ctx.app_id,
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
@ -225,7 +237,9 @@ class HumanInputNode(Node[HumanInputNodeData]):
resolved_default_values=self.resolve_default_values(),
console_recipient_required=self._should_require_console_recipient(),
console_creator_account_id=(
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
dify_ctx.user_id
if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}
else None
),
backstage_recipient_required=True,
)

View File

@ -587,24 +587,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.entities import GraphInitParams
from dify_graph.graph import Graph
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
from dify_graph.graph_engine.command_channels import InMemoryChannel
from dify_graph.runtime import GraphRuntimeState
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState
# Create GraphInitParams from node attributes
# Create GraphInitParams for child graph execution.
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
run_context=self.run_context,
call_depth=self.workflow_call_depth,
)
# Create a deep copy of the variable pool for each iteration
@ -621,28 +611,17 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
total_tokens=0,
node_run_steps=0,
)
root_node_id = self.node_data.start_node_id
if root_node_id is None:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=iteration_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine
try:
return self.graph_runtime_state.create_child_engine(
workflow_id=self.workflow_id,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state_copy,
graph_config=self.graph_config,
root_node_id=root_node_id,
)
except ChildGraphNotFoundError as exc:
raise IterationGraphNotFoundError("iteration graph not found") from exc

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.template import Template
@ -20,6 +20,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
_INVOKE_FROM_DEBUGGER = "debugger"
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
@ -58,7 +59,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False
invoke_from_value = str(invoke_from.value) if invoke_from else None
is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER
chunks = variable.value
variables = {"chunks": chunks}

View File

@ -66,9 +66,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
)
self._llm_file_saver = llm_file_saver
@ -160,6 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[Source], LLMUsage]:
dify_ctx = self.require_dify_context()
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
@ -176,10 +178,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
model = node_data.single_retrieval_config.model
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
user_from=self.user_from.value,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
user_from=dify_ctx.user_from.value,
dataset_ids=dataset_ids,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
completion_params=model.completion_params,
@ -229,10 +231,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
app_id=dify_ctx.app_id,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
user_from=dify_ctx.user_from.value,
dataset_ids=dataset_ids,
query=query,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,

View File

@ -145,9 +145,10 @@ class LLMNode(Node[LLMNodeData]):
self._memory = memory
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
)
self._llm_file_saver = llm_file_saver
@ -242,7 +243,7 @@ class LLMNode(Node[LLMNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
user_id=self.require_dify_context().user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self.node_data.structured_output,
file_saver=self._llm_file_saver,
@ -702,7 +703,7 @@ class LLMNode(Node[LLMNodeData]):
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
tenant_id=self.require_dify_context().tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,

View File

@ -412,24 +412,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return build_segment_with_type(var_type, value)
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.entities import GraphInitParams
from dify_graph.graph import Graph
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
from dify_graph.graph_engine.command_channels import InMemoryChannel
from dify_graph.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
# Create GraphInitParams for child graph execution.
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
run_context=self.run_context,
call_depth=self.workflow_call_depth,
)
@ -439,22 +429,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
start_at=start_at.timestamp(),
)
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the loop graph with the new node factory
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
return self.graph_runtime_state.create_child_engine(
workflow_id=self.workflow_id,
graph=loop_graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
graph_config=self.graph_config,
root_node_id=root_node_id,
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@ -297,7 +297,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
tools=tools,
stop=list(stop),
stream=False,
user=self.user_id,
user=self.require_dify_context().user_id,
)
# handle invoke result

View File

@ -86,9 +86,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._memory = memory
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
)
self._llm_file_saver = llm_file_saver
@ -160,7 +161,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
user_id=self.require_dify_context().user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,

View File

@ -56,6 +56,8 @@ class ToolNode(Node[ToolNodeData]):
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
dify_ctx = self.require_dify_context()
# fetch tool icon
tool_info = {
"provider_type": self.node_data.provider_type.value,
@ -75,7 +77,12 @@ class ToolNode(Node[ToolNodeData]):
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
dify_ctx.tenant_id,
dify_ctx.app_id,
self._node_id,
self.node_data,
dify_ctx.invoke_from,
variable_pool,
)
except ToolNodeError as e:
yield StreamCompletedEvent(
@ -109,10 +116,10 @@ class ToolNode(Node[ToolNodeData]):
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
user_id=dify_ctx.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=self.app_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except ToolNodeError as e:
@ -133,8 +140,8 @@ class ToolNode(Node[ToolNodeData]):
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_id=self._node_id,
tool_runtime=tool_runtime,
)

View File

@ -69,6 +69,7 @@ class TriggerWebhookNode(Node[WebhookData]):
)
def generate_file_var(self, param_name: str, file: dict):
dify_ctx = self.require_dify_context()
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
@ -84,7 +85,7 @@ class TriggerWebhookNode(Node[WebhookData]):
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=self.tenant_id,
tenant_id=dify_ctx.tenant_id,
)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])

View File

@ -1,9 +1,17 @@
from .graph_runtime_state import GraphRuntimeState
from .graph_runtime_state import (
ChildEngineBuilderNotConfiguredError,
ChildEngineError,
ChildGraphNotFoundError,
GraphRuntimeState,
)
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
from .variable_pool import VariablePool, VariableValue
__all__ = [
"ChildEngineBuilderNotConfiguredError",
"ChildEngineError",
"ChildGraphNotFoundError",
"GraphRuntimeState",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",

View File

@ -15,6 +15,7 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.entities.pause_reason import PauseReason
@ -135,6 +136,31 @@ class GraphProtocol(Protocol):
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
class ChildGraphEngineBuilderProtocol(Protocol):
def build_child_engine(
self,
*,
workflow_id: str,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
graph_config: Mapping[str, Any],
root_node_id: str,
layers: Sequence[object] = (),
) -> Any: ...
class ChildEngineError(ValueError):
"""Base error type for child-engine creation failures."""
class ChildEngineBuilderNotConfiguredError(ChildEngineError):
"""Raised when child-engine creation is requested without a bound builder."""
class ChildGraphNotFoundError(ChildEngineError):
"""Raised when the requested child graph entry point cannot be resolved."""
class _GraphStateSnapshot(BaseModel):
"""Serializable graph state snapshot for node/edge states."""
@ -209,6 +235,7 @@ class GraphRuntimeState:
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None
# Node and edges states needed to be restored into
# graph object.
@ -250,6 +277,31 @@ class GraphRuntimeState:
if self._graph is not None:
_ = self.response_coordinator
def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None:
self._child_engine_builder = builder
def create_child_engine(
self,
*,
workflow_id: str,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
graph_config: Mapping[str, Any],
root_node_id: str,
layers: Sequence[object] = (),
) -> Any:
if self._child_engine_builder is None:
raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.")
return self._child_engine_builder.build_child_engine(
workflow_id=workflow_id,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
graph_config=graph_config,
root_node_id=root_node_id,
layers=layers,
)
# ------------------------------------------------------------------
# Primary collaborators
# ------------------------------------------------------------------