mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor(dify_graph): introduce run_context and delegate child engine creation (#32964)
This commit is contained in:
@ -3,7 +3,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.enums import InvokeFrom, UserFrom
|
||||
DIFY_RUN_CONTEXT_KEY = "_dify"
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
@ -18,11 +18,7 @@ class GraphInitParams(BaseModel):
|
||||
"""
|
||||
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
run_context: Mapping[str, Any] = Field(..., description="runtime context")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
|
||||
@ -33,39 +33,6 @@ class SystemVariableKey(StrEnum):
|
||||
INVOKE_FROM = "invoke_from"
|
||||
|
||||
|
||||
class UserFrom(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
"""Get source of invoke from.
|
||||
|
||||
:return: source
|
||||
"""
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
|
||||
@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from dify_graph.context import capture_current_context
|
||||
@ -27,6 +27,7 @@ from dify_graph.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
|
||||
from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from dify_graph.runtime.graph_runtime_state import GraphProtocol
|
||||
@ -49,6 +50,7 @@ from .protocols.command_channel import CommandChannel
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.graph_engine.domain.graph_execution import GraphExecution
|
||||
from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
@ -74,6 +76,7 @@ class GraphEngine:
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: CommandChannel,
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
@ -83,6 +86,9 @@ class GraphEngine:
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
self._config = config
|
||||
self._child_engine_builder = child_engine_builder
|
||||
if child_engine_builder is not None:
|
||||
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
@ -214,6 +220,25 @@ class GraphEngine:
|
||||
self._bind_layer_context(layer)
|
||||
return self
|
||||
|
||||
def create_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: dict[str, object] | Mapping[str, object],
|
||||
root_node_id: str,
|
||||
layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (),
|
||||
) -> GraphEngine:
|
||||
return self._graph_runtime_state.create_child_engine(
|
||||
workflow_id=workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Execute the graph using the modular architecture.
|
||||
|
||||
@ -80,9 +80,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
@ -120,8 +122,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
@ -144,8 +146,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
@ -283,8 +285,13 @@ class AgentNode(Node[AgentNodeData]):
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
dify_ctx = self.require_dify_context()
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
entity,
|
||||
dify_ctx.invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
@ -396,7 +403,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
dify_ctx = self.require_dify_context()
|
||||
plugins = manager.list_plugins(dify_ctx.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
@ -417,8 +425,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
@ -429,9 +440,10 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
@ -440,7 +452,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
|
||||
@ -8,10 +8,11 @@ from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
@ -64,10 +65,28 @@ from libs.datetime_utils import naive_utc_now
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyRunContextProtocol(Protocol):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: Any
|
||||
invoke_from: Any
|
||||
|
||||
|
||||
class _MappingDifyRunContext:
|
||||
def __init__(self, mapping: Mapping[str, Any]) -> None:
|
||||
self.tenant_id = str(mapping["tenant_id"])
|
||||
self.app_id = str(mapping["app_id"])
|
||||
self.user_id = str(mapping["user_id"])
|
||||
self.user_from = mapping["user_from"]
|
||||
self.invoke_from = mapping["invoke_from"]
|
||||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
"""BaseNode serves as the foundational class for all node implementations.
|
||||
|
||||
@ -227,14 +246,10 @@ class Node(Generic[NodeDataT]):
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self._run_context = MappingProxyType(dict(graph_init_params.run_context))
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = graph_init_params.user_from
|
||||
self.invoke_from = graph_init_params.invoke_from
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
@ -263,6 +278,38 @@ class Node(Generic[NodeDataT]):
|
||||
def graph_init_params(self) -> GraphInitParams:
|
||||
return self._graph_init_params
|
||||
|
||||
@property
|
||||
def run_context(self) -> Mapping[str, Any]:
|
||||
return self._run_context
|
||||
|
||||
def get_run_context_value(self, key: str, default: Any = None) -> Any:
|
||||
return self._run_context.get(key, default)
|
||||
|
||||
def require_run_context_value(self, key: str) -> Any:
|
||||
value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE)
|
||||
if value is _MISSING_RUN_CONTEXT_VALUE:
|
||||
raise ValueError(f"run_context missing required key: {key}")
|
||||
return value
|
||||
|
||||
def require_dify_context(self) -> DifyRunContextProtocol:
|
||||
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
|
||||
if raw_ctx is None:
|
||||
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
|
||||
|
||||
if isinstance(raw_ctx, Mapping):
|
||||
missing_keys = [
|
||||
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
|
||||
]
|
||||
if missing_keys:
|
||||
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
|
||||
return _MappingDifyRunContext(raw_ctx)
|
||||
|
||||
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
|
||||
if not hasattr(raw_ctx, attr):
|
||||
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
|
||||
|
||||
return cast(DifyRunContextProtocol, raw_ctx)
|
||||
|
||||
@property
|
||||
def execution_id(self) -> str:
|
||||
return self._node_execution_id
|
||||
|
||||
@ -52,6 +52,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
@ -75,7 +76,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
datasource_info["icon"] = self.datasource_manager.get_icon_url(
|
||||
provider_id=provider_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
datasource_type=datasource_type.value,
|
||||
)
|
||||
|
||||
@ -104,11 +105,11 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
yield from self.datasource_manager.stream_node_events(
|
||||
node_id=self._node_id,
|
||||
user_id=self.user_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
datasource_type=datasource_type.value,
|
||||
provider_id=provider_id,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=credential_id,
|
||||
@ -136,7 +137,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
|
||||
file_info = self.datasource_manager.get_upload_file_by_id(
|
||||
file_id=related_id, tenant_id=self.tenant_id
|
||||
file_id=related_id, tenant_id=dify_ctx.tenant_id
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
|
||||
@ -212,6 +212,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
"""
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
dify_ctx = self.require_dify_context()
|
||||
files: list[File] = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
@ -236,8 +237,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
tool_file_manager = self._tool_file_manager_factory()
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=mime_type,
|
||||
@ -249,7 +250,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
@ -31,6 +31,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
_INVOKE_FROM_DEBUGGER = "debugger"
|
||||
_INVOKE_FROM_EXPLORE = "explore"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -155,30 +157,39 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
return resolved_defaults
|
||||
|
||||
def _should_require_console_recipient(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
invoke_from = self._invoke_from_value()
|
||||
if invoke_from == _INVOKE_FROM_DEBUGGER:
|
||||
return True
|
||||
if self.invoke_from == InvokeFrom.EXPLORE:
|
||||
if invoke_from == _INVOKE_FROM_EXPLORE:
|
||||
return self._node_data.is_webapp_enabled()
|
||||
return False
|
||||
|
||||
def _display_in_ui(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER:
|
||||
return True
|
||||
return self._node_data.is_webapp_enabled()
|
||||
|
||||
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
invoke_from = self._invoke_from_value()
|
||||
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
|
||||
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}:
|
||||
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
|
||||
return [
|
||||
apply_debug_email_recipient(
|
||||
method,
|
||||
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
|
||||
user_id=self.user_id or "",
|
||||
enabled=invoke_from == _INVOKE_FROM_DEBUGGER,
|
||||
user_id=dify_ctx.user_id,
|
||||
)
|
||||
for method in enabled_methods
|
||||
]
|
||||
|
||||
def _invoke_from_value(self) -> str:
|
||||
invoke_from = self.require_dify_context().invoke_from
|
||||
if isinstance(invoke_from, str):
|
||||
return invoke_from
|
||||
return str(getattr(invoke_from, "value", invoke_from))
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
resolved_default_values = self.resolve_default_values()
|
||||
@ -212,10 +223,11 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"""
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
dify_ctx = self.require_dify_context()
|
||||
if form is None:
|
||||
display_in_ui = self._display_in_ui()
|
||||
params = FormCreateParams(
|
||||
app_id=self.app_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
@ -225,7 +237,9 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
resolved_default_values=self.resolve_default_values(),
|
||||
console_recipient_required=self._should_require_console_recipient(),
|
||||
console_creator_account_id=(
|
||||
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
|
||||
dify_ctx.user_id
|
||||
if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}
|
||||
else None
|
||||
),
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
|
||||
@ -587,24 +587,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
return
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
# Create GraphInitParams for child graph execution.
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
run_context=self.run_context,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
# Create a deep copy of the variable pool for each iteration
|
||||
@ -621,28 +611,17 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
root_node_id = self.node_data.start_node_id
|
||||
if root_node_id is None:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the iteration graph with the new node factory
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=iteration_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
try:
|
||||
return self.graph_runtime_state.create_child_engine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
graph_config=self.graph_config,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
except ChildGraphNotFoundError as exc:
|
||||
raise IterationGraphNotFoundError("iteration graph not found") from exc
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.template import Template
|
||||
@ -20,6 +20,7 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_INVOKE_FROM_DEBUGGER = "debugger"
|
||||
|
||||
|
||||
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
@ -58,7 +59,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False
|
||||
invoke_from_value = str(invoke_from.value) if invoke_from else None
|
||||
is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER
|
||||
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
|
||||
@ -66,9 +66,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@ -160,6 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||
) -> tuple[list[Source], LLMUsage]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
dataset_ids = node_data.dataset_ids
|
||||
query = variables.get("query")
|
||||
attachments = variables.get("attachments")
|
||||
@ -176,10 +178,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
model = node_data.single_retrieval_config.model
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
user_from=dify_ctx.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
|
||||
completion_params=model.completion_params,
|
||||
@ -229,10 +231,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
app_id=dify_ctx.app_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
user_from=dify_ctx.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
query=query,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
|
||||
|
||||
@ -145,9 +145,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@ -242,7 +243,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=self.node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
@ -702,7 +703,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=self.require_dify_context().tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
|
||||
@ -412,24 +412,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
# Create GraphInitParams for child graph execution.
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
run_context=self.run_context,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
|
||||
@ -439,22 +429,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
start_at=start_at.timestamp(),
|
||||
)
|
||||
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the loop graph with the new node factory
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
return self.graph_runtime_state.create_child_engine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=loop_graph,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
graph_config=self.graph_config,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@ -297,7 +297,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
tools=tools,
|
||||
stop=list(stop),
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
user=self.require_dify_context().user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
||||
@ -86,9 +86,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@ -160,7 +161,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
|
||||
@ -56,6 +56,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
"provider_type": self.node_data.provider_type.value,
|
||||
@ -75,7 +77,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
self._node_id,
|
||||
self.node_data,
|
||||
dify_ctx.invoke_from,
|
||||
variable_pool,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
@ -109,10 +116,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.app_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
@ -133,8 +140,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_id=self._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
|
||||
@ -69,6 +69,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
)
|
||||
|
||||
def generate_file_var(self, param_name: str, file: dict):
|
||||
dify_ctx = self.require_dify_context()
|
||||
related_id = file.get("related_id")
|
||||
transfer_method_value = file.get("transfer_method")
|
||||
if transfer_method_value:
|
||||
@ -84,7 +85,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
try:
|
||||
file_obj = file_factory.build_from_mapping(
|
||||
mapping=file,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
|
||||
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .graph_runtime_state import (
|
||||
ChildEngineBuilderNotConfiguredError,
|
||||
ChildEngineError,
|
||||
ChildGraphNotFoundError,
|
||||
GraphRuntimeState,
|
||||
)
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
|
||||
__all__ = [
|
||||
"ChildEngineBuilderNotConfiguredError",
|
||||
"ChildEngineError",
|
||||
"ChildGraphNotFoundError",
|
||||
"GraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
|
||||
@ -15,6 +15,7 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
@ -135,6 +136,31 @@ class GraphProtocol(Protocol):
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
class ChildGraphEngineBuilderProtocol(Protocol):
|
||||
def build_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: str,
|
||||
layers: Sequence[object] = (),
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class ChildEngineError(ValueError):
|
||||
"""Base error type for child-engine creation failures."""
|
||||
|
||||
|
||||
class ChildEngineBuilderNotConfiguredError(ChildEngineError):
|
||||
"""Raised when child-engine creation is requested without a bound builder."""
|
||||
|
||||
|
||||
class ChildGraphNotFoundError(ChildEngineError):
|
||||
"""Raised when the requested child graph entry point cannot be resolved."""
|
||||
|
||||
|
||||
class _GraphStateSnapshot(BaseModel):
|
||||
"""Serializable graph state snapshot for node/edge states."""
|
||||
|
||||
@ -209,6 +235,7 @@ class GraphRuntimeState:
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self._deferred_nodes: set[str] = set()
|
||||
self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None
|
||||
|
||||
# Node and edges states needed to be restored into
|
||||
# graph object.
|
||||
@ -250,6 +277,31 @@ class GraphRuntimeState:
|
||||
if self._graph is not None:
|
||||
_ = self.response_coordinator
|
||||
|
||||
def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None:
|
||||
self._child_engine_builder = builder
|
||||
|
||||
def create_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: str,
|
||||
layers: Sequence[object] = (),
|
||||
) -> Any:
|
||||
if self._child_engine_builder is None:
|
||||
raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.")
|
||||
|
||||
return self._child_engine_builder.build_child_engine(
|
||||
workflow_id=workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Primary collaborators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user