Merge branch 'main' into feat/grouping-branching

# Conflicts:
#	web/package.json
This commit is contained in:
zhsama
2026-01-06 22:00:01 +08:00
156 changed files with 5890 additions and 1553 deletions

View File

@ -20,6 +20,8 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
@ -40,6 +42,7 @@ from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__)
@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
conversation_variable_layer = ConversationVariablePersistenceLayer(
ConversationVariableUpdater(session_factory.get_session_maker())
)
workflow_entry.graph_engine.layer(conversation_variable_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)

View File

@ -75,7 +75,7 @@ class AnnotationReplyFeature:
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question,
annotation.question_text,
annotation.content,
query,
user_id,

View File

@ -0,0 +1,60 @@
import logging
from core.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
logger = logging.getLogger(__name__)
class ConversationVariablePersistenceLayer(GraphEngineLayer):
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
super().__init__()
self._conversation_variable_updater = conversation_variable_updater
def on_graph_start(self) -> None:
pass
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunSucceededEvent):
return
if event.node_type != NodeType.VARIABLE_ASSIGNER:
return
if self.graph_runtime_state is None:
return
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
if not updated_variables:
return
conversation_id = self.graph_runtime_state.system_variable.conversation_id
if conversation_id is None:
return
updated_any = False
for item in updated_variables:
selector = item.selector
if len(selector) < 2:
logger.warning("Conversation variable selector invalid. selector=%s", selector)
continue
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,
)
continue
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
updated_any = True
if updated_any:
self._conversation_variable_updater.flush()
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
super().__init__()
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
if not isinstance(event, GraphRunPausedEvent):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)

View File

@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer):
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer):
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id

View File

@ -1,7 +1,7 @@
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
_session_maker: sessionmaker | None = None
_session_maker: sessionmaker[Session] | None = None
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
def get_session_maker() -> sessionmaker:
def get_session_maker() -> sessionmaker[Session]:
if _session_maker is None:
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
return _session_maker
@ -27,7 +27,7 @@ class SessionFactory:
configure_session_factory(engine, expire_on_commit)
@staticmethod
def get_session_maker() -> sessionmaker:
def get_session_maker() -> sessionmaker[Session]:
return get_session_maker()
@staticmethod

View File

@ -8,8 +8,9 @@ import urllib.parse
from configs import dify_config
def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
url = f"{base_url}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()

View File

@ -112,17 +112,17 @@ class File(BaseModel):
return text
def generate_url(self) -> str | None:
def generate_url(self, for_external: bool = True) -> str | None:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=self.related_id)
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
assert self.related_id is not None
assert self.extension is not None
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
return None
def to_plugin_parameter(self) -> dict[str, Any]:
@ -133,7 +133,7 @@ class File(BaseModel):
"extension": self.extension,
"size": self.size,
"type": self.type,
"url": self.generate_url(),
"url": self.generate_url(for_external=False),
}
@model_validator(mode="after")

View File

@ -76,7 +76,7 @@ class TemplateTransformer(ABC):
Post-process the result to convert scientific notation strings back to numbers
"""
def convert_scientific_notation(value):
def convert_scientific_notation(value: Any) -> Any:
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
@ -90,7 +90,7 @@ class TemplateTransformer(ABC):
return [convert_scientific_notation(v) for v in value]
return value
return convert_scientific_notation(result) # type: ignore[no-any-return]
return convert_scientific_notation(result)
@classmethod
@abstractmethod

View File

@ -984,9 +984,11 @@ class ClickzettaVector(BaseVector):
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
escaped_query = escape_like_pattern(query).replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""

View File

@ -287,11 +287,15 @@ class IrisVector(BaseVector):
cursor.execute(sql, (query,))
else:
# Fallback to LIKE search (inefficient for large datasets)
query_pattern = f"%{query}%"
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
escaped_query = escape_like_pattern(query)
query_pattern = f"%{escaped_query}%"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE text LIKE ?
WHERE text LIKE ? ESCAPE '\\'
"""
cursor.execute(sql, (query_pattern,))

View File

@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
in a Weaviate collection.
"""
_DOCUMENT_ID_PROPERTY = "document_id"
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))

View File

@ -7,10 +7,11 @@ import re
import tempfile
import uuid
from urllib.parse import urlparse
from xml.etree import ElementTree
import httpx
from docx import Document as DocxDocument
from docx.oxml.ns import qn
from docx.text.run import Run
from configs import dify_config
from core.helper import ssrf_proxy
@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
image_map = self._extract_images_from_docx(doc)
hyperlinks_url = None
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
for para in doc.paragraphs:
for run in para.runs:
if run.text and hyperlinks_url:
result = f" [{run.text}]({hyperlinks_url}) "
run.text = result
hyperlinks_url = None
if "HYPERLINK" in run.element.xml:
try:
xml = ElementTree.XML(run.element.xml)
x_child = [c for c in xml.iter() if c is not None]
for x in x_child:
if x is None:
continue
if x.tag.endswith("instrText"):
if x.text is None:
continue
for i in url_pattern.findall(x.text):
hyperlinks_url = str(i)
except Exception:
logger.exception("Failed to parse HYPERLINK xml")
def parse_paragraph(paragraph):
paragraph_content = []
def append_image_link(image_id, has_drawing):
def append_image_link(image_id, has_drawing, target_buffer):
"""Helper to append image link from image_map based on relationship type."""
rel = doc.part.rels[image_id]
if rel.is_external:
if image_id in image_map and not has_drawing:
paragraph_content.append(image_map[image_id])
target_buffer.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
target_buffer.append(image_map[image_part])
for run in paragraph.runs:
def process_run(run, target_buffer):
# Helper to extract text and embedded images from a run element and append them to target_buffer
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
drawing_elements = run.element.findall(
@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
# External image: use embed_id as key
if embed_id in image_map:
has_drawing = True
paragraph_content.append(image_map[embed_id])
target_buffer.append(image_map[embed_id])
else:
# Internal image: use target_part as key
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
target_buffer.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
append_image_link(image_id, has_drawing)
append_image_link(image_id, has_drawing, target_buffer)
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
append_image_link(image_id, has_drawing)
append_image_link(image_id, has_drawing, target_buffer)
if run.text.strip():
paragraph_content.append(run.text.strip())
target_buffer.append(run.text.strip())
def process_hyperlink(hyperlink_elem, target_buffer):
# Helper to extract text from a hyperlink element and append it to target_buffer
r_id = hyperlink_elem.get(qn("r:id"))
# Extract text from runs inside the hyperlink
link_text_parts = []
for run_elem in hyperlink_elem.findall(qn("w:r")):
run = Run(run_elem, paragraph)
# Hyperlink text may be split across multiple runs (e.g., with different formatting),
# so collect all run texts first
if run.text:
link_text_parts.append(run.text)
link_text = "".join(link_text_parts).strip()
# Resolve URL
if r_id:
try:
rel = doc.part.rels.get(r_id)
if rel and rel.is_external:
link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
except Exception:
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
if link_text:
target_buffer.append(link_text)
paragraph_content = []
# State for legacy HYPERLINK fields
hyperlink_field_url = None
hyperlink_field_text_parts: list = []
is_collecting_field_text = False
# Iterate through paragraph elements in document order
for child in paragraph._element:
tag = child.tag
if tag == qn("w:r"):
# Regular run
run = Run(child, paragraph)
# Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
fld_chars = child.findall(qn("w:fldChar"))
instr_texts = child.findall(qn("w:instrText"))
# Handle Fields
if fld_chars or instr_texts:
# Process instrText to find HYPERLINK "url"
for instr in instr_texts:
if instr.text and "HYPERLINK" in instr.text:
# Quick regex to extract URL
match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
if match:
hyperlink_field_url = match.group(1)
# Process fldChar
for fld_char in fld_chars:
fld_char_type = fld_char.get(qn("w:fldCharType"))
if fld_char_type == "begin":
# Start of a field: reset legacy link state
hyperlink_field_url = None
hyperlink_field_text_parts = []
is_collecting_field_text = False
elif fld_char_type == "separate":
# Separator: if we found a URL, start collecting visible text
if hyperlink_field_url:
is_collecting_field_text = True
elif fld_char_type == "end":
# End of field
if is_collecting_field_text and hyperlink_field_url:
# Create markdown link and append to main content
display_text = "".join(hyperlink_field_text_parts).strip()
if display_text:
link_md = f"[{display_text}]({hyperlink_field_url})"
paragraph_content.append(link_md)
# Reset state
hyperlink_field_url = None
hyperlink_field_text_parts = []
is_collecting_field_text = False
# Decide where to append content
target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
process_run(run, target_buffer)
elif tag == qn("w:hyperlink"):
process_hyperlink(child, paragraph_content)
return "".join(paragraph_content) if paragraph_content else ""
paragraphs = doc.paragraphs.copy()

View File

@ -1198,18 +1198,24 @@ class DatasetRetrieval:
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
from libs.helper import escape_like_pattern
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
case "start with":
filters.append(json_field.like(f"{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
case "end with":
filters.append(json_field.like(f"%{value}"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
case "is" | "=":
if isinstance(value, str):
@ -1474,38 +1480,38 @@ class DatasetRetrieval:
if cancel_event and cancel_event.is_set():
break
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()

View File

@ -7,12 +7,12 @@ import time
from configs import dify_config
def sign_tool_file(tool_file_id: str, extension: str) -> str:
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""
sign file to get a temporary url for plugin access
"""
# Use internal URL for plugin/tool file access in Docker environments
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
# Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))

View File

@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
can assume `graph_runtime_state` is available.
### Event-Driven Architecture
All node executions emit events for monitoring and integration:

View File

@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
if command_type == CommandType.UPDATE_VARIABLES:
return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)

View File

@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
"UpdateVariablesCommandHandler",
]

View File

@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
@final
class UpdateVariablesCommandHandler(CommandHandler):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, UpdateVariablesCommand)
for update in command.updates:
try:
variable = update.value
self._variable_pool.add(variable.selector, variable)
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
except ValueError as exc:
logger.warning(
"Skipping invalid variable selector %s for workflow %s: %s",
getattr(update.value, "selector", None),
execution.workflow_id,
exc,
)

View File

@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
ABORT = auto()
PAUSE = auto()
UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
"""Command to update a group of variables in the variable pool."""
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")

View File

@ -8,6 +8,7 @@ Domain-Driven Design principles for improved maintainability and testability.
import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@ -30,8 +31,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .command_processing import (
AbortCommandHandler,
CommandProcessor,
PauseCommandHandler,
UpdateVariablesCommandHandler,
)
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@ -70,10 +76,13 @@ class GraphEngine:
scale_down_idle_time: float | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
@ -140,6 +149,9 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
@ -169,6 +181,7 @@ class GraphEngine:
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
stop_event=self._stop_event,
)
# === Orchestration ===
@ -199,6 +212,7 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@ -212,9 +226,16 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def _bind_layer_context(
self,
layer: GraphEngineLayer,
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
@ -301,14 +322,7 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
# Create a read-only wrapper for the runtime state
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
try:
layer.initialize(read_only_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
try:
layer.on_graph_start()
except Exception as e:
@ -316,6 +330,7 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
@ -343,13 +358,12 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
logger = logging.getLogger(__name__)
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)

View File

@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```
`engine.layer()` binds the read-only runtime state before execution, so
`graph_runtime_state` is always available inside layer hooks.
## Custom Layers
```python

View File

@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayerNotInitializedError(Exception):
"""Raised when a layer's runtime state is accessed before initialization."""
def __init__(self, layer_name: str | None = None) -> None:
name = layer_name or "GraphEngineLayer"
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
@property
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
if self._graph_runtime_state is None:
raise GraphEngineLayerNotInitializedError(type(self).__name__)
return self._graph_runtime_state
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Called by GraphEngine to inject the read-only runtime state and command channel.
This is invoked when the layer is registered with a `GraphEngine` instance.
Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod

View File

@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
if self.include_outputs and self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop and pause operations.
"""
@staticmethod
@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""

View File

@ -44,6 +44,7 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@ -61,7 +62,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._stop_event = stop_event
self._start_time: float | None = None
def start(self) -> None:
@ -69,16 +70,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
self._thread.join(timeout=2.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""

View File

@ -42,6 +42,7 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
@ -65,13 +66,16 @@ class Worker(threading.Thread):
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
self._stop_event = stop_event
self._layers = layers if layers is not None else []
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
@property
def is_idle(self) -> bool:

View File

@ -41,6 +41,7 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
@ -81,6 +82,7 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@ -135,7 +137,7 @@ class WorkerPool:
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
worker.join(timeout=2.0)
self._workers.clear()
@ -152,6 +154,7 @@ class WorkerPool:
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
stop_event=self._stop_event,
)
worker.start()

View File

@ -264,6 +264,10 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@ -332,6 +336,21 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(

View File

@ -11,6 +11,11 @@ from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -37,6 +42,7 @@ class DifyNodeFactory(NodeFactory):
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@ -54,6 +60,7 @@ class DifyNodeFactory(NodeFactory):
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -106,6 +113,14 @@ class DifyNodeFactory(NodeFactory):
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
)
return node_class(
id=node_id,

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: type[CodeExecutor]
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc
rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod

View File

@ -1,28 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from extensions.ext_database import db
from models import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def flush(self):
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()

View File

@ -1,9 +1,8 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
def __init__(
self,
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
):
super().__init__(
id=id,
@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._conv_var_updater_factory = conv_var_updater_factory
@classmethod
def version(cls) -> str:
@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found")
conv_var_updater = self._conv_var_updater_factory()
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater.flush()
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={

View File

@ -1,24 +1,20 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, cast
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidDataError,
InvalidInputValueError,
@ -26,6 +22,10 @@ from .exc import (
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
return False
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@classmethod
def version(cls) -> str:
return "2"
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
conv_var_updater = self._conv_var_updater_factory()
# Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
if self.invoke_from != InvokeFrom.DEBUGGER:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
conv_var_updater.update(
conversation_id=cast(str, conversation_id),
variable=variable,
)
conv_var_updater.flush()
updated_variables = [
common_helpers.variable_to_processed_data(selector, seg)
for selector in updated_variable_selectors

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import importlib
import json
import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
@ -168,6 +169,7 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self.stop_event: threading.Event = threading.Event()
if graph is not None:
self.attach_graph(graph)

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Segment | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""Get a variable value (read-only)."""
...

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any
@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""Return a copy of a variable value if present."""
value = self._variable_pool.get([node_id, variable_key])
value = self._variable_pool.get(selector)
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]: