Merge main HEAD (segment 5) into sandboxed-agent-rebase

Resolve 83 conflicts: 10 backend, 62 frontend, 11 config/lock files.
Preserve sandbox/agent/collaboration features while adopting main's
UI refactorings (Dialog/AlertDialog/Popover), model provider updates,
and enterprise features.

Made-with: Cursor
This commit is contained in:
Novice
2026-03-23 14:20:06 +08:00
1671 changed files with 124822 additions and 22302 deletions

View File

@ -259,6 +259,9 @@ _END_STATE = frozenset(
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
Values in this enum are persisted as execution metadata and must stay in sync
with every node that writes `NodeRunResult.metadata`.
"""
TOTAL_TOKENS = "total_tokens"
@ -282,6 +285,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
DATASOURCE_INFO = "datasource_info"
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
LLM_TRACE = "llm_trace"
TRIGGER_INFO = "trigger_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
PARENT_NODE_ID = "parent_node_id" # parent node id for nested nodes (extractor nodes)

View File

@ -159,6 +159,7 @@ class ErrorHandler:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
@ -198,6 +199,7 @@ class ErrorHandler:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,

View File

@ -6,6 +6,5 @@ of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
from .session import RESPONSE_SESSION_NODE_TYPES
__all__ = ["RESPONSE_SESSION_NODE_TYPES", "ResponseStreamCoordinator"]
__all__ = ["ResponseStreamCoordinator"]

View File

@ -3,10 +3,6 @@ Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
`RESPONSE_SESSION_NODE_TYPES` is intentionally mutable so downstream applications
can opt additional response-capable node types into session creation without
patching the coordinator.
"""
from __future__ import annotations
@ -14,7 +10,6 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol, cast
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base.template import Template
from dify_graph.runtime.graph_runtime_state import NodeProtocol
@ -25,12 +20,6 @@ class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
def get_streaming_template(self) -> Template: ...
RESPONSE_SESSION_NODE_TYPES: list[NodeType] = [
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.END,
]
@dataclass
class ResponseSession:
"""
@ -49,8 +38,8 @@ class ResponseSession:
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
At runtime this must be a node whose `node_type` is listed in `RESPONSE_SESSION_NODE_TYPES`
and which implements `get_streaming_template()`.
At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which
graph nodes should be treated as response-capable before they reach this factory.
Args:
node: Node from the materialized workflow graph.
@ -59,15 +48,8 @@ class ResponseSession:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node is not a supported response node type.
TypeError: If node does not implement the response-session streaming contract.
"""
if node.node_type not in RESPONSE_SESSION_NODE_TYPES:
supported_node_types = ", ".join(RESPONSE_SESSION_NODE_TYPES)
raise TypeError(
"ResponseSession.from_node only supports node types in "
f"RESPONSE_SESSION_NODE_TYPES: {supported_node_types}"
)
response_node = cast(_ResponseSessionNodeProtocol, node)
try:
template = response_node.get_streaming_template()

View File

@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from dify_graph.context import IExecutionContext
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from libs.datetime_utils import naive_utc_now
from .ready_queue import ReadyQueue
@ -65,6 +68,7 @@ class Worker(threading.Thread):
self._stop_event = threading.Event()
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
self._current_node_started_at: datetime | None = None
def stop(self) -> None:
"""Signal the worker to stop processing."""
@ -104,18 +108,15 @@ class Worker(threading.Thread):
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._current_node_started_at = None
self._execute_node(node)
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),
self._event_queue.put(
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
)
self._event_queue.put(error_event)
finally:
self._current_node_started_at = None
def _execute_node(self, node: Node) -> None:
"""
@ -136,6 +137,8 @@ class Worker(threading.Thread):
try:
node_events = node.run()
for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
@ -149,6 +152,8 @@ class Worker(threading.Thread):
try:
node_events = node.run()
for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
@ -177,3 +182,24 @@ class Worker(threading.Thread):
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _build_fallback_failure_event(
self, node: Node, error: Exception, *, started_at: datetime | None = None
) -> NodeRunFailedEvent:
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
failure_time = naive_utc_now()
error_message = str(error)
return NodeRunFailedEvent(
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=error_message,
start_at=started_at or failure_time,
finished_at=failure_time,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
error_type=type(error).__name__,
),
)

View File

@ -75,16 +75,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
class NodeRunSucceededEvent(GraphNodeEventBase):
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunFailedEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunExceptionEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunRetryEvent(NodeRunStartedEvent):

View File

@ -455,11 +455,13 @@ class Node(Generic[NodeDataT]):
error=str(e),
error_type="WorkflowNodeError",
)
finished_at = naive_utc_now()
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
error=str(e),
)
@ -617,6 +619,7 @@ class Node(Generic[NodeDataT]):
return self._node_data
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
finished_at = naive_utc_now()
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
@ -624,6 +627,7 @@ class Node(Generic[NodeDataT]):
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
error=result.error,
)
@ -633,6 +637,7 @@ class Node(Generic[NodeDataT]):
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
)
case _:
@ -717,6 +722,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
finished_at = naive_utc_now()
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
@ -724,6 +730,7 @@ class Node(Generic[NodeDataT]):
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=event.node_run_result,
)
case WorkflowNodeExecutionStatus.FAILED:
@ -732,6 +739,7 @@ class Node(Generic[NodeDataT]):
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=event.node_run_result,
error=event.node_run_result.error,
)

View File

@ -101,6 +101,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
http_request_config=self._http_request_config,
# Must be 0 to disable executor-level retries, as the graph engine handles them.
# This is critical to prevent nested retries.
max_retries=0,
ssl_verify=self.node_data.ssl_verify,
http_client=self._http_client,

View File

@ -8,6 +8,8 @@ from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from typing import Annotated, Any, ClassVar, Literal, Self
import bleach
import markdown
from pydantic import BaseModel, Field, field_validator, model_validator
from dify_graph.entities.base_node_data import BaseNodeData
@ -58,6 +60,39 @@ class EmailDeliveryConfig(BaseModel):
"""Configuration for email delivery method."""
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
_SUBJECT_NEWLINE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(r"[\r\n]+")
_ALLOWED_HTML_TAGS: ClassVar[list[str]] = [
"a",
"blockquote",
"br",
"code",
"em",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"hr",
"li",
"ol",
"p",
"pre",
"strong",
"table",
"tbody",
"td",
"th",
"thead",
"tr",
"ul",
]
_ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = {
"a": ["href", "title"],
"td": ["align"],
"th": ["align"],
}
_ALLOWED_PROTOCOLS: ClassVar[list[str]] = ["http", "https", "mailto"]
recipients: EmailRecipients
@ -98,6 +133,43 @@ class EmailDeliveryConfig(BaseModel):
return templated_body
return variable_pool.convert_template(templated_body).text
@classmethod
def render_markdown_body(cls, body: str) -> str:
"""Render markdown to safe HTML for email delivery."""
sanitized_markdown = bleach.clean(
body,
tags=[],
attributes={},
strip=True,
strip_comments=True,
)
rendered_html = markdown.markdown(
sanitized_markdown,
extensions=["nl2br", "tables"],
extension_configs={"tables": {"use_align_attribute": True}},
)
return bleach.clean(
rendered_html,
tags=cls._ALLOWED_HTML_TAGS,
attributes=cls._ALLOWED_HTML_ATTRIBUTES,
protocols=cls._ALLOWED_PROTOCOLS,
strip=True,
strip_comments=True,
)
@classmethod
def sanitize_subject(cls, subject: str) -> str:
"""Sanitize email subject to plain text and prevent CRLF injection."""
sanitized_subject = bleach.clean(
subject,
tags=[],
attributes={},
strip=True,
strip_comments=True,
)
sanitized_subject = cls._SUBJECT_NEWLINE_PATTERN.sub(" ", sanitized_subject)
return " ".join(sanitized_subject.split())
class _DeliveryMethodBase(BaseModel):
"""Base delivery method configuration."""

View File

@ -236,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
future_to_index: dict[
Future[
tuple[
datetime,
float,
list[GraphNodeEventBase],
object | None,
dict[str, Variable],
@ -261,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
try:
result = future.result()
(
iter_start_at,
iteration_duration,
events,
output_value,
conversation_snapshot,
@ -274,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# Yield all events from this iteration
yield from events
# Update tokens and timing
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
# The worker computes duration before we replay buffered events here,
# so slow downstream consumers don't inflate per-iteration timing.
iter_run_map[str(index)] = iteration_duration
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
@ -305,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
index: int,
item: object,
execution_context: "IExecutionContext",
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -327,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
conversation_snapshot = self._extract_conversation_variable_snapshot(
variable_pool=graph_engine.graph_runtime_state.variable_pool
)
iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
return (
iter_start_at,
iteration_duration,
events,
output_value,
conversation_snapshot,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, cast
@ -9,38 +11,53 @@ from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
from dify_graph.enums import SystemVariableKey
from dify_graph.file import FileType, file_manager
from dify_graph.file.models import File
from dify_graph.model_runtime.entities import PromptMessageRole
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageContentType,
PromptMessageRole,
TextPromptMessageContent,
ToolPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.llm.entities import LLMGenerationData
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from dify_graph.variables import ArrayFileSegment, FileSegment
from dify_graph.variables.segments import ArrayAnySegment, NoneSegment, StringSegment
from .exc import InvalidVariableTypeError
from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from .exc import (
InvalidVariableTypeError,
MemoryRolePrefixRequiredError,
NoPromptFoundError,
TemplateTypeNotSupportError,
)
from .protocols import TemplateRenderer
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
model_instance.model_name,
model_instance.credentials,
dict(model_instance.credentials),
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
return model_schema
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]:
variable = variable_pool.get(selector)
if variable is None:
return []
@ -285,3 +302,370 @@ def _restore_message_content(message: PromptMessage) -> PromptMessage:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: PromptMessageMemory | None = None,
model_instance: ModelInstance,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
context_files: list[File] | None = None,
template_renderer: TemplateRenderer | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
prompt_messages.extend(
handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
template_renderer=template_renderer,
)
)
prompt_messages.extend(
handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
)
if sys_query:
prompt_messages.extend(
handle_list_messages(
messages=[
LLMNodeChatModelMessage(
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
template_renderer=template_renderer,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
prompt_messages.extend(
handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
)
memory_text = handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
prompt_content = prompt_messages[0].content
if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
content_item.data = memory_text + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
if sys_query:
if isinstance(prompt_content, str):
prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
_append_file_prompts(
prompt_messages=prompt_messages,
files=sys_files,
vision_enabled=vision_enabled,
vision_detail=vision_detail,
)
_append_file_prompts(
prompt_messages=prompt_messages,
files=context_files or [],
vision_enabled=vision_enabled,
vision_detail=vision_detail,
)
filtered_prompt_messages: list[PromptMessage] = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
if not model_schema.features:
if content_item.type == PromptMessageContentType.TEXT:
prompt_message_content.append(content_item)
continue
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if not prompt_message_content:
continue
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
filtered_prompt_messages.append(prompt_message)
elif not prompt_message.is_empty():
filtered_prompt_messages.append(prompt_message)
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
template_renderer: TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = render_jinja2_message(
template=message.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
prompt_messages.append(
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)],
role=message.role,
)
)
continue
template = message.text.replace("{#context#}", context) if context else message.text
segment_group = variable_pool.convert_template(template)
file_contents: list[PromptMessageContentUnionTypes] = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
)
if segment_group.text:
prompt_messages.append(
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=segment_group.text)],
role=message.role,
)
)
if file_contents:
prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role))
return prompt_messages
def render_jinja2_message(
*,
template: str,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
template_renderer: TemplateRenderer | None = None,
) -> str:
if not template:
return ""
if template_renderer is None:
raise ValueError("template_renderer is required for jinja2 prompt rendering")
jinja2_inputs: dict[str, Any] = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs)
def handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
template_renderer: TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
if template.edition_type == "jinja2":
result_text = render_jinja2_message(
template=template.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
else:
template_text = template.text.replace("{#context#}", context) if context else template.text
result_text = variable_pool.convert_template(template_text).text
return [
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)],
role=PromptMessageRole.USER,
)
]
def combine_message_content_with_role(
*,
contents: str | list[PromptMessageContentUnionTypes] | None = None,
role: PromptMessageRole,
) -> PromptMessage:
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
case _:
raise NotImplementedError(f"Role {role} is not supported")
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int:
rest_tokens = 2000
runtime_model_schema = fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def handle_memory_chat_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> Sequence[PromptMessage]:
if not memory or not memory_config:
return []
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
return memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
def handle_memory_completion_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> str:
if not memory or not memory_config:
return ""
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
return fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
def _append_file_prompts(
*,
prompt_messages: list[PromptMessage],
files: Sequence[File],
vision_enabled: bool,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> None:
if not vision_enabled or not files:
return
file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files]
if (
prompt_messages
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
existing_contents = prompt_messages[-1].content
assert isinstance(existing_contents, list)
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))

View File

@ -54,11 +54,10 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.llm_entities import (
@ -69,14 +68,7 @@ from dify_graph.model_runtime.entities.llm_entities import (
LLMStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
@ -95,14 +87,13 @@ from dify_graph.node_events.node import ChunkType, ThoughtEndChunkEvent, Thought
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
ArrayPromptMessageSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
@ -133,9 +124,6 @@ from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMNodeError,
MemoryRolePrefixRequiredError,
NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
@ -162,6 +150,7 @@ class LLMNode(Node[LLMNodeData]):
_model_factory: ModelFactory
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
def __init__(
self,
@ -174,6 +163,7 @@ class LLMNode(Node[LLMNodeData]):
model_factory: ModelFactory,
model_instance: ModelInstance,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@ -190,6 +180,7 @@ class LLMNode(Node[LLMNodeData]):
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
@ -1326,7 +1317,6 @@ class LLMNode(Node[LLMNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=prompt_template,
@ -1338,16 +1328,13 @@ class LLMNode(Node[LLMNodeData]):
)
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if sys_query:
message = LLMNodeChatModelMessage(
text=sys_query,
@ -1365,7 +1352,6 @@ class LLMNode(Node[LLMNodeData]):
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
@ -1375,15 +1361,12 @@ class LLMNode(Node[LLMNodeData]):
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
prompt_content = str(prompt_content)
@ -1403,7 +1386,6 @@ class LLMNode(Node[LLMNodeData]):
else:
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if sys_query:
if prompt_content_type == str:
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
@ -1418,14 +1400,11 @@ class LLMNode(Node[LLMNodeData]):
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
@ -1435,14 +1414,11 @@ class LLMNode(Node[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
@ -1452,20 +1428,17 @@ class LLMNode(Node[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
if (
(
content_item.type == PromptMessageContentType.IMAGE
@ -1680,7 +1653,6 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
@ -2824,7 +2796,6 @@ def _handle_memory_chat_mode(
model_instance: ModelInstance,
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
@ -2844,7 +2815,6 @@ def _handle_memory_completion_mode(
model_instance: ModelInstance,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
@ -2869,17 +2839,6 @@ def _handle_completion_template(
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_manager import ModelInstance
@ -19,3 +20,11 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class TemplateRenderer(Protocol):
"""Port for rendering prompt templates used by LLM-compatible nodes."""
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
"""Render the given Jinja2 template into plain text."""
...

View File

@ -28,7 +28,7 @@ from dify_graph.nodes.llm import (
llm_utils,
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -59,6 +59,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
def __init__(
self,
@ -71,6 +72,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_factory: "ModelFactory",
model_instance: ModelInstance,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@ -87,6 +89,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
@ -141,7 +144,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_messages, stop = llm_utils.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
@ -152,6 +155,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
template_renderer=self._template_renderer,
)
result_text = ""
@ -291,7 +295,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages, _ = LLMNode.fetch_prompt_messages(
prompt_messages, _ = llm_utils.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
sys_files=[],
@ -304,6 +308,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
template_renderer=self._template_renderer,
)
rest_tokens = 2000

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from dify_graph.file.models import File
if TYPE_CHECKING:
pass
from dify_graph.variables.segments import Segment
class ArrayValidation(StrEnum):
@ -220,7 +220,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: SegmentType):
def get_zero_value(t: SegmentType) -> Segment:
# Lazy import to avoid circular dependency
from factories import variable_factory