feat: add mention type variable

This commit is contained in:
Novice
2026-01-12 17:39:36 +08:00
parent d65ae68668
commit bb190f9610
23 changed files with 457 additions and 439 deletions

View File

@ -4,10 +4,8 @@ from .entities import (
BaseLoopNodeData,
BaseLoopState,
BaseNodeData,
VirtualNodeConfig,
)
from .usage_tracking_mixin import LLMUsageTrackingMixin
from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor
__all__ = [
"BaseIterationNodeData",
@ -16,7 +14,4 @@ __all__ = [
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
"VirtualNodeConfig",
"VirtualNodeExecutionError",
"VirtualNodeExecutor",
]

View File

@ -167,24 +167,6 @@ class DefaultValue(BaseModel):
return self
class VirtualNodeConfig(BaseModel):
"""Configuration for a virtual sub-node embedded within a parent node."""
# Local ID within parent node (e.g., "ext_1")
# Will be converted to global ID: "{parent_id}.{id}"
id: str
# Node type (e.g., "llm", "code", "tool")
type: str
# Full node data configuration
data: dict[str, Any] = {}
def get_global_id(self, parent_node_id: str) -> str:
"""Get the global node ID by combining parent ID and local ID."""
return f"{parent_node_id}.{self.id}"
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
@ -193,8 +175,15 @@ class BaseNodeData(ABC, BaseModel):
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
# Virtual sub-nodes that execute before the main node
virtual_nodes: list[VirtualNodeConfig] = []
# Parent node ID when this node is used as an extractor.
# If set, this node is an "attached" extractor node that extracts values
# from list[PromptMessage] for the parent node's parameters.
parent_node_id: str | None = None
@property
def is_extractor_node(self) -> bool:
"""Check if this node is an extractor node (has parent_node_id)."""
return self.parent_node_id is not None
@property
def default_value_dict(self) -> dict[str, Any]:

View File

@ -229,7 +229,6 @@ class Node(Generic[NodeDataT]):
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
@ -271,51 +270,81 @@ class Node(Generic[NodeDataT]):
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
"""
Execute all virtual sub-nodes defined in node configuration.
Virtual nodes are complete node definitions that execute before the main node.
Each virtual node:
- Has its own global ID: "{parent_id}.{local_id}"
- Generates standard node events
- Stores outputs in the variable pool (via event handling)
- Supports retry via parent node's retry config
Find all extractor node configurations that have parent_node_id == self._node_id.
Returns:
dict mapping local_id -> outputs dict
List of node configuration dicts for extractor nodes
"""
from .virtual_node_executor import VirtualNodeExecutor
nodes = self.graph_config.get("nodes", [])
extractor_configs = []
for node_config in nodes:
node_data = node_config.get("data", {})
if node_data.get("parent_node_id") == self._node_id:
extractor_configs.append(node_config)
return extractor_configs
virtual_nodes = self.node_data.virtual_nodes
if not virtual_nodes:
return {}
executor = VirtualNodeExecutor(
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
parent_node_id=self._node_id,
parent_retry_config=self.retry_config,
)
return (yield from executor.execute_virtual_nodes(virtual_nodes))
@property
def virtual_node_outputs(self) -> dict[str, Any]:
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
"""
Get the outputs from virtual sub-nodes.
Execute all extractor nodes associated with this node.
Returns:
dict mapping local_id -> outputs dict
Extractor nodes are nodes with parent_node_id == self._node_id.
They are executed before the main node to extract values from list[PromptMessage].
"""
return self._virtual_node_outputs
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
extractor_configs = self._find_extractor_node_configs()
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
if not extractor_configs:
return
for config in extractor_configs:
node_id = config.get("id")
node_data = config.get("data", {})
node_type_str = node_data.get("type")
if not node_id or not node_type_str:
continue
# Get node class
try:
node_type = NodeType(node_type_str)
except ValueError:
continue
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
continue
node_version = str(node_data.get("version", "1"))
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
if not node_cls:
continue
# Instantiate and execute the extractor node
extractor_node = node_cls(
id=node_id,
config=config,
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Execute and process extractor node events
for event in extractor_node.run():
if isinstance(event, NodeRunSucceededEvent):
# Store extractor node outputs in variable pool
outputs = event.node_run_result.outputs
for variable_name, variable_value in outputs.items():
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
yield event
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Step 1: Execute virtual sub-nodes before main node execution
self._virtual_node_outputs = yield from self._execute_virtual_nodes()
# Step 1: Execute associated extractor nodes before main node execution
yield from self._execute_extractor_nodes()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(

View File

@ -1,213 +0,0 @@
"""
Virtual Node Executor for running embedded sub-nodes within a parent node.
This module handles the execution of virtual nodes defined in a parent node's
`virtual_nodes` configuration. Virtual nodes are complete node definitions
that execute before the parent node.
Example configuration:
virtual_nodes:
- id: ext_1
type: llm
data:
model: {...}
prompt_template: [...]
"""
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
from uuid import uuid4
from core.workflow.enums import NodeType
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunFailedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from libs.datetime_utils import naive_utc_now
from .entities import RetryConfig, VirtualNodeConfig
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class VirtualNodeExecutionError(Exception):
"""Error during virtual node execution"""
def __init__(self, node_id: str, original_error: Exception):
self.node_id = node_id
self.original_error = original_error
super().__init__(f"Virtual node {node_id} execution failed: {original_error}")
class VirtualNodeExecutor:
"""
Executes virtual sub-nodes embedded within a parent node.
Virtual nodes are complete node definitions that execute before the parent node.
Each virtual node:
- Has its own global ID: "{parent_id}.{local_id}"
- Generates standard node events
- Stores outputs in the variable pool
- Supports retry via parent node's retry config
"""
def __init__(
self,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
parent_node_id: str,
parent_retry_config: RetryConfig | None = None,
):
self._graph_init_params = graph_init_params
self._graph_runtime_state = graph_runtime_state
self._parent_node_id = parent_node_id
self._parent_retry_config = parent_retry_config or RetryConfig()
def execute_virtual_nodes(
self,
virtual_nodes: list[VirtualNodeConfig],
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute all virtual nodes in order.
Args:
virtual_nodes: List of virtual node configurations
Yields:
Node events from each virtual node execution
Returns:
dict mapping local_id -> outputs dict
"""
results: dict[str, Any] = {}
for vnode_config in virtual_nodes:
global_id = vnode_config.get_global_id(self._parent_node_id)
# Execute with retry
outputs = yield from self._execute_with_retry(vnode_config, global_id)
results[vnode_config.id] = outputs
return results
def _execute_with_retry(
self,
vnode_config: VirtualNodeConfig,
global_id: str,
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute virtual node with retry support.
"""
retry_config = self._parent_retry_config
last_error: Exception | None = None
for attempt in range(retry_config.max_retries + 1):
try:
return (yield from self._execute_single_node(vnode_config, global_id))
except Exception as e:
last_error = e
if attempt < retry_config.max_retries:
# Yield retry event
yield NodeRunRetryEvent(
id=str(uuid4()),
node_id=global_id,
node_type=self._get_node_type(vnode_config.type),
node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
start_at=naive_utc_now(),
error=str(e),
retry_index=attempt + 1,
)
time.sleep(retry_config.retry_interval_seconds)
continue
raise VirtualNodeExecutionError(global_id, e) from e
raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error"))
def _execute_single_node(
self,
vnode_config: VirtualNodeConfig,
global_id: str,
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute a single virtual node by instantiating and running it.
"""
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
# Build node config
node_config: dict[str, Any] = {
"id": global_id,
"data": {
**vnode_config.data,
"title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
},
}
# Get the node class for this type
node_type = self._get_node_type(vnode_config.type)
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
node_version = str(vnode_config.data.get("version", "1"))
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
if not node_cls:
raise ValueError(f"No class found for node type: {node_type}")
# Instantiate the node
node = node_cls(
id=global_id,
config=node_config,
graph_init_params=self._graph_init_params,
graph_runtime_state=self._graph_runtime_state,
)
# Run and collect events
outputs: dict[str, Any] = {}
for event in node.run():
# Mark event as coming from virtual node
self._mark_event_as_virtual(event, vnode_config)
yield event
if isinstance(event, NodeRunSucceededEvent):
outputs = event.node_run_result.outputs or {}
elif isinstance(event, NodeRunFailedEvent):
raise Exception(event.error or "Virtual node execution failed")
return outputs
def _mark_event_as_virtual(
self,
event: GraphNodeEventBase,
vnode_config: VirtualNodeConfig,
) -> None:
"""Mark event as coming from a virtual node."""
if isinstance(event, NodeRunStartedEvent):
event.is_virtual = True
event.parent_node_id = self._parent_node_id
def _get_node_type(self, type_str: str) -> NodeType:
"""Convert type string to NodeType enum."""
type_mapping = {
"llm": NodeType.LLM,
"code": NodeType.CODE,
"tool": NodeType.TOOL,
"if-else": NodeType.IF_ELSE,
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"template-transform": NodeType.TEMPLATE_TRANSFORM,
"variable-assigner": NodeType.VARIABLE_ASSIGNER,
"http-request": NodeType.HTTP_REQUEST,
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
}
return type_mapping.get(type_str, NodeType.LLM)