mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
Merge branch 'main' into feat/grouping-branching
# Conflicts: # web/package.json
This commit is contained in:
@ -20,6 +20,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
@ -40,6 +42,7 @@ from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
)
|
||||
workflow_entry.graph_engine.layer(conversation_variable_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ class AnnotationReplyFeature:
|
||||
AppAnnotationService.add_annotation_history(
|
||||
annotation.id,
|
||||
app_record.id,
|
||||
annotation.question,
|
||||
annotation.question_text,
|
||||
annotation.content,
|
||||
query,
|
||||
user_id,
|
||||
|
||||
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
@ -0,0 +1,60 @@
|
||||
import logging
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
|
||||
super().__init__()
|
||||
self._conversation_variable_updater = conversation_variable_updater
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunSucceededEvent):
|
||||
return
|
||||
if event.node_type != NodeType.VARIABLE_ASSIGNER:
|
||||
return
|
||||
if self.graph_runtime_state is None:
|
||||
return
|
||||
|
||||
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
|
||||
if not updated_variables:
|
||||
return
|
||||
|
||||
conversation_id = self.graph_runtime_state.system_variable.conversation_id
|
||||
if conversation_id is None:
|
||||
return
|
||||
|
||||
updated_any = False
|
||||
for item in updated_variables:
|
||||
selector = item.selector
|
||||
if len(selector) < 2:
|
||||
logger.warning("Conversation variable selector invalid. selector=%s", selector)
|
||||
continue
|
||||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
continue
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
logger.warning(
|
||||
"Conversation variable not found in variable pool. selector=%s",
|
||||
selector,
|
||||
)
|
||||
continue
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
|
||||
updated_any = True
|
||||
|
||||
if updated_any:
|
||||
self._conversation_variable_updater.flush()
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
"""
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
super().__init__()
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
self._generate_entity = generate_entity
|
||||
@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
|
||||
entity_wrapper: _GenerateEntityUnion
|
||||
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
||||
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
||||
|
||||
@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
trigger_log_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
):
|
||||
super().__init__()
|
||||
self.trigger_log_id = trigger_log_id
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if not self.graph_runtime_state:
|
||||
logger.exception("Graph runtime state is not set")
|
||||
return
|
||||
|
||||
outputs = self.graph_runtime_state.outputs
|
||||
|
||||
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
_session_maker: sessionmaker | None = None
|
||||
_session_maker: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||
|
||||
|
||||
def get_session_maker() -> sessionmaker:
|
||||
def get_session_maker() -> sessionmaker[Session]:
|
||||
if _session_maker is None:
|
||||
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
|
||||
return _session_maker
|
||||
@ -27,7 +27,7 @@ class SessionFactory:
|
||||
configure_session_factory(engine, expire_on_commit)
|
||||
|
||||
@staticmethod
|
||||
def get_session_maker() -> sessionmaker:
|
||||
def get_session_maker() -> sessionmaker[Session]:
|
||||
return get_session_maker()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -8,8 +8,9 @@ import urllib.parse
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
|
||||
@ -112,17 +112,17 @@ class File(BaseModel):
|
||||
|
||||
return text
|
||||
|
||||
def generate_url(self) -> str | None:
|
||||
def generate_url(self, for_external: bool = True) -> str | None:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
|
||||
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
|
||||
return None
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
@ -133,7 +133,7 @@ class File(BaseModel):
|
||||
"extension": self.extension,
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"url": self.generate_url(),
|
||||
"url": self.generate_url(for_external=False),
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -76,7 +76,7 @@ class TemplateTransformer(ABC):
|
||||
Post-process the result to convert scientific notation strings back to numbers
|
||||
"""
|
||||
|
||||
def convert_scientific_notation(value):
|
||||
def convert_scientific_notation(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
# Check if the string looks like scientific notation
|
||||
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||
@ -90,7 +90,7 @@ class TemplateTransformer(ABC):
|
||||
return [convert_scientific_notation(v) for v in value]
|
||||
return value
|
||||
|
||||
return convert_scientific_notation(result) # type: ignore[no-any-return]
|
||||
return convert_scientific_notation(result)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@ -984,9 +984,11 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query).replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
|
||||
@ -287,11 +287,15 @@ class IrisVector(BaseVector):
|
||||
cursor.execute(sql, (query,))
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
query_pattern = f"%{query}%"
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query)
|
||||
query_pattern = f"%{escaped_query}%"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ?
|
||||
WHERE text LIKE ? ESCAPE '\\'
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
|
||||
@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
|
||||
in a Weaviate collection.
|
||||
"""
|
||||
|
||||
_DOCUMENT_ID_PROPERTY = "document_id"
|
||||
|
||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||
"""
|
||||
Initializes the Weaviate vector store.
|
||||
@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
|
||||
return []
|
||||
|
||||
col = self._client.collections.use(self._collection_name)
|
||||
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
|
||||
props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
|
||||
|
||||
where = None
|
||||
doc_ids = kwargs.get("document_ids_filter") or []
|
||||
if doc_ids:
|
||||
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
|
||||
where = ors[0]
|
||||
for f in ors[1:]:
|
||||
where = where | f
|
||||
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
|
||||
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
|
||||
where = None
|
||||
doc_ids = kwargs.get("document_ids_filter") or []
|
||||
if doc_ids:
|
||||
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
|
||||
where = ors[0]
|
||||
for f in ors[1:]:
|
||||
where = where | f
|
||||
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
|
||||
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
|
||||
|
||||
@ -7,10 +7,11 @@ import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import httpx
|
||||
from docx import Document as DocxDocument
|
||||
from docx.oxml.ns import qn
|
||||
from docx.text.run import Run
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
image_map = self._extract_images_from_docx(doc)
|
||||
|
||||
hyperlinks_url = None
|
||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
||||
for para in doc.paragraphs:
|
||||
for run in para.runs:
|
||||
if run.text and hyperlinks_url:
|
||||
result = f" [{run.text}]({hyperlinks_url}) "
|
||||
run.text = result
|
||||
hyperlinks_url = None
|
||||
if "HYPERLINK" in run.element.xml:
|
||||
try:
|
||||
xml = ElementTree.XML(run.element.xml)
|
||||
x_child = [c for c in xml.iter() if c is not None]
|
||||
for x in x_child:
|
||||
if x is None:
|
||||
continue
|
||||
if x.tag.endswith("instrText"):
|
||||
if x.text is None:
|
||||
continue
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse HYPERLINK xml")
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
def append_image_link(image_id, has_drawing):
|
||||
def append_image_link(image_id, has_drawing, target_buffer):
|
||||
"""Helper to append image link from image_map based on relationship type."""
|
||||
rel = doc.part.rels[image_id]
|
||||
if rel.is_external:
|
||||
if image_id in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
target_buffer.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
target_buffer.append(image_map[image_part])
|
||||
|
||||
for run in paragraph.runs:
|
||||
def process_run(run, target_buffer):
|
||||
# Helper to extract text and embedded images from a run element and append them to target_buffer
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
drawing_elements = run.element.findall(
|
||||
@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
|
||||
# External image: use embed_id as key
|
||||
if embed_id in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[embed_id])
|
||||
target_buffer.append(image_map[embed_id])
|
||||
else:
|
||||
# Internal image: use target_part as key
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
target_buffer.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
append_image_link(image_id, has_drawing, target_buffer)
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
append_image_link(image_id, has_drawing, target_buffer)
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
target_buffer.append(run.text.strip())
|
||||
|
||||
def process_hyperlink(hyperlink_elem, target_buffer):
|
||||
# Helper to extract text from a hyperlink element and append it to target_buffer
|
||||
r_id = hyperlink_elem.get(qn("r:id"))
|
||||
|
||||
# Extract text from runs inside the hyperlink
|
||||
link_text_parts = []
|
||||
for run_elem in hyperlink_elem.findall(qn("w:r")):
|
||||
run = Run(run_elem, paragraph)
|
||||
# Hyperlink text may be split across multiple runs (e.g., with different formatting),
|
||||
# so collect all run texts first
|
||||
if run.text:
|
||||
link_text_parts.append(run.text)
|
||||
|
||||
link_text = "".join(link_text_parts).strip()
|
||||
|
||||
# Resolve URL
|
||||
if r_id:
|
||||
try:
|
||||
rel = doc.part.rels.get(r_id)
|
||||
if rel and rel.is_external:
|
||||
link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
|
||||
|
||||
if link_text:
|
||||
target_buffer.append(link_text)
|
||||
|
||||
paragraph_content = []
|
||||
# State for legacy HYPERLINK fields
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts: list = []
|
||||
is_collecting_field_text = False
|
||||
# Iterate through paragraph elements in document order
|
||||
for child in paragraph._element:
|
||||
tag = child.tag
|
||||
if tag == qn("w:r"):
|
||||
# Regular run
|
||||
run = Run(child, paragraph)
|
||||
|
||||
# Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
|
||||
fld_chars = child.findall(qn("w:fldChar"))
|
||||
instr_texts = child.findall(qn("w:instrText"))
|
||||
|
||||
# Handle Fields
|
||||
if fld_chars or instr_texts:
|
||||
# Process instrText to find HYPERLINK "url"
|
||||
for instr in instr_texts:
|
||||
if instr.text and "HYPERLINK" in instr.text:
|
||||
# Quick regex to extract URL
|
||||
match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
|
||||
if match:
|
||||
hyperlink_field_url = match.group(1)
|
||||
|
||||
# Process fldChar
|
||||
for fld_char in fld_chars:
|
||||
fld_char_type = fld_char.get(qn("w:fldCharType"))
|
||||
if fld_char_type == "begin":
|
||||
# Start of a field: reset legacy link state
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts = []
|
||||
is_collecting_field_text = False
|
||||
elif fld_char_type == "separate":
|
||||
# Separator: if we found a URL, start collecting visible text
|
||||
if hyperlink_field_url:
|
||||
is_collecting_field_text = True
|
||||
elif fld_char_type == "end":
|
||||
# End of field
|
||||
if is_collecting_field_text and hyperlink_field_url:
|
||||
# Create markdown link and append to main content
|
||||
display_text = "".join(hyperlink_field_text_parts).strip()
|
||||
if display_text:
|
||||
link_md = f"[{display_text}]({hyperlink_field_url})"
|
||||
paragraph_content.append(link_md)
|
||||
# Reset state
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts = []
|
||||
is_collecting_field_text = False
|
||||
|
||||
# Decide where to append content
|
||||
target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
|
||||
process_run(run, target_buffer)
|
||||
elif tag == qn("w:hyperlink"):
|
||||
process_hyperlink(child, paragraph_content)
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
paragraphs = doc.paragraphs.copy()
|
||||
|
||||
@ -1198,18 +1198,24 @@ class DatasetRetrieval:
|
||||
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
@ -1474,38 +1480,38 @@ class DatasetRetrieval:
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
|
||||
# Skip second reranking when there is only one dataset
|
||||
if reranking_enable and dataset_count > 1:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
# Skip second reranking when there is only one dataset
|
||||
if reranking_enable and dataset_count > 1:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
except Exception as e:
|
||||
if cancel_event:
|
||||
cancel_event.set()
|
||||
|
||||
@ -7,12 +7,12 @@ import time
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
"""
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
# Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
|
||||
engine.layer(ExecutionLimitsLayer(max_nodes=100))
|
||||
```
|
||||
|
||||
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
|
||||
can assume `graph_runtime_state` is available.
|
||||
|
||||
### Event-Driven Architecture
|
||||
|
||||
All node executions emit events for monitoring and integration:
|
||||
|
||||
@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
@ -113,6 +113,8 @@ class RedisChannel:
|
||||
return AbortCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
if command_type == CommandType.UPDATE_VARIABLES:
|
||||
return UpdateVariablesCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
@ -5,11 +5,12 @@ This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
"UpdateVariablesCommandHandler",
|
||||
]
|
||||
|
||||
@ -4,9 +4,10 @@ from typing import final
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
|
||||
|
||||
@final
|
||||
class UpdateVariablesCommandHandler(CommandHandler):
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, UpdateVariablesCommand)
|
||||
for update in command.updates:
|
||||
try:
|
||||
variable = update.value
|
||||
self._variable_pool.add(variable.selector, variable)
|
||||
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Skipping invalid variable selector %s for workflow %s: %s",
|
||||
getattr(update.value, "selector", None),
|
||||
execution.workflow_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.variables.variables import VariableUnion
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
ABORT = auto()
|
||||
PAUSE = auto()
|
||||
UPDATE_VARIABLES = auto()
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
|
||||
|
||||
class VariableUpdate(BaseModel):
|
||||
"""Represents a single variable update instruction."""
|
||||
|
||||
value: VariableUnion = Field(description="New variable value")
|
||||
|
||||
|
||||
class UpdateVariablesCommand(GraphEngineCommand):
|
||||
"""Command to update a group of variables in the variable pool."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
|
||||
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")
|
||||
|
||||
@ -8,6 +8,7 @@ Domain-Driven Design principles for improved maintainability and testability.
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@ -30,8 +31,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
|
||||
from .entities.commands import AbortCommand, PauseCommand
|
||||
from .command_processing import (
|
||||
AbortCommandHandler,
|
||||
CommandProcessor,
|
||||
PauseCommandHandler,
|
||||
UpdateVariablesCommandHandler,
|
||||
)
|
||||
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
@ -70,10 +76,13 @@ class GraphEngine:
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
@ -140,6 +149,9 @@ class GraphEngine:
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
|
||||
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
@ -169,6 +181,7 @@ class GraphEngine:
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@ -199,6 +212,7 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@ -212,9 +226,16 @@ class GraphEngine:
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
def _bind_layer_context(
|
||||
self,
|
||||
layer: GraphEngineLayer,
|
||||
) -> None:
|
||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
||||
|
||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
||||
"""Add a layer for extending functionality."""
|
||||
self._layers.append(layer)
|
||||
self._bind_layer_context(layer)
|
||||
return self
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
@ -301,14 +322,7 @@ class GraphEngine:
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self._event_manager.set_layers(self._layers)
|
||||
# Create a read-only wrapper for the runtime state
|
||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(read_only_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception as e:
|
||||
@ -316,6 +330,7 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
@ -343,13 +358,12 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
|
||||
@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
|
||||
|
||||
Abstract base class for layers.
|
||||
|
||||
- `initialize()` - Receive runtime context
|
||||
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
|
||||
- `on_graph_start()` - Execution start hook
|
||||
- `on_event()` - Process all events
|
||||
- `on_graph_end()` - Execution end hook
|
||||
@ -34,6 +34,9 @@ engine.layer(debug_layer)
|
||||
engine.run()
|
||||
```
|
||||
|
||||
`engine.layer()` binds the read-only runtime state before execution, so
|
||||
`graph_runtime_state` is always available inside layer hooks.
|
||||
|
||||
## Custom Layers
|
||||
|
||||
```python
|
||||
|
||||
@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
class GraphEngineLayerNotInitializedError(Exception):
|
||||
"""Raised when a layer's runtime state is accessed before initialization."""
|
||||
|
||||
def __init__(self, layer_name: str | None = None) -> None:
|
||||
name = layer_name or "GraphEngineLayer"
|
||||
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
@property
|
||||
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
|
||||
if self._graph_runtime_state is None:
|
||||
raise GraphEngineLayerNotInitializedError(type(self).__name__)
|
||||
return self._graph_runtime_state
|
||||
|
||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine before execution starts to inject the read-only runtime state
|
||||
and command channel. This allows layers to observe engine context and send
|
||||
commands, but prevents direct state modification.
|
||||
|
||||
Called by GraphEngine to inject the read-only runtime state and command channel.
|
||||
This is invoked when the layer is registered with a `GraphEngine` instance.
|
||||
Implementations should be idempotent.
|
||||
Args:
|
||||
graph_runtime_state: Read-only view of the runtime state
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.info("=" * 80)
|
||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if self.graph_runtime_state:
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.info(" Node retries: %s", self.retry_count)
|
||||
|
||||
# Log final state if available
|
||||
if self.graph_runtime_state and self.include_outputs:
|
||||
if self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
if self.include_outputs and self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
if update_finished:
|
||||
execution.finished_at = naive_utc_now()
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return {}
|
||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
PauseCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -23,7 +29,6 @@ class GraphEngineManager:
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop and pause operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -45,6 +50,16 @@ class GraphEngineManager:
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
GraphEngineManager._send_command(task_id, update_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
@ -44,6 +44,7 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -61,7 +62,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event = stop_event
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@ -69,16 +70,14 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
|
||||
@ -42,6 +42,7 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
@ -65,13 +66,16 @@ class Worker(threading.Thread):
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@ -41,6 +41,7 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
@ -81,6 +82,7 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@ -135,7 +137,7 @@ class WorkerPool:
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
@ -152,6 +154,7 @@ class WorkerPool:
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@ -264,6 +264,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@ -332,6 +336,21 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
|
||||
@ -11,6 +11,11 @@ from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
@ -37,6 +42,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
@ -54,6 +60,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
@ -106,6 +113,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
|
||||
_code_executor: type[CodeExecutor]
|
||||
|
||||
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
|
||||
self._code_executor = code_executor or CodeExecutor
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as exc:
|
||||
raise TemplateRenderError(str(exc)) from exc
|
||||
|
||||
rendered = result.get("result")
|
||||
if not isinstance(rendered, str):
|
||||
raise TemplateRenderError("Template render result must be a string.")
|
||||
return rendered
|
||||
@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
from .exc import VariableOperatorNodeError
|
||||
|
||||
|
||||
class ConversationVariableUpdaterImpl:
|
||||
def update(self, conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||
return ConversationVariableUpdaterImpl()
|
||||
@ -1,9 +1,8 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
||||
from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._conv_var_updater_factory = conv_var_updater_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
|
||||
@ -1,24 +1,20 @@
|
||||
import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
||||
from . import helpers
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidDataError,
|
||||
InvalidInputValueError,
|
||||
@ -26,6 +22,10 @@ from .exc import (
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
selector_node_id = item.variable_selector[0]
|
||||
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
|
||||
return False
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
# remove the duplicated items first.
|
||||
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
||||
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
# Update variables
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
raise VariableNotFoundError(variable_selector=selector)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
if self.invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
conv_var_updater.update(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [
|
||||
common_helpers.variable_to_processed_data(selector, seg)
|
||||
for selector in updated_variable_selectors
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -168,6 +169,7 @@ class GraphRuntimeState:
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""Return a copy of a variable value if present."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
value = self._variable_pool.get(selector)
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
|
||||
Reference in New Issue
Block a user