mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 14:08:18 +08:00
feat: add mention type variable
This commit is contained in:
@ -1047,6 +1047,8 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
if not isinstance(tool_input.value, list):
|
||||
raise ToolParameterError(f"Invalid variable selector for {parameter.name}")
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
@ -1056,6 +1058,11 @@ class ToolManager:
|
||||
elif tool_input.type == "mixed":
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.text
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type not supported in agent mode
|
||||
raise ToolParameterError(
|
||||
f"Mention type not supported in agent for parameter '{parameter.name}'"
|
||||
)
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
runtime_parameters[parameter.name] = parameter_value
|
||||
|
||||
@ -4,6 +4,7 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
@ -20,6 +21,7 @@ from .variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayPromptMessageVariable,
|
||||
ArrayStringVariable,
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
@ -41,6 +43,8 @@ __all__ = [
|
||||
"ArrayNumberVariable",
|
||||
"ArrayObjectSegment",
|
||||
"ArrayObjectVariable",
|
||||
"ArrayPromptMessageSegment",
|
||||
"ArrayPromptMessageVariable",
|
||||
"ArraySegment",
|
||||
"ArrayStringSegment",
|
||||
"ArrayStringVariable",
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Annotated, Any, TypeAlias
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
@ -208,6 +209,15 @@ class ArrayBooleanSegment(ArraySegment):
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
class ArrayPromptMessageSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_PROMPT_MESSAGE
|
||||
value: Sequence[PromptMessage]
|
||||
|
||||
def to_object(self):
|
||||
"""Convert to JSON-serializable format for database storage and frontend."""
|
||||
return [msg.model_dump() for msg in self.value]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type
|
||||
@ -248,6 +258,7 @@ SegmentUnion: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[ArrayPromptMessageSegment, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@ -45,6 +45,7 @@ class SegmentType(StrEnum):
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_PROMPT_MESSAGE = "array[message]"
|
||||
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@ -3,8 +3,10 @@ from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||
from .segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
|
||||
|
||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||
@ -16,7 +18,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
|
||||
|
||||
def segment_orjson_default(o: Any):
|
||||
"""Default function for orjson serialization of Segment types"""
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
if isinstance(o, (ArrayFileSegment, ArrayPromptMessageSegment)):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
@ -24,6 +26,8 @@ def segment_orjson_default(o: Any):
|
||||
return [segment_orjson_default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
elif isinstance(o, PromptMessage):
|
||||
return o.model_dump()
|
||||
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
@ -110,6 +111,10 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayPromptMessageVariable(ArrayPromptMessageSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
@ -160,6 +165,7 @@ VariableUnion: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[ArrayPromptMessageVariable, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
|
||||
@ -311,9 +311,9 @@ class Graph:
|
||||
# - custom-note: top-level type (node_config.type == "custom-note")
|
||||
# - group: data-level type (node_config.data.type == "group")
|
||||
node_configs = [
|
||||
node_config for node_config in node_configs
|
||||
if node_config.get("type", "") != "custom-note"
|
||||
and node_config.get("data", {}).get("type", "") != "group"
|
||||
node_config
|
||||
for node_config in node_configs
|
||||
if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group"
|
||||
]
|
||||
|
||||
# Parse node configurations
|
||||
|
||||
@ -125,9 +125,9 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Check if this is a virtual node (extraction node)
|
||||
if self._is_virtual_node(event.node_id):
|
||||
self._handle_virtual_node_started(event)
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_started(event)
|
||||
return
|
||||
|
||||
# Track execution in domain model
|
||||
@ -169,9 +169,9 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Check if this is a virtual node (extraction node)
|
||||
if self._is_virtual_node(event.node_id):
|
||||
self._handle_virtual_node_success(event)
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_success(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
@ -236,9 +236,9 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Check if this is a virtual node (extraction node)
|
||||
if self._is_virtual_node(event.node_id):
|
||||
self._handle_virtual_node_failed(event)
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_failed(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
@ -361,23 +361,23 @@ class EventHandler:
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
|
||||
def _is_virtual_node(self, node_id: str) -> bool:
|
||||
def _is_extractor_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if node_id represents a virtual sub-node.
|
||||
Check if node_id represents an extractor node (has parent_node_id).
|
||||
|
||||
Virtual nodes have IDs in the format: {parent_node_id}.{local_id}
|
||||
We check if the part before '.' exists in graph nodes.
|
||||
Extractor nodes extract values from list[PromptMessage] for their parent node.
|
||||
They have a parent_node_id field pointing to their parent node.
|
||||
"""
|
||||
if "." in node_id:
|
||||
parent_id = node_id.rsplit(".", 1)[0]
|
||||
return parent_id in self._graph.nodes
|
||||
return False
|
||||
node = self._graph.nodes.get(node_id)
|
||||
if node is None:
|
||||
return False
|
||||
return node.node_data.is_extractor_node
|
||||
|
||||
def _handle_virtual_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
def _handle_extractor_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle virtual node started event.
|
||||
Handle extractor node started event.
|
||||
|
||||
Virtual nodes don't need full execution tracking, just collect the event.
|
||||
Extractor nodes don't need full execution tracking, just collect the event.
|
||||
"""
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
@ -385,11 +385,11 @@ class EventHandler:
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_virtual_node_success(self, event: NodeRunSucceededEvent) -> None:
|
||||
def _handle_extractor_node_success(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle virtual node success event.
|
||||
Handle extractor node success event.
|
||||
|
||||
Virtual nodes (extraction nodes) need special handling:
|
||||
Extractor nodes need special handling:
|
||||
- Store outputs in variable pool (for reference by other nodes)
|
||||
- Accumulate token usage
|
||||
- Collect the event for logging
|
||||
@ -403,11 +403,11 @@ class EventHandler:
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_virtual_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
def _handle_extractor_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle virtual node failed event.
|
||||
Handle extractor node failed event.
|
||||
|
||||
Virtual nodes (extraction nodes) failures are collected for logging,
|
||||
Extractor node failures are collected for logging,
|
||||
but the parent node is responsible for handling the error.
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
@ -20,12 +20,6 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
provider_type: str = ""
|
||||
provider_id: str = ""
|
||||
|
||||
# Virtual node fields for extraction
|
||||
is_virtual: bool = False
|
||||
parent_node_id: str | None = None
|
||||
extraction_source: str | None = None # e.g., "llm1.context"
|
||||
extraction_prompt: str | None = None
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
|
||||
@ -4,10 +4,8 @@ from .entities import (
|
||||
BaseLoopNodeData,
|
||||
BaseLoopState,
|
||||
BaseNodeData,
|
||||
VirtualNodeConfig,
|
||||
)
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
@ -16,7 +14,4 @@ __all__ = [
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
"VirtualNodeConfig",
|
||||
"VirtualNodeExecutionError",
|
||||
"VirtualNodeExecutor",
|
||||
]
|
||||
|
||||
@ -167,24 +167,6 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class VirtualNodeConfig(BaseModel):
|
||||
"""Configuration for a virtual sub-node embedded within a parent node."""
|
||||
|
||||
# Local ID within parent node (e.g., "ext_1")
|
||||
# Will be converted to global ID: "{parent_id}.{id}"
|
||||
id: str
|
||||
|
||||
# Node type (e.g., "llm", "code", "tool")
|
||||
type: str
|
||||
|
||||
# Full node data configuration
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
def get_global_id(self, parent_node_id: str) -> str:
|
||||
"""Get the global node ID by combining parent ID and local ID."""
|
||||
return f"{parent_node_id}.{self.id}"
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
@ -193,8 +175,15 @@ class BaseNodeData(ABC, BaseModel):
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
# Virtual sub-nodes that execute before the main node
|
||||
virtual_nodes: list[VirtualNodeConfig] = []
|
||||
# Parent node ID when this node is used as an extractor.
|
||||
# If set, this node is an "attached" extractor node that extracts values
|
||||
# from list[PromptMessage] for the parent node's parameters.
|
||||
parent_node_id: str | None = None
|
||||
|
||||
@property
|
||||
def is_extractor_node(self) -> bool:
|
||||
"""Check if this node is an extractor node (has parent_node_id)."""
|
||||
return self.parent_node_id is not None
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -229,7 +229,6 @@ class Node(Generic[NodeDataT]):
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
@ -271,51 +270,81 @@ class Node(Generic[NodeDataT]):
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Execute all virtual sub-nodes defined in node configuration.
|
||||
|
||||
Virtual nodes are complete node definitions that execute before the main node.
|
||||
Each virtual node:
|
||||
- Has its own global ID: "{parent_id}.{local_id}"
|
||||
- Generates standard node events
|
||||
- Stores outputs in the variable pool (via event handling)
|
||||
- Supports retry via parent node's retry config
|
||||
Find all extractor node configurations that have parent_node_id == self._node_id.
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
List of node configuration dicts for extractor nodes
|
||||
"""
|
||||
from .virtual_node_executor import VirtualNodeExecutor
|
||||
nodes = self.graph_config.get("nodes", [])
|
||||
extractor_configs = []
|
||||
for node_config in nodes:
|
||||
node_data = node_config.get("data", {})
|
||||
if node_data.get("parent_node_id") == self._node_id:
|
||||
extractor_configs.append(node_config)
|
||||
return extractor_configs
|
||||
|
||||
virtual_nodes = self.node_data.virtual_nodes
|
||||
if not virtual_nodes:
|
||||
return {}
|
||||
|
||||
executor = VirtualNodeExecutor(
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
parent_node_id=self._node_id,
|
||||
parent_retry_config=self.retry_config,
|
||||
)
|
||||
|
||||
return (yield from executor.execute_virtual_nodes(virtual_nodes))
|
||||
|
||||
@property
|
||||
def virtual_node_outputs(self) -> dict[str, Any]:
|
||||
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
"""
|
||||
Get the outputs from virtual sub-nodes.
|
||||
Execute all extractor nodes associated with this node.
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
Extractor nodes are nodes with parent_node_id == self._node_id.
|
||||
They are executed before the main node to extract values from list[PromptMessage].
|
||||
"""
|
||||
return self._virtual_node_outputs
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
extractor_configs = self._find_extractor_node_configs()
|
||||
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
|
||||
if not extractor_configs:
|
||||
return
|
||||
|
||||
for config in extractor_configs:
|
||||
node_id = config.get("id")
|
||||
node_data = config.get("data", {})
|
||||
node_type_str = node_data.get("type")
|
||||
|
||||
if not node_id or not node_type_str:
|
||||
continue
|
||||
|
||||
# Get node class
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
continue
|
||||
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
# Instantiate and execute the extractor node
|
||||
extractor_node = node_cls(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
# Execute and process extractor node events
|
||||
for event in extractor_node.run():
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
# Store extractor node outputs in variable pool
|
||||
outputs = event.node_run_result.outputs
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
yield event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Step 1: Execute virtual sub-nodes before main node execution
|
||||
self._virtual_node_outputs = yield from self._execute_virtual_nodes()
|
||||
# Step 1: Execute associated extractor nodes before main node execution
|
||||
yield from self._execute_extractor_nodes()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
|
||||
@ -1,213 +0,0 @@
|
||||
"""
|
||||
Virtual Node Executor for running embedded sub-nodes within a parent node.
|
||||
|
||||
This module handles the execution of virtual nodes defined in a parent node's
|
||||
`virtual_nodes` configuration. Virtual nodes are complete node definitions
|
||||
that execute before the parent node.
|
||||
|
||||
Example configuration:
|
||||
virtual_nodes:
|
||||
- id: ext_1
|
||||
type: llm
|
||||
data:
|
||||
model: {...}
|
||||
prompt_template: [...]
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import RetryConfig, VirtualNodeConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class VirtualNodeExecutionError(Exception):
|
||||
"""Error during virtual node execution"""
|
||||
|
||||
def __init__(self, node_id: str, original_error: Exception):
|
||||
self.node_id = node_id
|
||||
self.original_error = original_error
|
||||
super().__init__(f"Virtual node {node_id} execution failed: {original_error}")
|
||||
|
||||
|
||||
class VirtualNodeExecutor:
|
||||
"""
|
||||
Executes virtual sub-nodes embedded within a parent node.
|
||||
|
||||
Virtual nodes are complete node definitions that execute before the parent node.
|
||||
Each virtual node:
|
||||
- Has its own global ID: "{parent_id}.{local_id}"
|
||||
- Generates standard node events
|
||||
- Stores outputs in the variable pool
|
||||
- Supports retry via parent node's retry config
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
parent_node_id: str,
|
||||
parent_retry_config: RetryConfig | None = None,
|
||||
):
|
||||
self._graph_init_params = graph_init_params
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._parent_node_id = parent_node_id
|
||||
self._parent_retry_config = parent_retry_config or RetryConfig()
|
||||
|
||||
def execute_virtual_nodes(
|
||||
self,
|
||||
virtual_nodes: list[VirtualNodeConfig],
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute all virtual nodes in order.
|
||||
|
||||
Args:
|
||||
virtual_nodes: List of virtual node configurations
|
||||
|
||||
Yields:
|
||||
Node events from each virtual node execution
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
"""
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
for vnode_config in virtual_nodes:
|
||||
global_id = vnode_config.get_global_id(self._parent_node_id)
|
||||
|
||||
# Execute with retry
|
||||
outputs = yield from self._execute_with_retry(vnode_config, global_id)
|
||||
results[vnode_config.id] = outputs
|
||||
|
||||
return results
|
||||
|
||||
def _execute_with_retry(
|
||||
self,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
global_id: str,
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute virtual node with retry support.
|
||||
"""
|
||||
retry_config = self._parent_retry_config
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(retry_config.max_retries + 1):
|
||||
try:
|
||||
return (yield from self._execute_single_node(vnode_config, global_id))
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt < retry_config.max_retries:
|
||||
# Yield retry event
|
||||
yield NodeRunRetryEvent(
|
||||
id=str(uuid4()),
|
||||
node_id=global_id,
|
||||
node_type=self._get_node_type(vnode_config.type),
|
||||
node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
|
||||
start_at=naive_utc_now(),
|
||||
error=str(e),
|
||||
retry_index=attempt + 1,
|
||||
)
|
||||
|
||||
time.sleep(retry_config.retry_interval_seconds)
|
||||
continue
|
||||
|
||||
raise VirtualNodeExecutionError(global_id, e) from e
|
||||
|
||||
raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error"))
|
||||
|
||||
def _execute_single_node(
|
||||
self,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
global_id: str,
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute a single virtual node by instantiating and running it.
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
# Build node config
|
||||
node_config: dict[str, Any] = {
|
||||
"id": global_id,
|
||||
"data": {
|
||||
**vnode_config.data,
|
||||
"title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
|
||||
},
|
||||
}
|
||||
|
||||
# Get the node class for this type
|
||||
node_type = self._get_node_type(vnode_config.type)
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
node_version = str(vnode_config.data.get("version", "1"))
|
||||
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
raise ValueError(f"No class found for node type: {node_type}")
|
||||
|
||||
# Instantiate the node
|
||||
node = node_cls(
|
||||
id=global_id,
|
||||
config=node_config,
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
)
|
||||
|
||||
# Run and collect events
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
for event in node.run():
|
||||
# Mark event as coming from virtual node
|
||||
self._mark_event_as_virtual(event, vnode_config)
|
||||
yield event
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
outputs = event.node_run_result.outputs or {}
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
raise Exception(event.error or "Virtual node execution failed")
|
||||
|
||||
return outputs
|
||||
|
||||
def _mark_event_as_virtual(
|
||||
self,
|
||||
event: GraphNodeEventBase,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
) -> None:
|
||||
"""Mark event as coming from a virtual node."""
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
event.is_virtual = True
|
||||
event.parent_node_id = self._parent_node_id
|
||||
|
||||
def _get_node_type(self, type_str: str) -> NodeType:
|
||||
"""Convert type string to NodeType enum."""
|
||||
type_mapping = {
|
||||
"llm": NodeType.LLM,
|
||||
"code": NodeType.CODE,
|
||||
"tool": NodeType.TOOL,
|
||||
"if-else": NodeType.IF_ELSE,
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||
"template-transform": NodeType.TEMPLATE_TRANSFORM,
|
||||
"variable-assigner": NodeType.VARIABLE_ASSIGNER,
|
||||
"http-request": NodeType.HTTP_REQUEST,
|
||||
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||
}
|
||||
return type_mapping.get(type_str, NodeType.LLM)
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
@ -58,9 +58,28 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class PromptMessageContext(BaseModel):
|
||||
"""Context variable reference in prompt template.
|
||||
|
||||
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
|
||||
This will be expanded to list[PromptMessage] at runtime.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
value_selector: Sequence[str] = Field(alias="$context")
|
||||
|
||||
|
||||
# Union type for prompt template items (static message or context variable reference)
|
||||
PromptTemplateItem: TypeAlias = Annotated[
|
||||
LLMNodeChatModelMessage | PromptMessageContext,
|
||||
Field(discriminator=None),
|
||||
]
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
|
||||
@ -7,7 +7,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -52,6 +52,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
@ -88,6 +89,7 @@ from .entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
PromptMessageContext,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@ -160,8 +162,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
# Parse prompt template to separate static messages and context references
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages, context_refs, template_order = self._parse_prompt_template()
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
@ -223,21 +226,40 @@ class LLMNode(Node[LLMNodeData]):
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
# Get prompt messages
|
||||
prompt_messages: Sequence[PromptMessage]
|
||||
stop: Sequence[str] | None
|
||||
if isinstance(prompt_template, list) and context_refs:
|
||||
prompt_messages, stop = self._build_prompt_messages_with_context(
|
||||
context_refs=context_refs,
|
||||
template_order=template_order,
|
||||
static_messages=static_messages,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(
|
||||
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
self.node_data.prompt_template,
|
||||
),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
@ -304,7 +326,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": self._build_context(prompt_messages, clean_text, model_config.mode),
|
||||
"context": self._build_context(prompt_messages, clean_text),
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
@ -602,17 +624,15 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def _build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
model_mode: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_mode, prompt_messages=context_messages
|
||||
)
|
||||
return context_messages
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
@ -629,6 +649,106 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return messages
|
||||
|
||||
def _parse_prompt_template(
|
||||
self,
|
||||
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
|
||||
"""
|
||||
Parse prompt_template to separate static messages and context references.
|
||||
|
||||
Returns:
|
||||
Tuple of (static_messages, context_refs, template_order)
|
||||
- static_messages: list of LLMNodeChatModelMessage
|
||||
- context_refs: list of PromptMessageContext
|
||||
- template_order: list of (index, type) tuples preserving original order
|
||||
"""
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages: list[LLMNodeChatModelMessage] = []
|
||||
context_refs: list[PromptMessageContext] = []
|
||||
template_order: list[tuple[int, str]] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for idx, item in enumerate(prompt_template):
|
||||
if isinstance(item, PromptMessageContext):
|
||||
context_refs.append(item)
|
||||
template_order.append((idx, "context"))
|
||||
else:
|
||||
static_messages.append(item)
|
||||
template_order.append((idx, "static"))
|
||||
# Transform static messages for jinja2
|
||||
if static_messages:
|
||||
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
|
||||
|
||||
return static_messages, context_refs, template_order
|
||||
|
||||
def _build_prompt_messages_with_context(
|
||||
self,
|
||||
*,
|
||||
context_refs: list[PromptMessageContext],
|
||||
template_order: list[tuple[int, str]],
|
||||
static_messages: list[LLMNodeChatModelMessage],
|
||||
query: str | None,
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context_files: list[File],
|
||||
) -> tuple[list[PromptMessage], Sequence[str] | None]:
|
||||
"""
|
||||
Build prompt messages by combining static messages and context references in DSL order.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_messages, stop_sequences)
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# Build a map from context index to its messages
|
||||
context_messages_map: dict[int, list[PromptMessage]] = {}
|
||||
context_idx = 0
|
||||
for idx, type_ in template_order:
|
||||
if type_ == "context":
|
||||
ctx_ref = context_refs[context_idx]
|
||||
ctx_var = variable_pool.get(ctx_ref.value_selector)
|
||||
if ctx_var is None:
|
||||
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
|
||||
if not isinstance(ctx_var, ArrayPromptMessageSegment):
|
||||
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
|
||||
context_messages_map[idx] = list(ctx_var.value)
|
||||
context_idx += 1
|
||||
|
||||
# Process static messages
|
||||
static_prompt_messages: Sequence[PromptMessage] = []
|
||||
stop: Sequence[str] | None = None
|
||||
if static_messages:
|
||||
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# Combine messages according to original DSL order
|
||||
combined_messages: list[PromptMessage] = []
|
||||
static_msg_iter = iter(static_prompt_messages)
|
||||
for idx, type_ in template_order:
|
||||
if type_ == "context":
|
||||
combined_messages.extend(context_messages_map[idx])
|
||||
else:
|
||||
if msg := next(static_msg_iter, None):
|
||||
combined_messages.append(msg)
|
||||
# Append any remaining static messages (e.g., memory messages)
|
||||
combined_messages.extend(static_msg_iter)
|
||||
|
||||
return combined_messages, stop
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
@ -7,6 +8,31 @@ from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class MentionValue(BaseModel):
|
||||
"""Value structure for mention type parameters.
|
||||
|
||||
Used when a tool parameter needs to be extracted from conversation context
|
||||
using an extractor LLM node.
|
||||
"""
|
||||
|
||||
# Variable selector for list[PromptMessage] input to extractor
|
||||
variable_selector: Sequence[str]
|
||||
|
||||
# ID of the extractor LLM node
|
||||
extractor_node_id: str
|
||||
|
||||
# Output variable selector from extractor node
|
||||
# e.g., ["text"], ["structured_output", "query"]
|
||||
output_selector: Sequence[str]
|
||||
|
||||
# Strategy when output is None
|
||||
null_strategy: Literal["raise_error", "use_default"] = "raise_error"
|
||||
|
||||
# Default value when null_strategy is "use_default"
|
||||
# Type should match the parameter's expected type
|
||||
default_value: Any = None
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: ToolProviderType
|
||||
@ -34,8 +60,8 @@ class ToolEntity(BaseModel):
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: Union[Any, list[str], MentionValue]
|
||||
type: Literal["mixed", "variable", "constant", "mention"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
@ -56,6 +82,17 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
elif typ == "mention":
|
||||
# Mention type: value should be a MentionValue or dict with required fields
|
||||
if isinstance(value, MentionValue):
|
||||
pass # Already validated by Pydantic
|
||||
elif isinstance(value, dict):
|
||||
if "extractor_node_id" not in value:
|
||||
raise ValueError("value must contain extractor_node_id for mention type")
|
||||
if "output_selector" not in value:
|
||||
raise ValueError("value must contain output_selector for mention type")
|
||||
else:
|
||||
raise ValueError("value must be a MentionValue or dict for mention type")
|
||||
return typ
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
@ -89,20 +92,18 @@ class ToolNode(Node[ToolNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# get parameters (use virtual_node_outputs from base class)
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
virtual_node_outputs=self.virtual_node_outputs,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
virtual_node_outputs=self.virtual_node_outputs,
|
||||
)
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
@ -178,7 +179,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
variable_pool: "VariablePool",
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
virtual_node_outputs: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
@ -188,16 +188,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
for_log (bool): Whether to generate parameters for logging.
|
||||
virtual_node_outputs (dict[str, Any] | None): Outputs from virtual sub-nodes.
|
||||
Maps local_id -> outputs dict. Virtual node outputs are also in variable_pool
|
||||
with global IDs like "{parent_id}.{local_id}".
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
virtual_node_outputs = virtual_node_outputs or {}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
@ -207,22 +203,39 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
# Check if this references a virtual node output (local ID like [ext_1, text])
|
||||
if not isinstance(tool_input.value, list):
|
||||
raise ToolParameterError(f"Invalid variable selector for parameter '{parameter_name}'")
|
||||
selector = tool_input.value
|
||||
if len(selector) >= 2 and selector[0] in virtual_node_outputs:
|
||||
# Reference to virtual node output
|
||||
local_id = selector[0]
|
||||
var_name = selector[1]
|
||||
outputs = virtual_node_outputs.get(local_id, {})
|
||||
parameter_value = outputs.get(var_name)
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {selector} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type: get value from extractor node's output
|
||||
from .entities import MentionValue
|
||||
|
||||
mention_value = tool_input.value
|
||||
if isinstance(mention_value, MentionValue):
|
||||
mention_config = mention_value.model_dump()
|
||||
elif isinstance(mention_value, dict):
|
||||
mention_config = mention_value
|
||||
else:
|
||||
# Normal variable reference
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {selector} does not exist")
|
||||
raise ToolParameterError(f"Invalid mention value for parameter '{parameter_name}'")
|
||||
|
||||
try:
|
||||
parameter_value, found = variable_pool.resolve_mention(
|
||||
mention_config, parameter_name=parameter_name
|
||||
)
|
||||
if not found and parameter.required:
|
||||
raise ToolParameterError(
|
||||
f"Extractor output not found for required parameter '{parameter_name}'"
|
||||
)
|
||||
if not found:
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
except ValueError as e:
|
||||
raise ToolParameterError(str(e)) from e
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
template = str(tool_input.value)
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
@ -507,8 +520,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
if isinstance(input.value, list):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "mention":
|
||||
# Mention type handled by extractor node, no direct variable reference
|
||||
pass
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
|
||||
@ -268,6 +268,58 @@ class VariablePool(BaseModel):
|
||||
continue
|
||||
self.add(selector, value)
|
||||
|
||||
def resolve_mention(
|
||||
self,
|
||||
mention_config: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
parameter_name: str = "",
|
||||
) -> tuple[Any, bool]:
|
||||
"""
|
||||
Resolve a mention parameter value from an extractor node's output.
|
||||
|
||||
Mention parameters reference values extracted by an extractor LLM node
|
||||
from list[PromptMessage] context.
|
||||
|
||||
Args:
|
||||
mention_config: A dict containing:
|
||||
- extractor_node_id: ID of the extractor LLM node
|
||||
- output_selector: Selector path for the output variable (e.g., ["text"])
|
||||
- null_strategy: "raise_error" or "use_default"
|
||||
- default_value: Value to use when null_strategy is "use_default"
|
||||
parameter_name: Name of the parameter being resolved (for error messages)
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_value, found):
|
||||
- resolved_value: The extracted value, or default_value if not found
|
||||
- found: True if value was found, False if using default
|
||||
|
||||
Raises:
|
||||
ValueError: If extractor_node_id is missing, or if null_strategy is
|
||||
"raise_error" and the value is not found
|
||||
"""
|
||||
extractor_node_id = mention_config.get("extractor_node_id")
|
||||
if not extractor_node_id:
|
||||
raise ValueError(f"Missing extractor_node_id for mention parameter '{parameter_name}'")
|
||||
|
||||
output_selector = list(mention_config.get("output_selector", []))
|
||||
null_strategy = mention_config.get("null_strategy", "raise_error")
|
||||
default_value = mention_config.get("default_value")
|
||||
|
||||
# Build full selector: [extractor_node_id, ...output_selector]
|
||||
full_selector = [extractor_node_id] + output_selector
|
||||
variable = self.get(full_selector)
|
||||
|
||||
if variable is None:
|
||||
if null_strategy == "use_default":
|
||||
return default_value, False
|
||||
raise ValueError(
|
||||
f"Extractor node '{extractor_node_id}' output '{'.'.join(output_selector)}' "
|
||||
f"not found for parameter '{parameter_name}'"
|
||||
)
|
||||
|
||||
return variable.value, True
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> VariablePool:
|
||||
"""Create an empty variable pool."""
|
||||
|
||||
@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
@ -11,6 +12,7 @@ from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
@ -29,6 +31,7 @@ from core.variables.variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayPromptMessageVariable,
|
||||
ArrayStringVariable,
|
||||
BooleanVariable,
|
||||
FileVariable,
|
||||
@ -61,6 +64,7 @@ SEGMENT_TO_VARIABLE_MAP = {
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
ArrayNumberSegment: ArrayNumberVariable,
|
||||
ArrayObjectSegment: ArrayObjectVariable,
|
||||
ArrayPromptMessageSegment: ArrayPromptMessageVariable,
|
||||
ArrayStringSegment: ArrayStringVariable,
|
||||
BooleanSegment: BooleanVariable,
|
||||
FileSegment: FileVariable,
|
||||
@ -156,7 +160,13 @@ def build_segment(value: Any, /) -> Segment:
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, File):
|
||||
return FileSegment(value=value)
|
||||
if isinstance(value, PromptMessage):
|
||||
# Single PromptMessage should be wrapped in a list
|
||||
return ArrayPromptMessageSegment(value=[value])
|
||||
if isinstance(value, list):
|
||||
# Check if all items are PromptMessage
|
||||
if value and all(isinstance(item, PromptMessage) for item in value):
|
||||
return ArrayPromptMessageSegment(value=value)
|
||||
items = [build_segment(item) for item in value]
|
||||
types = {item.value_type for item in items}
|
||||
if all(isinstance(item, ArraySegment) for item in items):
|
||||
@ -200,6 +210,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
|
||||
SegmentType.ARRAY_FILE: ArrayFileSegment,
|
||||
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
|
||||
SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1291,7 +1291,7 @@ class WorkflowDraftVariable(Base):
|
||||
# which may differ from the original value's type. Typically, they are the same,
|
||||
# but in cases where the structurally truncated value still exceeds the size limit,
|
||||
# text slicing is applied, and the `value_type` is converted to `STRING`.
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=21))
|
||||
|
||||
# The variable's value serialized as a JSON string
|
||||
#
|
||||
@ -1665,7 +1665,7 @@ class WorkflowDraftVariableFile(Base):
|
||||
|
||||
# The `value_type` field records the type of the original value.
|
||||
value_type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=20),
|
||||
EnumText(SegmentType, length=21),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, Generic, TypeAlias, TypeVar, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.models import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
@ -287,6 +288,10 @@ class VariableTruncator(BaseTruncator):
|
||||
if isinstance(item, File):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
# Handle PromptMessage types - convert to dict for truncation
|
||||
if isinstance(item, PromptMessage):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
if i >= target_length:
|
||||
return _PartResult(truncated_value, used_size, True)
|
||||
if i > 0:
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
"""
|
||||
Unit tests for virtual node configuration.
|
||||
"""
|
||||
|
||||
from core.workflow.nodes.base.entities import VirtualNodeConfig
|
||||
|
||||
|
||||
class TestVirtualNodeConfig:
|
||||
"""Tests for VirtualNodeConfig entity."""
|
||||
|
||||
def test_create_basic_config(self):
|
||||
"""Test creating a basic virtual node config."""
|
||||
config = VirtualNodeConfig(
|
||||
id="ext_1",
|
||||
type="llm",
|
||||
data={
|
||||
"title": "Extract keywords",
|
||||
"model": {"provider": "openai", "name": "gpt-4o-mini"},
|
||||
},
|
||||
)
|
||||
|
||||
assert config.id == "ext_1"
|
||||
assert config.type == "llm"
|
||||
assert config.data["title"] == "Extract keywords"
|
||||
|
||||
def test_get_global_id(self):
|
||||
"""Test generating global ID from parent ID."""
|
||||
config = VirtualNodeConfig(
|
||||
id="ext_1",
|
||||
type="llm",
|
||||
data={},
|
||||
)
|
||||
|
||||
global_id = config.get_global_id("tool1")
|
||||
assert global_id == "tool1.ext_1"
|
||||
|
||||
def test_get_global_id_with_different_parents(self):
|
||||
"""Test global ID generation with different parent IDs."""
|
||||
config = VirtualNodeConfig(id="sub_node", type="code", data={})
|
||||
|
||||
assert config.get_global_id("parent1") == "parent1.sub_node"
|
||||
assert config.get_global_id("node_123") == "node_123.sub_node"
|
||||
|
||||
def test_empty_data(self):
|
||||
"""Test virtual node config with empty data."""
|
||||
config = VirtualNodeConfig(
|
||||
id="test",
|
||||
type="tool",
|
||||
)
|
||||
|
||||
assert config.id == "test"
|
||||
assert config.type == "tool"
|
||||
assert config.data == {}
|
||||
|
||||
def test_complex_data(self):
|
||||
"""Test virtual node config with complex data."""
|
||||
config = VirtualNodeConfig(
|
||||
id="llm_1",
|
||||
type="llm",
|
||||
data={
|
||||
"title": "Generate summary",
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0.7, "max_tokens": 500},
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "user", "text": "{{#llm1.context#}}"},
|
||||
{"role": "user", "text": "Please summarize the conversation"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert config.data["model"]["provider"] == "openai"
|
||||
assert len(config.data["prompt_template"]) == 2
|
||||
|
||||
@ -25,6 +25,12 @@ class _StubErrorHandler:
|
||||
"""Minimal error handler stub for tests."""
|
||||
|
||||
|
||||
class _StubNodeData:
|
||||
"""Simple node data stub with is_extractor_node property."""
|
||||
|
||||
is_extractor_node = False
|
||||
|
||||
|
||||
class _StubNode:
|
||||
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||
|
||||
@ -36,6 +42,7 @@ class _StubNode:
|
||||
self.error_strategy = None
|
||||
self.retry_config = RetryConfig()
|
||||
self.retry = False
|
||||
self.node_data = _StubNodeData()
|
||||
|
||||
|
||||
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||
|
||||
Reference in New Issue
Block a user