feat: add mention type variable

This commit is contained in:
Novice
2026-01-12 17:39:36 +08:00
parent d65ae68668
commit bb190f9610
23 changed files with 457 additions and 439 deletions

View File

@ -1047,6 +1047,8 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
if not isinstance(tool_input.value, list):
raise ToolParameterError(f"Invalid variable selector for {parameter.name}")
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
@ -1056,6 +1058,11 @@ class ToolManager:
elif tool_input.type == "mixed":
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.text
elif tool_input.type == "mention":
# Mention type not supported in agent mode
raise ToolParameterError(
f"Mention type not supported in agent for parameter '{parameter.name}'"
)
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
runtime_parameters[parameter.name] = parameter_value

View File

@ -4,6 +4,7 @@ from .segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
FileSegment,
@ -20,6 +21,7 @@ from .variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayPromptMessageVariable,
ArrayStringVariable,
ArrayVariable,
FileVariable,
@ -41,6 +43,8 @@ __all__ = [
"ArrayNumberVariable",
"ArrayObjectSegment",
"ArrayObjectVariable",
"ArrayPromptMessageSegment",
"ArrayPromptMessageVariable",
"ArraySegment",
"ArrayStringSegment",
"ArrayStringVariable",

View File

@ -6,6 +6,7 @@ from typing import Annotated, Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
from core.model_runtime.entities import PromptMessage
from .types import SegmentType
@ -208,6 +209,15 @@ class ArrayBooleanSegment(ArraySegment):
value: Sequence[bool]
class ArrayPromptMessageSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_PROMPT_MESSAGE
value: Sequence[PromptMessage]
def to_object(self):
"""Convert to JSON-serializable format for database storage and frontend."""
return [msg.model_dump() for msg in self.value]
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
@ -248,6 +258,7 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[ArrayPromptMessageSegment, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -45,6 +45,7 @@ class SegmentType(StrEnum):
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_PROMPT_MESSAGE = "array[message]"
NONE = "none"

View File

@ -3,8 +3,10 @@ from typing import Any
import orjson
from core.model_runtime.entities import PromptMessage
from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, FileSegment, Segment
from .segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
@ -16,7 +18,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
def segment_orjson_default(o: Any):
"""Default function for orjson serialization of Segment types"""
if isinstance(o, ArrayFileSegment):
if isinstance(o, (ArrayFileSegment, ArrayPromptMessageSegment)):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
@ -24,6 +26,8 @@ def segment_orjson_default(o: Any):
return [segment_orjson_default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
elif isinstance(o, PromptMessage):
return o.model_dump()
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")

View File

@ -12,6 +12,7 @@ from .segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
@ -110,6 +111,10 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
class ArrayPromptMessageVariable(ArrayPromptMessageSegment, ArrayVariable):
pass
class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
@ -160,6 +165,7 @@ VariableUnion: TypeAlias = Annotated[
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[ArrayPromptMessageVariable, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),

View File

@ -311,9 +311,9 @@ class Graph:
# - custom-note: top-level type (node_config.type == "custom-note")
# - group: data-level type (node_config.data.type == "group")
node_configs = [
node_config for node_config in node_configs
if node_config.get("type", "") != "custom-note"
and node_config.get("data", {}).get("type", "") != "group"
node_config
for node_config in node_configs
if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group"
]
# Parse node configurations

View File

@ -125,9 +125,9 @@ class EventHandler:
Args:
event: The node started event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_started(event)
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_started(event)
return
# Track execution in domain model
@ -169,9 +169,9 @@ class EventHandler:
Args:
event: The node succeeded event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_success(event)
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_success(event)
return
# Update domain model
@ -236,9 +236,9 @@ class EventHandler:
Args:
event: The node failed event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_failed(event)
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_failed(event)
return
# Update domain model
@ -361,23 +361,23 @@ class EventHandler:
else:
self._graph_runtime_state.set_output(key, value)
def _is_virtual_node(self, node_id: str) -> bool:
def _is_extractor_node(self, node_id: str) -> bool:
"""
Check if node_id represents a virtual sub-node.
Check if node_id represents an extractor node (has parent_node_id).
Virtual nodes have IDs in the format: {parent_node_id}.{local_id}
We check if the part before '.' exists in graph nodes.
Extractor nodes extract values from list[PromptMessage] for their parent node.
They have a parent_node_id field pointing to their parent node.
"""
if "." in node_id:
parent_id = node_id.rsplit(".", 1)[0]
return parent_id in self._graph.nodes
return False
node = self._graph.nodes.get(node_id)
if node is None:
return False
return node.node_data.is_extractor_node
def _handle_virtual_node_started(self, event: NodeRunStartedEvent) -> None:
def _handle_extractor_node_started(self, event: NodeRunStartedEvent) -> None:
"""
Handle virtual node started event.
Handle extractor node started event.
Virtual nodes don't need full execution tracking, just collect the event.
Extractor nodes don't need full execution tracking, just collect the event.
"""
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
@ -385,11 +385,11 @@ class EventHandler:
# Collect the event
self._event_collector.collect(event)
def _handle_virtual_node_success(self, event: NodeRunSucceededEvent) -> None:
def _handle_extractor_node_success(self, event: NodeRunSucceededEvent) -> None:
"""
Handle virtual node success event.
Handle extractor node success event.
Virtual nodes (extraction nodes) need special handling:
Extractor nodes need special handling:
- Store outputs in variable pool (for reference by other nodes)
- Accumulate token usage
- Collect the event for logging
@ -403,11 +403,11 @@ class EventHandler:
# Collect the event
self._event_collector.collect(event)
def _handle_virtual_node_failed(self, event: NodeRunFailedEvent) -> None:
def _handle_extractor_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
Handle virtual node failed event.
Handle extractor node failed event.
Virtual nodes (extraction nodes) failures are collected for logging,
Extractor node failures are collected for logging,
but the parent node is responsible for handling the error.
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)

View File

@ -20,12 +20,6 @@ class NodeRunStartedEvent(GraphNodeEventBase):
provider_type: str = ""
provider_id: str = ""
# Virtual node fields for extraction
is_virtual: bool = False
parent_node_id: str | None = None
extraction_source: str | None = None # e.g., "llm1.context"
extraction_prompt: str | None = None
class NodeRunStreamChunkEvent(GraphNodeEventBase):
# Spec-compliant fields

View File

@ -4,10 +4,8 @@ from .entities import (
BaseLoopNodeData,
BaseLoopState,
BaseNodeData,
VirtualNodeConfig,
)
from .usage_tracking_mixin import LLMUsageTrackingMixin
from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor
__all__ = [
"BaseIterationNodeData",
@ -16,7 +14,4 @@ __all__ = [
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
"VirtualNodeConfig",
"VirtualNodeExecutionError",
"VirtualNodeExecutor",
]

View File

@ -167,24 +167,6 @@ class DefaultValue(BaseModel):
return self
class VirtualNodeConfig(BaseModel):
"""Configuration for a virtual sub-node embedded within a parent node."""
# Local ID within parent node (e.g., "ext_1")
# Will be converted to global ID: "{parent_id}.{id}"
id: str
# Node type (e.g., "llm", "code", "tool")
type: str
# Full node data configuration
data: dict[str, Any] = {}
def get_global_id(self, parent_node_id: str) -> str:
"""Get the global node ID by combining parent ID and local ID."""
return f"{parent_node_id}.{self.id}"
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
@ -193,8 +175,15 @@ class BaseNodeData(ABC, BaseModel):
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
# Virtual sub-nodes that execute before the main node
virtual_nodes: list[VirtualNodeConfig] = []
# Parent node ID when this node is used as an extractor.
# If set, this node is an "attached" extractor node that extracts values
# from list[PromptMessage] for the parent node's parameters.
parent_node_id: str | None = None
@property
def is_extractor_node(self) -> bool:
"""Check if this node is an extractor node (has parent_node_id)."""
return self.parent_node_id is not None
@property
def default_value_dict(self) -> dict[str, Any]:

View File

@ -229,7 +229,6 @@ class Node(Generic[NodeDataT]):
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
@ -271,51 +270,81 @@ class Node(Generic[NodeDataT]):
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
"""
Execute all virtual sub-nodes defined in node configuration.
Virtual nodes are complete node definitions that execute before the main node.
Each virtual node:
- Has its own global ID: "{parent_id}.{local_id}"
- Generates standard node events
- Stores outputs in the variable pool (via event handling)
- Supports retry via parent node's retry config
Find all extractor node configurations that have parent_node_id == self._node_id.
Returns:
dict mapping local_id -> outputs dict
List of node configuration dicts for extractor nodes
"""
from .virtual_node_executor import VirtualNodeExecutor
nodes = self.graph_config.get("nodes", [])
extractor_configs = []
for node_config in nodes:
node_data = node_config.get("data", {})
if node_data.get("parent_node_id") == self._node_id:
extractor_configs.append(node_config)
return extractor_configs
virtual_nodes = self.node_data.virtual_nodes
if not virtual_nodes:
return {}
executor = VirtualNodeExecutor(
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
parent_node_id=self._node_id,
parent_retry_config=self.retry_config,
)
return (yield from executor.execute_virtual_nodes(virtual_nodes))
@property
def virtual_node_outputs(self) -> dict[str, Any]:
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
"""
Get the outputs from virtual sub-nodes.
Execute all extractor nodes associated with this node.
Returns:
dict mapping local_id -> outputs dict
Extractor nodes are nodes with parent_node_id == self._node_id.
They are executed before the main node to extract values from list[PromptMessage].
"""
return self._virtual_node_outputs
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
extractor_configs = self._find_extractor_node_configs()
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
if not extractor_configs:
return
for config in extractor_configs:
node_id = config.get("id")
node_data = config.get("data", {})
node_type_str = node_data.get("type")
if not node_id or not node_type_str:
continue
# Get node class
try:
node_type = NodeType(node_type_str)
except ValueError:
continue
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
continue
node_version = str(node_data.get("version", "1"))
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
if not node_cls:
continue
# Instantiate and execute the extractor node
extractor_node = node_cls(
id=node_id,
config=config,
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Execute and process extractor node events
for event in extractor_node.run():
if isinstance(event, NodeRunSucceededEvent):
# Store extractor node outputs in variable pool
outputs = event.node_run_result.outputs
for variable_name, variable_value in outputs.items():
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
yield event
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Step 1: Execute virtual sub-nodes before main node execution
self._virtual_node_outputs = yield from self._execute_virtual_nodes()
# Step 1: Execute associated extractor nodes before main node execution
yield from self._execute_extractor_nodes()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(

View File

@ -1,213 +0,0 @@
"""
Virtual Node Executor for running embedded sub-nodes within a parent node.
This module handles the execution of virtual nodes defined in a parent node's
`virtual_nodes` configuration. Virtual nodes are complete node definitions
that execute before the parent node.
Example configuration:
virtual_nodes:
- id: ext_1
type: llm
data:
model: {...}
prompt_template: [...]
"""
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
from uuid import uuid4
from core.workflow.enums import NodeType
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunFailedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from libs.datetime_utils import naive_utc_now
from .entities import RetryConfig, VirtualNodeConfig
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class VirtualNodeExecutionError(Exception):
"""Error during virtual node execution"""
def __init__(self, node_id: str, original_error: Exception):
self.node_id = node_id
self.original_error = original_error
super().__init__(f"Virtual node {node_id} execution failed: {original_error}")
class VirtualNodeExecutor:
"""
Executes virtual sub-nodes embedded within a parent node.
Virtual nodes are complete node definitions that execute before the parent node.
Each virtual node:
- Has its own global ID: "{parent_id}.{local_id}"
- Generates standard node events
- Stores outputs in the variable pool
- Supports retry via parent node's retry config
"""
def __init__(
self,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
parent_node_id: str,
parent_retry_config: RetryConfig | None = None,
):
self._graph_init_params = graph_init_params
self._graph_runtime_state = graph_runtime_state
self._parent_node_id = parent_node_id
self._parent_retry_config = parent_retry_config or RetryConfig()
def execute_virtual_nodes(
self,
virtual_nodes: list[VirtualNodeConfig],
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute all virtual nodes in order.
Args:
virtual_nodes: List of virtual node configurations
Yields:
Node events from each virtual node execution
Returns:
dict mapping local_id -> outputs dict
"""
results: dict[str, Any] = {}
for vnode_config in virtual_nodes:
global_id = vnode_config.get_global_id(self._parent_node_id)
# Execute with retry
outputs = yield from self._execute_with_retry(vnode_config, global_id)
results[vnode_config.id] = outputs
return results
def _execute_with_retry(
self,
vnode_config: VirtualNodeConfig,
global_id: str,
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute virtual node with retry support.
"""
retry_config = self._parent_retry_config
last_error: Exception | None = None
for attempt in range(retry_config.max_retries + 1):
try:
return (yield from self._execute_single_node(vnode_config, global_id))
except Exception as e:
last_error = e
if attempt < retry_config.max_retries:
# Yield retry event
yield NodeRunRetryEvent(
id=str(uuid4()),
node_id=global_id,
node_type=self._get_node_type(vnode_config.type),
node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
start_at=naive_utc_now(),
error=str(e),
retry_index=attempt + 1,
)
time.sleep(retry_config.retry_interval_seconds)
continue
raise VirtualNodeExecutionError(global_id, e) from e
raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error"))
def _execute_single_node(
self,
vnode_config: VirtualNodeConfig,
global_id: str,
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute a single virtual node by instantiating and running it.
"""
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
# Build node config
node_config: dict[str, Any] = {
"id": global_id,
"data": {
**vnode_config.data,
"title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
},
}
# Get the node class for this type
node_type = self._get_node_type(vnode_config.type)
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
node_version = str(vnode_config.data.get("version", "1"))
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
if not node_cls:
raise ValueError(f"No class found for node type: {node_type}")
# Instantiate the node
node = node_cls(
id=global_id,
config=node_config,
graph_init_params=self._graph_init_params,
graph_runtime_state=self._graph_runtime_state,
)
# Run and collect events
outputs: dict[str, Any] = {}
for event in node.run():
# Mark event as coming from virtual node
self._mark_event_as_virtual(event, vnode_config)
yield event
if isinstance(event, NodeRunSucceededEvent):
outputs = event.node_run_result.outputs or {}
elif isinstance(event, NodeRunFailedEvent):
raise Exception(event.error or "Virtual node execution failed")
return outputs
def _mark_event_as_virtual(
self,
event: GraphNodeEventBase,
vnode_config: VirtualNodeConfig,
) -> None:
"""Mark event as coming from a virtual node."""
if isinstance(event, NodeRunStartedEvent):
event.is_virtual = True
event.parent_node_id = self._parent_node_id
def _get_node_type(self, type_str: str) -> NodeType:
"""Convert type string to NodeType enum."""
type_mapping = {
"llm": NodeType.LLM,
"code": NodeType.CODE,
"tool": NodeType.TOOL,
"if-else": NodeType.IF_ELSE,
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"template-transform": NodeType.TEMPLATE_TRANSFORM,
"variable-assigner": NodeType.VARIABLE_ASSIGNER,
"http-request": NodeType.HTTP_REQUEST,
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
}
return type_mapping.get(type_str, NodeType.LLM)

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Annotated, Any, Literal, TypeAlias
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
@ -58,9 +58,28 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: str | None = None
class PromptMessageContext(BaseModel):
"""Context variable reference in prompt template.
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
This will be expanded to list[PromptMessage] at runtime.
"""
model_config = ConfigDict(populate_by_name=True)
value_selector: Sequence[str] = Field(alias="$context")
# Union type for prompt template items (static message or context variable reference)
PromptTemplateItem: TypeAlias = Annotated[
LLMNodeChatModelMessage | PromptMessageContext,
Field(discriminator=None),
]
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: MemoryConfig | None = None
context: ContextConfig

View File

@ -7,7 +7,7 @@ import logging
import re
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
@ -52,6 +52,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArrayPromptMessageSegment,
ArraySegment,
FileSegment,
NoneSegment,
@ -88,6 +89,7 @@ from .entities import (
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
PromptMessageContext,
)
from .exc import (
InvalidContextStructureError,
@ -160,8 +162,9 @@ class LLMNode(Node[LLMNodeData]):
variable_pool = self.graph_runtime_state.variable_pool
try:
# init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# Parse prompt template to separate static messages and context references
prompt_template = self.node_data.prompt_template
static_messages, context_refs, template_order = self._parse_prompt_template()
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
@ -223,21 +226,40 @@ class LLMNode(Node[LLMNodeData]):
):
query = query_variable.text
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# Get prompt messages
prompt_messages: Sequence[PromptMessage]
stop: Sequence[str] | None
if isinstance(prompt_template, list) and context_refs:
prompt_messages, stop = self._build_prompt_messages_with_context(
context_refs=context_refs,
template_order=template_order,
static_messages=static_messages,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config,
context_files=context_files,
)
else:
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
self.node_data.prompt_template,
),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# handle invoke result
generator = LLMNode.invoke_llm(
@ -304,7 +326,7 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": self._build_context(prompt_messages, clean_text, model_config.mode),
"context": self._build_context(prompt_messages, clean_text),
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
@ -602,17 +624,15 @@ class LLMNode(Node[LLMNodeData]):
def _build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
model_mode: str,
) -> list[dict[str, Any]]:
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
"""
context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_mode, prompt_messages=context_messages
)
return context_messages
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
@ -629,6 +649,106 @@ class LLMNode(Node[LLMNodeData]):
return messages
def _parse_prompt_template(
self,
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
"""
Parse prompt_template to separate static messages and context references.
Returns:
Tuple of (static_messages, context_refs, template_order)
- static_messages: list of LLMNodeChatModelMessage
- context_refs: list of PromptMessageContext
- template_order: list of (index, type) tuples preserving original order
"""
prompt_template = self.node_data.prompt_template
static_messages: list[LLMNodeChatModelMessage] = []
context_refs: list[PromptMessageContext] = []
template_order: list[tuple[int, str]] = []
if isinstance(prompt_template, list):
for idx, item in enumerate(prompt_template):
if isinstance(item, PromptMessageContext):
context_refs.append(item)
template_order.append((idx, "context"))
else:
static_messages.append(item)
template_order.append((idx, "static"))
# Transform static messages for jinja2
if static_messages:
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
return static_messages, context_refs, template_order
def _build_prompt_messages_with_context(
self,
*,
context_refs: list[PromptMessageContext],
template_order: list[tuple[int, str]],
static_messages: list[LLMNodeChatModelMessage],
query: str | None,
files: Sequence[File],
context: str | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
context_files: list[File],
) -> tuple[list[PromptMessage], Sequence[str] | None]:
"""
Build prompt messages by combining static messages and context references in DSL order.
Returns:
Tuple of (prompt_messages, stop_sequences)
"""
variable_pool = self.graph_runtime_state.variable_pool
# Build a map from context index to its messages
context_messages_map: dict[int, list[PromptMessage]] = {}
context_idx = 0
for idx, type_ in template_order:
if type_ == "context":
ctx_ref = context_refs[context_idx]
ctx_var = variable_pool.get(ctx_ref.value_selector)
if ctx_var is None:
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
context_messages_map[idx] = list(ctx_var.value)
context_idx += 1
# Process static messages
static_prompt_messages: Sequence[PromptMessage] = []
stop: Sequence[str] | None = None
if static_messages:
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# Combine messages according to original DSL order
combined_messages: list[PromptMessage] = []
static_msg_iter = iter(static_prompt_messages)
for idx, type_ in template_order:
if type_ == "context":
combined_messages.extend(context_messages_map[idx])
else:
if msg := next(static_msg_iter, None):
combined_messages.append(msg)
# Append any remaining static messages (e.g., memory messages)
combined_messages.extend(static_msg_iter)
return combined_messages, stop
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables: dict[str, Any] = {}

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
@ -7,6 +8,31 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.base.entities import BaseNodeData
class MentionValue(BaseModel):
"""Value structure for mention type parameters.
Used when a tool parameter needs to be extracted from conversation context
using an extractor LLM node.
"""
# Variable selector for list[PromptMessage] input to extractor
variable_selector: Sequence[str]
# ID of the extractor LLM node
extractor_node_id: str
# Output variable selector from extractor node
# e.g., ["text"], ["structured_output", "query"]
output_selector: Sequence[str]
# Strategy when output is None
null_strategy: Literal["raise_error", "use_default"] = "raise_error"
# Default value when null_strategy is "use_default"
# Type should match the parameter's expected type
default_value: Any = None
class ToolEntity(BaseModel):
provider_id: str
provider_type: ToolProviderType
@ -34,8 +60,8 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
value: Union[Any, list[str], MentionValue]
type: Literal["mixed", "variable", "constant", "mention"]
@field_validator("type", mode="before")
@classmethod
@ -56,6 +82,17 @@ class ToolNodeData(BaseNodeData, ToolEntity):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
raise ValueError("value must be a string, int, float, bool or dict")
elif typ == "mention":
# Mention type: value should be a MentionValue or dict with required fields
if isinstance(value, MentionValue):
pass # Already validated by Pydantic
elif isinstance(value, dict):
if "extractor_node_id" not in value:
raise ValueError("value must contain extractor_node_id for mention type")
if "output_selector" not in value:
raise ValueError("value must contain output_selector for mention type")
else:
raise ValueError("value must be a MentionValue or dict for mention type")
return typ
tool_parameters: dict[str, ToolInput]

View File

@ -1,7 +1,10 @@
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
logger = logging.getLogger(__name__)
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -89,20 +92,18 @@ class ToolNode(Node[ToolNodeData]):
)
return
# get parameters (use virtual_node_outputs from base class)
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
virtual_node_outputs=self.virtual_node_outputs,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
virtual_node_outputs=self.virtual_node_outputs,
)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@ -178,7 +179,6 @@ class ToolNode(Node[ToolNodeData]):
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
virtual_node_outputs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -188,16 +188,12 @@ class ToolNode(Node[ToolNodeData]):
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
for_log (bool): Whether to generate parameters for logging.
virtual_node_outputs (dict[str, Any] | None): Outputs from virtual sub-nodes.
Maps local_id -> outputs dict. Virtual node outputs are also in variable_pool
with global IDs like "{parent_id}.{local_id}".
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
virtual_node_outputs = virtual_node_outputs or {}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
@ -207,22 +203,39 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
# Check if this references a virtual node output (local ID like [ext_1, text])
if not isinstance(tool_input.value, list):
raise ToolParameterError(f"Invalid variable selector for parameter '{parameter_name}'")
selector = tool_input.value
if len(selector) >= 2 and selector[0] in virtual_node_outputs:
# Reference to virtual node output
local_id = selector[0]
var_name = selector[1]
outputs = virtual_node_outputs.get(local_id, {})
parameter_value = outputs.get(var_name)
variable = variable_pool.get(selector)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {selector} does not exist")
continue
parameter_value = variable.value
elif tool_input.type == "mention":
# Mention type: get value from extractor node's output
from .entities import MentionValue
mention_value = tool_input.value
if isinstance(mention_value, MentionValue):
mention_config = mention_value.model_dump()
elif isinstance(mention_value, dict):
mention_config = mention_value
else:
# Normal variable reference
variable = variable_pool.get(selector)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {selector} does not exist")
raise ToolParameterError(f"Invalid mention value for parameter '{parameter_name}'")
try:
parameter_value, found = variable_pool.resolve_mention(
mention_config, parameter_name=parameter_name
)
if not found and parameter.required:
raise ToolParameterError(
f"Extractor output not found for required parameter '{parameter_name}'"
)
if not found:
continue
parameter_value = variable.value
except ValueError as e:
raise ToolParameterError(str(e)) from e
elif tool_input.type in {"mixed", "constant"}:
template = str(tool_input.value)
segment_group = variable_pool.convert_template(template)
@ -507,8 +520,12 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
if isinstance(input.value, list):
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "mention":
# Mention type handled by extractor node, no direct variable reference
pass
elif input.type == "constant":
pass

View File

@ -268,6 +268,58 @@ class VariablePool(BaseModel):
continue
self.add(selector, value)
def resolve_mention(
self,
mention_config: Mapping[str, Any],
/,
*,
parameter_name: str = "",
) -> tuple[Any, bool]:
"""
Resolve a mention parameter value from an extractor node's output.
Mention parameters reference values extracted by an extractor LLM node
from list[PromptMessage] context.
Args:
mention_config: A dict containing:
- extractor_node_id: ID of the extractor LLM node
- output_selector: Selector path for the output variable (e.g., ["text"])
- null_strategy: "raise_error" or "use_default"
- default_value: Value to use when null_strategy is "use_default"
parameter_name: Name of the parameter being resolved (for error messages)
Returns:
Tuple of (resolved_value, found):
- resolved_value: The extracted value, or default_value if not found
- found: True if value was found, False if using default
Raises:
ValueError: If extractor_node_id is missing, or if null_strategy is
"raise_error" and the value is not found
"""
extractor_node_id = mention_config.get("extractor_node_id")
if not extractor_node_id:
raise ValueError(f"Missing extractor_node_id for mention parameter '{parameter_name}'")
output_selector = list(mention_config.get("output_selector", []))
null_strategy = mention_config.get("null_strategy", "raise_error")
default_value = mention_config.get("default_value")
# Build full selector: [extractor_node_id, ...output_selector]
full_selector = [extractor_node_id] + output_selector
variable = self.get(full_selector)
if variable is None:
if null_strategy == "use_default":
return default_value, False
raise ValueError(
f"Extractor node '{extractor_node_id}' output '{'.'.join(output_selector)}' "
f"not found for parameter '{parameter_name}'"
)
return variable.value, True
@classmethod
def empty(cls) -> VariablePool:
"""Create an empty variable pool."""

View File

@ -4,6 +4,7 @@ from uuid import uuid4
from configs import dify_config
from core.file import File
from core.model_runtime.entities import PromptMessage
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
@ -11,6 +12,7 @@ from core.variables.segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
@ -29,6 +31,7 @@ from core.variables.variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayPromptMessageVariable,
ArrayStringVariable,
BooleanVariable,
FileVariable,
@ -61,6 +64,7 @@ SEGMENT_TO_VARIABLE_MAP = {
ArrayFileSegment: ArrayFileVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayPromptMessageSegment: ArrayPromptMessageVariable,
ArrayStringSegment: ArrayStringVariable,
BooleanSegment: BooleanVariable,
FileSegment: FileVariable,
@ -156,7 +160,13 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, File):
return FileSegment(value=value)
if isinstance(value, PromptMessage):
# Single PromptMessage should be wrapped in a list
return ArrayPromptMessageSegment(value=[value])
if isinstance(value, list):
# Check if all items are PromptMessage
if value and all(isinstance(item, PromptMessage) for item in value):
return ArrayPromptMessageSegment(value=value)
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
if all(isinstance(item, ArraySegment) for item in items):
@ -200,6 +210,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment,
}

View File

@ -1291,7 +1291,7 @@ class WorkflowDraftVariable(Base):
# which may differ from the original value's type. Typically, they are the same,
# but in cases where the structurally truncated value still exceeds the size limit,
# text slicing is applied, and the `value_type` is converted to `STRING`.
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=21))
# The variable's value serialized as a JSON string
#
@ -1665,7 +1665,7 @@ class WorkflowDraftVariableFile(Base):
# The `value_type` field records the type of the original value.
value_type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=20),
EnumText(SegmentType, length=21),
nullable=False,
)

View File

@ -7,6 +7,7 @@ from typing import Any, Generic, TypeAlias, TypeVar, overload
from configs import dify_config
from core.file.models import File
from core.model_runtime.entities import PromptMessage
from core.variables.segments import (
ArrayFileSegment,
ArraySegment,
@ -287,6 +288,10 @@ class VariableTruncator(BaseTruncator):
if isinstance(item, File):
truncated_value.append(item)
continue
# Handle PromptMessage types - convert to dict for truncation
if isinstance(item, PromptMessage):
truncated_value.append(item)
continue
if i >= target_length:
return _PartResult(truncated_value, used_size, True)
if i > 0:

View File

@ -1,77 +0,0 @@
"""
Unit tests for virtual node configuration.
"""
from core.workflow.nodes.base.entities import VirtualNodeConfig
class TestVirtualNodeConfig:
"""Tests for VirtualNodeConfig entity."""
def test_create_basic_config(self):
"""Test creating a basic virtual node config."""
config = VirtualNodeConfig(
id="ext_1",
type="llm",
data={
"title": "Extract keywords",
"model": {"provider": "openai", "name": "gpt-4o-mini"},
},
)
assert config.id == "ext_1"
assert config.type == "llm"
assert config.data["title"] == "Extract keywords"
def test_get_global_id(self):
"""Test generating global ID from parent ID."""
config = VirtualNodeConfig(
id="ext_1",
type="llm",
data={},
)
global_id = config.get_global_id("tool1")
assert global_id == "tool1.ext_1"
def test_get_global_id_with_different_parents(self):
"""Test global ID generation with different parent IDs."""
config = VirtualNodeConfig(id="sub_node", type="code", data={})
assert config.get_global_id("parent1") == "parent1.sub_node"
assert config.get_global_id("node_123") == "node_123.sub_node"
def test_empty_data(self):
"""Test virtual node config with empty data."""
config = VirtualNodeConfig(
id="test",
type="tool",
)
assert config.id == "test"
assert config.type == "tool"
assert config.data == {}
def test_complex_data(self):
"""Test virtual node config with complex data."""
config = VirtualNodeConfig(
id="llm_1",
type="llm",
data={
"title": "Generate summary",
"model": {
"provider": "openai",
"name": "gpt-4",
"mode": "chat",
"completion_params": {"temperature": 0.7, "max_tokens": 500},
},
"prompt_template": [
{"role": "user", "text": "{{#llm1.context#}}"},
{"role": "user", "text": "Please summarize the conversation"},
],
},
)
assert config.data["model"]["provider"] == "openai"
assert len(config.data["prompt_template"]) == 2

View File

@ -25,6 +25,12 @@ class _StubErrorHandler:
"""Minimal error handler stub for tests."""
class _StubNodeData:
"""Simple node data stub with is_extractor_node property."""
is_extractor_node = False
class _StubNode:
"""Simple node stub exposing the attributes needed by the state manager."""
@ -36,6 +42,7 @@ class _StubNode:
self.error_strategy = None
self.retry_config = RetryConfig()
self.retry = False
self.node_data = _StubNodeData()
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]: