mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
feat: add mention type variable
This commit is contained in:
@ -4,10 +4,8 @@ from .entities import (
|
||||
BaseLoopNodeData,
|
||||
BaseLoopState,
|
||||
BaseNodeData,
|
||||
VirtualNodeConfig,
|
||||
)
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
@ -16,7 +14,4 @@ __all__ = [
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
"VirtualNodeConfig",
|
||||
"VirtualNodeExecutionError",
|
||||
"VirtualNodeExecutor",
|
||||
]
|
||||
|
||||
@ -167,24 +167,6 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class VirtualNodeConfig(BaseModel):
|
||||
"""Configuration for a virtual sub-node embedded within a parent node."""
|
||||
|
||||
# Local ID within parent node (e.g., "ext_1")
|
||||
# Will be converted to global ID: "{parent_id}.{id}"
|
||||
id: str
|
||||
|
||||
# Node type (e.g., "llm", "code", "tool")
|
||||
type: str
|
||||
|
||||
# Full node data configuration
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
def get_global_id(self, parent_node_id: str) -> str:
|
||||
"""Get the global node ID by combining parent ID and local ID."""
|
||||
return f"{parent_node_id}.{self.id}"
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
@ -193,8 +175,15 @@ class BaseNodeData(ABC, BaseModel):
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
# Virtual sub-nodes that execute before the main node
|
||||
virtual_nodes: list[VirtualNodeConfig] = []
|
||||
# Parent node ID when this node is used as an extractor.
|
||||
# If set, this node is an "attached" extractor node that extracts values
|
||||
# from list[PromptMessage] for the parent node's parameters.
|
||||
parent_node_id: str | None = None
|
||||
|
||||
@property
|
||||
def is_extractor_node(self) -> bool:
|
||||
"""Check if this node is an extractor node (has parent_node_id)."""
|
||||
return self.parent_node_id is not None
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -229,7 +229,6 @@ class Node(Generic[NodeDataT]):
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
@ -271,51 +270,81 @@ class Node(Generic[NodeDataT]):
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Execute all virtual sub-nodes defined in node configuration.
|
||||
|
||||
Virtual nodes are complete node definitions that execute before the main node.
|
||||
Each virtual node:
|
||||
- Has its own global ID: "{parent_id}.{local_id}"
|
||||
- Generates standard node events
|
||||
- Stores outputs in the variable pool (via event handling)
|
||||
- Supports retry via parent node's retry config
|
||||
Find all extractor node configurations that have parent_node_id == self._node_id.
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
List of node configuration dicts for extractor nodes
|
||||
"""
|
||||
from .virtual_node_executor import VirtualNodeExecutor
|
||||
nodes = self.graph_config.get("nodes", [])
|
||||
extractor_configs = []
|
||||
for node_config in nodes:
|
||||
node_data = node_config.get("data", {})
|
||||
if node_data.get("parent_node_id") == self._node_id:
|
||||
extractor_configs.append(node_config)
|
||||
return extractor_configs
|
||||
|
||||
virtual_nodes = self.node_data.virtual_nodes
|
||||
if not virtual_nodes:
|
||||
return {}
|
||||
|
||||
executor = VirtualNodeExecutor(
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
parent_node_id=self._node_id,
|
||||
parent_retry_config=self.retry_config,
|
||||
)
|
||||
|
||||
return (yield from executor.execute_virtual_nodes(virtual_nodes))
|
||||
|
||||
@property
|
||||
def virtual_node_outputs(self) -> dict[str, Any]:
|
||||
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
"""
|
||||
Get the outputs from virtual sub-nodes.
|
||||
Execute all extractor nodes associated with this node.
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
Extractor nodes are nodes with parent_node_id == self._node_id.
|
||||
They are executed before the main node to extract values from list[PromptMessage].
|
||||
"""
|
||||
return self._virtual_node_outputs
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
extractor_configs = self._find_extractor_node_configs()
|
||||
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
|
||||
if not extractor_configs:
|
||||
return
|
||||
|
||||
for config in extractor_configs:
|
||||
node_id = config.get("id")
|
||||
node_data = config.get("data", {})
|
||||
node_type_str = node_data.get("type")
|
||||
|
||||
if not node_id or not node_type_str:
|
||||
continue
|
||||
|
||||
# Get node class
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
continue
|
||||
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
# Instantiate and execute the extractor node
|
||||
extractor_node = node_cls(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
# Execute and process extractor node events
|
||||
for event in extractor_node.run():
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
# Store extractor node outputs in variable pool
|
||||
outputs = event.node_run_result.outputs
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
yield event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Step 1: Execute virtual sub-nodes before main node execution
|
||||
self._virtual_node_outputs = yield from self._execute_virtual_nodes()
|
||||
# Step 1: Execute associated extractor nodes before main node execution
|
||||
yield from self._execute_extractor_nodes()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
|
||||
@ -1,213 +0,0 @@
|
||||
"""
|
||||
Virtual Node Executor for running embedded sub-nodes within a parent node.
|
||||
|
||||
This module handles the execution of virtual nodes defined in a parent node's
|
||||
`virtual_nodes` configuration. Virtual nodes are complete node definitions
|
||||
that execute before the parent node.
|
||||
|
||||
Example configuration:
|
||||
virtual_nodes:
|
||||
- id: ext_1
|
||||
type: llm
|
||||
data:
|
||||
model: {...}
|
||||
prompt_template: [...]
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import RetryConfig, VirtualNodeConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class VirtualNodeExecutionError(Exception):
|
||||
"""Error during virtual node execution"""
|
||||
|
||||
def __init__(self, node_id: str, original_error: Exception):
|
||||
self.node_id = node_id
|
||||
self.original_error = original_error
|
||||
super().__init__(f"Virtual node {node_id} execution failed: {original_error}")
|
||||
|
||||
|
||||
class VirtualNodeExecutor:
|
||||
"""
|
||||
Executes virtual sub-nodes embedded within a parent node.
|
||||
|
||||
Virtual nodes are complete node definitions that execute before the parent node.
|
||||
Each virtual node:
|
||||
- Has its own global ID: "{parent_id}.{local_id}"
|
||||
- Generates standard node events
|
||||
- Stores outputs in the variable pool
|
||||
- Supports retry via parent node's retry config
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
parent_node_id: str,
|
||||
parent_retry_config: RetryConfig | None = None,
|
||||
):
|
||||
self._graph_init_params = graph_init_params
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._parent_node_id = parent_node_id
|
||||
self._parent_retry_config = parent_retry_config or RetryConfig()
|
||||
|
||||
def execute_virtual_nodes(
|
||||
self,
|
||||
virtual_nodes: list[VirtualNodeConfig],
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute all virtual nodes in order.
|
||||
|
||||
Args:
|
||||
virtual_nodes: List of virtual node configurations
|
||||
|
||||
Yields:
|
||||
Node events from each virtual node execution
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
"""
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
for vnode_config in virtual_nodes:
|
||||
global_id = vnode_config.get_global_id(self._parent_node_id)
|
||||
|
||||
# Execute with retry
|
||||
outputs = yield from self._execute_with_retry(vnode_config, global_id)
|
||||
results[vnode_config.id] = outputs
|
||||
|
||||
return results
|
||||
|
||||
def _execute_with_retry(
|
||||
self,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
global_id: str,
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute virtual node with retry support.
|
||||
"""
|
||||
retry_config = self._parent_retry_config
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(retry_config.max_retries + 1):
|
||||
try:
|
||||
return (yield from self._execute_single_node(vnode_config, global_id))
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt < retry_config.max_retries:
|
||||
# Yield retry event
|
||||
yield NodeRunRetryEvent(
|
||||
id=str(uuid4()),
|
||||
node_id=global_id,
|
||||
node_type=self._get_node_type(vnode_config.type),
|
||||
node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
|
||||
start_at=naive_utc_now(),
|
||||
error=str(e),
|
||||
retry_index=attempt + 1,
|
||||
)
|
||||
|
||||
time.sleep(retry_config.retry_interval_seconds)
|
||||
continue
|
||||
|
||||
raise VirtualNodeExecutionError(global_id, e) from e
|
||||
|
||||
raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error"))
|
||||
|
||||
def _execute_single_node(
|
||||
self,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
global_id: str,
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute a single virtual node by instantiating and running it.
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
# Build node config
|
||||
node_config: dict[str, Any] = {
|
||||
"id": global_id,
|
||||
"data": {
|
||||
**vnode_config.data,
|
||||
"title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
|
||||
},
|
||||
}
|
||||
|
||||
# Get the node class for this type
|
||||
node_type = self._get_node_type(vnode_config.type)
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
node_version = str(vnode_config.data.get("version", "1"))
|
||||
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
raise ValueError(f"No class found for node type: {node_type}")
|
||||
|
||||
# Instantiate the node
|
||||
node = node_cls(
|
||||
id=global_id,
|
||||
config=node_config,
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
)
|
||||
|
||||
# Run and collect events
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
for event in node.run():
|
||||
# Mark event as coming from virtual node
|
||||
self._mark_event_as_virtual(event, vnode_config)
|
||||
yield event
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
outputs = event.node_run_result.outputs or {}
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
raise Exception(event.error or "Virtual node execution failed")
|
||||
|
||||
return outputs
|
||||
|
||||
def _mark_event_as_virtual(
|
||||
self,
|
||||
event: GraphNodeEventBase,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
) -> None:
|
||||
"""Mark event as coming from a virtual node."""
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
event.is_virtual = True
|
||||
event.parent_node_id = self._parent_node_id
|
||||
|
||||
def _get_node_type(self, type_str: str) -> NodeType:
|
||||
"""Convert type string to NodeType enum."""
|
||||
type_mapping = {
|
||||
"llm": NodeType.LLM,
|
||||
"code": NodeType.CODE,
|
||||
"tool": NodeType.TOOL,
|
||||
"if-else": NodeType.IF_ELSE,
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||
"template-transform": NodeType.TEMPLATE_TRANSFORM,
|
||||
"variable-assigner": NodeType.VARIABLE_ASSIGNER,
|
||||
"http-request": NodeType.HTTP_REQUEST,
|
||||
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||
}
|
||||
return type_mapping.get(type_str, NodeType.LLM)
|
||||
Reference in New Issue
Block a user