Merge branch 'main' into feat/agent-node-v2

This commit is contained in:
Novice
2026-01-07 17:34:23 +08:00
802 changed files with 41190 additions and 6172 deletions

View File

@ -20,6 +20,8 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
@ -40,6 +42,7 @@ from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__)
@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
conversation_variable_layer = ConversationVariablePersistenceLayer(
ConversationVariableUpdater(session_factory.get_session_maker())
)
workflow_entry.graph_engine.layer(conversation_variable_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)

View File

@ -471,6 +471,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
# For ANSWER nodes, check if we need to send a message_replace event
# Only send if the final output differs from the accumulated task_state.answer
# This happens when variables were updated by variable_assigner during workflow execution
if event.node_type == NodeType.ANSWER and event.outputs:
final_answer = event.outputs.get("answer")
if final_answer is not None and final_answer != self._task_state.answer:
logger.info(
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
final_answer,
self._task_state.answer,
)
# Update the task state answer
self._task_state.answer = str(final_answer)
# Send message_replace event to update the UI
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=str(final_answer),
reason="variable_update",
)
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],

View File

@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents: list[Document] = []
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator):
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = args.get("original_document_id") or None
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,

View File

@ -42,7 +42,8 @@ class InvokeFrom(StrEnum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
PUBLISHED = "published"
# PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow.
PUBLISHED_PIPELINE = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"

View File

@ -75,7 +75,7 @@ class AnnotationReplyFeature:
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question,
annotation.question_text,
annotation.content,
query,
user_id,

View File

@ -0,0 +1,60 @@
import logging
from core.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
logger = logging.getLogger(__name__)
class ConversationVariablePersistenceLayer(GraphEngineLayer):
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
super().__init__()
self._conversation_variable_updater = conversation_variable_updater
def on_graph_start(self) -> None:
pass
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunSucceededEvent):
return
if event.node_type != NodeType.VARIABLE_ASSIGNER:
return
if self.graph_runtime_state is None:
return
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
if not updated_variables:
return
conversation_id = self.graph_runtime_state.system_variable.conversation_id
if conversation_id is None:
return
updated_any = False
for item in updated_variables:
selector = item.selector
if len(selector) < 2:
logger.warning("Conversation variable selector invalid. selector=%s", selector)
continue
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,
)
continue
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
updated_any = True
if updated_any:
self._conversation_variable_updater.flush()
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
super().__init__()
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
if not isinstance(event, GraphRunPausedEvent):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)

View File

@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer):
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer):
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from configs import dify_config
@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
"""
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import enum
from enum import StrEnum
from typing import Any
@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
ONLINE_DRIVE = "online_drive"
@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
def value_of(cls, value: str) -> DatasourceProviderType:
"""
Get value of given mode.
@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
typ: DatasourceParameterType,
required: bool,
options: list[str] | None = None,
) -> "DatasourceParameter":
) -> DatasourceParameter:
"""
get a simple datasource parameter
@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
def empty(cls) -> DatasourceInvokeMeta:
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
def error_instance(cls, error: str) -> DatasourceInvokeMeta:
"""
Get an instance of DatasourceInvokeMeta with error
"""

View File

@ -1,7 +1,7 @@
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
_session_maker: sessionmaker | None = None
_session_maker: sessionmaker[Session] | None = None
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
def get_session_maker() -> sessionmaker:
def get_session_maker() -> sessionmaker[Session]:
if _session_maker is None:
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
return _session_maker
@ -27,7 +27,7 @@ class SessionFactory:
configure_session_factory(engine, expire_on_commit)
@staticmethod
def get_session_maker() -> sessionmaker:
def get_session_maker() -> sessionmaker[Session]:
return get_session_maker()
@staticmethod

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from datetime import datetime
from enum import StrEnum
@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
"""Create entity from database model with decryption"""
return cls(

View File

@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity):
@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_small_dark=provider_entity.icon_small_dark,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types,
)
@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType] = []

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from enum import StrEnum, auto
from typing import Union
@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
def value_of(cls, value: str) -> ProviderConfig.Type:
"""
Get value of given mode.

View File

@ -8,8 +8,9 @@ import urllib.parse
from configs import dify_config
def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
url = f"{base_url}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()

View File

@ -112,17 +112,17 @@ class File(BaseModel):
return text
def generate_url(self) -> str | None:
def generate_url(self, for_external: bool = True) -> str | None:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=self.related_id)
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
assert self.related_id is not None
assert self.extension is not None
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
return None
def to_plugin_parameter(self) -> dict[str, Any]:
@ -133,7 +133,7 @@ class File(BaseModel):
"extension": self.extension,
"size": self.size,
"type": self.type,
"url": self.generate_url(),
"url": self.generate_url(for_external=False),
}
@model_validator(mode="after")

View File

@ -76,7 +76,7 @@ class TemplateTransformer(ABC):
Post-process the result to convert scientific notation strings back to numbers
"""
def convert_scientific_notation(value):
def convert_scientific_notation(value: Any) -> Any:
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
@ -90,7 +90,7 @@ class TemplateTransformer(ABC):
return [convert_scientific_notation(v) for v in value]
return value
return convert_scientific_notation(result) # type: ignore[no-any-return]
return convert_scientific_notation(result)
@classmethod
@abstractmethod

View File

@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
def _inject_trace_headers(headers: dict | None) -> dict:
"""
Inject W3C traceparent header for distributed tracing.
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
When OTEL is disabled, we manually inject the traceparent header.
"""
if headers is None:
headers = {}
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return headers
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
if dify_config.ENABLE_OTEL:
return headers
# Generate and inject traceparent for non-OTEL scenarios
try:
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
# Silently ignore errors to avoid breaking requests
logger.debug("Failed to generate traceparent header", exc_info=True)
return headers
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
headers = kwargs.get("headers") or {}
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
# Preserve the user-provided Host header
# httpx may override the Host header when using a proxy
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host

View File

@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
return None
def get_span_id_from_otel_context() -> str | None:
"""
Retrieve the current span ID from the active OpenTelemetry trace context.
Returns:
A 16-character hex string representing the span ID, or None if not available.
"""
try:
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID
span = get_current_span()
if not span:
return None
span_context = span.get_span_context()
if not span_context or span_context.span_id == INVALID_SPAN_ID:
return None
return f"{span_context.span_id:016x}"
except Exception:
return None
def generate_traceparent_header() -> str | None:
"""
Generate a W3C traceparent header from the current context.
Uses OpenTelemetry context if available, otherwise uses the
ContextVar-based trace_id from the logging context.
Format: {version}-{trace_id}-{span_id}-{flags}
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
Returns:
A valid traceparent header string, or None if generation fails.
"""
import uuid
# Try OTEL context first
trace_id = get_trace_id_from_otel_context()
span_id = get_span_id_from_otel_context()
if trace_id and span_id:
return f"00-{trace_id}-{span_id}-01"
# Fallback: use ContextVar-based trace_id or generate new one
from core.logging.context import get_trace_id as get_logging_trace_id
trace_id = get_logging_trace_id() or uuid.uuid4().hex
# Generate a new span_id (16 hex chars)
span_id = uuid.uuid4().hex[:16]
return f"00-{trace_id}-{span_id}-01"

View File

@ -0,0 +1,20 @@
"""Structured logging components for Dify."""
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
from core.logging.filters import IdentityContextFilter, TraceContextFilter
from core.logging.structured_formatter import StructuredJSONFormatter
__all__ = [
"IdentityContextFilter",
"StructuredJSONFormatter",
"TraceContextFilter",
"clear_request_context",
"get_request_id",
"get_trace_id",
"init_request_context",
]

View File

@ -0,0 +1,35 @@
"""Request context for logging - framework agnostic.
This module provides request-scoped context variables for logging,
using Python's contextvars for thread-safe and async-safe storage.
"""
import uuid
from contextvars import ContextVar
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
def get_request_id() -> str:
"""Get current request ID (10 hex chars)."""
return _request_id.get()
def get_trace_id() -> str:
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
return _trace_id.get()
def init_request_context() -> None:
"""Initialize request context. Call at start of each request."""
req_id = uuid.uuid4().hex[:10]
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
_request_id.set(req_id)
_trace_id.set(trace_id)
def clear_request_context() -> None:
"""Clear request context. Call at end of request (optional)."""
_request_id.set("")
_trace_id.set("")

View File

@ -0,0 +1,94 @@
"""Logging filters for structured logging."""
import contextlib
import logging
import flask
from core.logging.context import get_request_id, get_trace_id
class TraceContextFilter(logging.Filter):
"""
Filter that adds trace_id and span_id to log records.
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
# For backward compatibility, also set req_id
record.req_id = get_request_id()
return True
def _get_otel_context(self) -> tuple[str, str]:
"""Extract trace_id and span_id from OpenTelemetry context."""
with contextlib.suppress(Exception):
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
span = get_current_span()
if span and span.get_span_context():
ctx = span.get_span_context()
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
trace_id = f"{ctx.trace_id:032x}"
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
return trace_id, span_id
return "", ""
class IdentityContextFilter(logging.Filter):
"""
Filter that adds user identity context to log records.
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")
record.user_id = identity.get("user_id", "")
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
return {}
from flask_login import current_user
# Check if user is authenticated using the proxy
if not current_user.is_authenticated:
return {}
# Access the underlying user object
user = current_user
from models import Account
from models.model import EndUser
identity: dict[str, str] = {}
if isinstance(user, Account):
if user.current_tenant_id:
identity["tenant_id"] = user.current_tenant_id
identity["user_id"] = user.id
identity["user_type"] = "account"
elif isinstance(user, EndUser):
identity["tenant_id"] = user.tenant_id
identity["user_id"] = user.id
identity["user_type"] = user.type or "end_user"
return identity
except Exception:
return {}

View File

@ -0,0 +1,107 @@
"""Structured JSON log formatter for Dify."""
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
import orjson
from configs import dify_config
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
{
"ts": "ISO 8601 UTC",
"severity": "INFO|ERROR|WARN|DEBUG",
"service": "service name",
"caller": "file:line",
"trace_id": "hex 32",
"span_id": "hex 16",
"identity": { "tenant_id", "user_id", "user_type" },
"message": "log message",
"attributes": { ... },
"stack_trace": "..."
}
"""
SEVERITY_MAP: dict[int, str] = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARN",
logging.ERROR: "ERROR",
logging.CRITICAL: "ERROR",
}
def __init__(self, service_name: str | None = None):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
try:
return orjson.dumps(log_dict).decode("utf-8")
except TypeError:
# Fallback: convert non-serializable objects to string
import json
return json.dumps(log_dict, default=str, ensure_ascii=False)
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
# Core fields
log_dict: dict[str, Any] = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,
"caller": f"{record.filename}:{record.lineno}",
"message": record.getMessage(),
}
# Trace context (from TraceContextFilter)
trace_id = getattr(record, "trace_id", "")
span_id = getattr(record, "span_id", "")
if trace_id:
log_dict["trace_id"] = trace_id
if span_id:
log_dict["span_id"] = span_id
# Identity context (from IdentityContextFilter)
identity = self._extract_identity(record)
if identity:
log_dict["identity"] = identity
# Dynamic attributes
attributes = getattr(record, "attributes", None)
if attributes:
log_dict["attributes"] = attributes
# Stack trace for errors with exceptions
if record.exc_info and record.levelno >= logging.ERROR:
log_dict["stack_trace"] = self._format_exception(record.exc_info)
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:
identity["user_id"] = user_id
if user_type:
identity["user_type"] = user_type
return identity
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
if exc_info and exc_info[0] is not None:
return "".join(traceback.format_exception(*exc_info))
return ""

View File

@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT
]""",
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
TOOL = auto()
@classmethod
def value_of(cls, value: str) -> "PromptMessageRole":
def value_of(cls, value: str) -> PromptMessageRole:
"""
Get value of given mode.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from decimal import Decimal
from enum import StrEnum, auto
from typing import Any
@ -20,7 +22,7 @@ class ModelType(StrEnum):
TTS = auto()
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
def value_of(cls, origin_model_type: str) -> ModelType:
"""
Get model type from origin model type.
@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
JSON_SCHEMA = auto()
@classmethod
def value_of(cls, value: Any) -> "DefaultParameterName":
def value_of(cls, value: Any) -> DefaultParameterName:
"""
Get parameter name from value.

View File

@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
@ -123,7 +122,6 @@ class ProviderEntity(BaseModel):
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
icon_small_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
@ -157,7 +155,6 @@ class ProviderEntity(BaseModel):
provider=self.provider,
label=self.label,
icon_small=self.icon_small,
icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models,
)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Sequence
@ -38,7 +40,7 @@ class ModelProviderFactory:
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
@ -76,7 +78,7 @@ class ModelProviderFactory:
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
"""
Get plugin model provider
:param provider: provider name
@ -285,7 +287,7 @@ class ModelProviderFactory:
"""
Get provider icon
:param provider: provider name
:param icon_type: icon type (icon_small or icon_large)
:param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return: provider icon
"""
@ -309,13 +311,7 @@ class ModelProviderFactory:
else:
file_name = provider_schema.icon_small_dark.en_US
else:
if not provider_schema.icon_large:
raise ValueError(f"Provider {provider} does not have large icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
def of(cls, credential_type: str) -> CredentialType:
type_name = credential_type.lower()
if type_name in {"api-key", "api_key"}:
return cls.API_KEY

View File

@ -103,6 +103,9 @@ class BasePluginClient:
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
# Inject traceparent header for distributed tracing
self._inject_trace_headers(prepared_headers)
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
@ -114,6 +117,31 @@ class BasePluginClient:
return str(url), prepared_headers, prepared_data, params, files
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
"""
Inject W3C traceparent header for distributed tracing.
This ensures trace context is propagated to plugin daemon even if
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
"""
if not dify_config.ENABLE_OTEL:
return
import contextlib
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
with contextlib.suppress(Exception):
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
def _stream_request(
self,
method: str,

View File

@ -331,7 +331,6 @@ class ProviderManager:
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types,
),
)

View File

@ -27,26 +27,44 @@ class CleanProcessor:
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
# Remove URL but keep Markdown image URLs
# First, temporarily replace Markdown image URLs with a placeholder
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
placeholders: list[str] = []
# Remove URL but keep Markdown image URLs and link URLs
# Replace the ENTIRE markdown link/image with a single placeholder to protect
# the link text (which might also be a URL) from being removed
markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
def replace_with_placeholder(match, placeholders=placeholders):
def replace_markdown_with_placeholder(match, placeholders=placeholders):
link_type = "link"
link_text = match.group(1)
url = match.group(2)
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
placeholders.append((link_type, link_text, url))
return placeholder
def replace_image_with_placeholder(match, placeholders=placeholders):
link_type = "image"
url = match.group(1)
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
placeholders.append(url)
return f"![image]({placeholder})"
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
placeholders.append((link_type, "image", url))
return placeholder
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
# Protect markdown links first
text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
# Then protect markdown images
text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
# Now remove all remaining URLs
url_pattern = r"https?://[^\s)]+"
url_pattern = r"https?://\S+"
text = re.sub(url_pattern, "", text)
# Finally, restore the Markdown image URLs
for i, url in enumerate(placeholders):
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
# Restore the Markdown links and images
for i, (link_type, text_or_alt, url) in enumerate(placeholders):
placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
if link_type == "link":
text = text.replace(placeholder, f"[{text_or_alt}]({url})")
else: # image
text = text.replace(placeholder, f"![{text_or_alt}]({url})")
return text
def filter_string(self, text):

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import contextlib
import json
import logging
@ -6,7 +8,7 @@ import re
import threading
import time
import uuid
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
Manages connection reuse across ClickzettaVector instances.
"""
_instance: Optional["ClickzettaConnectionPool"] = None
_instance: ClickzettaConnectionPool | None = None
_lock = threading.Lock()
def __init__(self):
@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
self._start_cleanup_thread()
@classmethod
def get_instance(cls) -> "ClickzettaConnectionPool":
def get_instance(cls) -> ClickzettaConnectionPool:
"""Get singleton instance of connection pool."""
if cls._instance is None:
with cls._lock:
@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
)
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
def _create_connection(self, config: ClickzettaConfig) -> Connection:
"""Create a new ClickZetta connection."""
max_retries = 3
retry_delay = 1.0
@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
def _configure_connection(self, connection: "Connection"):
def _configure_connection(self, connection: Connection):
"""Configure connection session settings."""
try:
with connection.cursor() as cursor:
@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
except Exception:
logger.exception("Failed to configure connection, continuing with defaults")
def _is_connection_valid(self, connection: "Connection") -> bool:
def _is_connection_valid(self, connection: Connection) -> bool:
"""Check if connection is still valid."""
try:
with connection.cursor() as cursor:
@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
except Exception:
return False
def get_connection(self, config: ClickzettaConfig) -> "Connection":
def get_connection(self, config: ClickzettaConfig) -> Connection:
"""Get a connection from the pool or create a new one."""
config_key = self._get_config_key(config)
@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
# No valid connection found, create new one
return self._create_connection(config)
def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
def return_connection(self, config: ClickzettaConfig, connection: Connection):
"""Return a connection to the pool."""
config_key = self._get_config_key(config)
@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
self._connection_pool = ClickzettaConnectionPool.get_instance()
self._init_write_queue()
def _get_connection(self) -> "Connection":
def _get_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._connection_pool.get_connection(self._config)
def _return_connection(self, connection: "Connection"):
def _return_connection(self, connection: Connection):
"""Return a connection to the pool."""
self._connection_pool.return_connection(self._config, connection)
class ConnectionContext:
"""Context manager for borrowing and returning connections."""
def __init__(self, vector_instance: "ClickzettaVector"):
def __init__(self, vector_instance: ClickzettaVector):
self.vector = vector_instance
self.connection: Connection | None = None
def __enter__(self) -> "Connection":
def __enter__(self) -> Connection:
self.connection = self.vector._get_connection()
return self.connection
@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
if self.connection:
self.vector._return_connection(self.connection)
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
"""Get a connection context manager."""
return self.ConnectionContext(self)
@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
"""Return the vector database type."""
return "clickzetta"
def _ensure_connection(self) -> "Connection":
def _ensure_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._get_connection()
@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector):
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
escaped_query = escape_like_pattern(query).replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""

View File

@ -287,11 +287,15 @@ class IrisVector(BaseVector):
cursor.execute(sql, (query,))
else:
# Fallback to LIKE search (inefficient for large datasets)
query_pattern = f"%{query}%"
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
escaped_query = escape_like_pattern(query)
query_pattern = f"%{escaped_query}%"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE text LIKE ?
WHERE text LIKE ? ESCAPE '\\'
"""
cursor.execute(sql, (query_pattern,))

View File

@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
in a Weaviate collection.
"""
_DOCUMENT_ID_PROPERTY = "document_id"
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
@ -22,7 +24,7 @@ class DatasetDocumentStore:
self._document_id = document_id
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
return cls(**config_dict)
def to_dict(self) -> dict[str, Any]:

View File

@ -112,7 +112,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@ -148,7 +148,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:

View File

@ -1,25 +1,57 @@
"""Abstract interface for document loader implementations."""
import contextlib
import io
import logging
import uuid
from collections.abc import Iterator
import pypdfium2
import pypdfium2.raw as pdfium_c
from configs import dify_config
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfExtractor(BaseExtractor):
"""Load pdf files.
"""
PdfExtractor is used to extract text and images from PDF files.
Args:
file_path: Path to the file to load.
file_path: Path to the PDF file.
tenant_id: Workspace ID.
user_id: ID of the user performing the extraction.
file_cache_key: Optional cache key for the extracted text.
"""
def __init__(self, file_path: str, file_cache_key: str | None = None):
"""Initialize with file path."""
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
IMAGE_FORMATS = [
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
(b"GIF8", "gif", "image/gif"),
(b"BM", "bmp", "image/bmp"),
(b"II*\x00", "tiff", "image/tiff"),
(b"MM\x00*", "tiff", "image/tiff"),
(b"II+\x00", "tiff", "image/tiff"),
(b"MM\x00+", "tiff", "image/tiff"),
]
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
"""Initialize PdfExtractor."""
self._file_path = file_path
self._tenant_id = tenant_id
self._user_id = user_id
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
image_content = self._extract_images(page)
if image_content:
content += "\n" + image_content
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()
def _extract_images(self, page) -> str:
"""
Extract images from a PDF page, save them to storage and database,
and return markdown image links.
Args:
page: pypdfium2 page object.
Returns:
Markdown string containing links to the extracted images.
"""
image_content = []
upload_files = []
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
try:
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
for obj in image_objects:
try:
# Extract image bytes
img_byte_arr = io.BytesIO()
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
# Fallback to png for other formats
obj.extract(img_byte_arr, fb_format="png")
img_bytes = img_byte_arr.getvalue()
if not img_bytes:
continue
header = img_bytes[: self.MAX_MAGIC_LEN]
image_ext = None
mime_type = None
for magic, ext, mime in self.IMAGE_FORMATS:
if header.startswith(magic):
image_ext = ext
mime_type = mime
break
if not image_ext or not mime_type:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
storage.save(file_key, img_bytes)
# save file to db
upload_file = UploadFile(
tenant_id=self._tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=len(img_bytes),
extension=image_ext,
mime_type=mime_type,
created_by=self._user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self._user_id,
used_at=naive_utc_now(),
)
upload_files.append(upload_file)
image_content.append(f"![image]({base_url}/files/{upload_file.id}/file-preview)")
except Exception as e:
logger.warning("Failed to extract image from PDF: %s", e)
continue
except Exception as e:
logger.warning("Failed to get objects from PDF page: %s", e)
if upload_files:
db.session.add_all(upload_files)
db.session.commit()
return "\n".join(image_content)

View File

@ -7,10 +7,11 @@ import re
import tempfile
import uuid
from urllib.parse import urlparse
from xml.etree import ElementTree
import httpx
from docx import Document as DocxDocument
from docx.oxml.ns import qn
from docx.text.run import Run
from configs import dify_config
from core.helper import ssrf_proxy
@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
image_map = self._extract_images_from_docx(doc)
hyperlinks_url = None
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
for para in doc.paragraphs:
for run in para.runs:
if run.text and hyperlinks_url:
result = f" [{run.text}]({hyperlinks_url}) "
run.text = result
hyperlinks_url = None
if "HYPERLINK" in run.element.xml:
try:
xml = ElementTree.XML(run.element.xml)
x_child = [c for c in xml.iter() if c is not None]
for x in x_child:
if x is None:
continue
if x.tag.endswith("instrText"):
if x.text is None:
continue
for i in url_pattern.findall(x.text):
hyperlinks_url = str(i)
except Exception:
logger.exception("Failed to parse HYPERLINK xml")
def parse_paragraph(paragraph):
paragraph_content = []
def append_image_link(image_id, has_drawing):
def append_image_link(image_id, has_drawing, target_buffer):
"""Helper to append image link from image_map based on relationship type."""
rel = doc.part.rels[image_id]
if rel.is_external:
if image_id in image_map and not has_drawing:
paragraph_content.append(image_map[image_id])
target_buffer.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
target_buffer.append(image_map[image_part])
for run in paragraph.runs:
def process_run(run, target_buffer):
# Helper to extract text and embedded images from a run element and append them to target_buffer
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
drawing_elements = run.element.findall(
@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
# External image: use embed_id as key
if embed_id in image_map:
has_drawing = True
paragraph_content.append(image_map[embed_id])
target_buffer.append(image_map[embed_id])
else:
# Internal image: use target_part as key
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
target_buffer.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
append_image_link(image_id, has_drawing)
append_image_link(image_id, has_drawing, target_buffer)
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
append_image_link(image_id, has_drawing)
append_image_link(image_id, has_drawing, target_buffer)
if run.text.strip():
paragraph_content.append(run.text.strip())
target_buffer.append(run.text.strip())
def process_hyperlink(hyperlink_elem, target_buffer):
# Helper to extract text from a hyperlink element and append it to target_buffer
r_id = hyperlink_elem.get(qn("r:id"))
# Extract text from runs inside the hyperlink
link_text_parts = []
for run_elem in hyperlink_elem.findall(qn("w:r")):
run = Run(run_elem, paragraph)
# Hyperlink text may be split across multiple runs (e.g., with different formatting),
# so collect all run texts first
if run.text:
link_text_parts.append(run.text)
link_text = "".join(link_text_parts).strip()
# Resolve URL
if r_id:
try:
rel = doc.part.rels.get(r_id)
if rel and rel.is_external:
link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
except Exception:
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
if link_text:
target_buffer.append(link_text)
paragraph_content = []
# State for legacy HYPERLINK fields
hyperlink_field_url = None
hyperlink_field_text_parts: list = []
is_collecting_field_text = False
# Iterate through paragraph elements in document order
for child in paragraph._element:
tag = child.tag
if tag == qn("w:r"):
# Regular run
run = Run(child, paragraph)
# Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
fld_chars = child.findall(qn("w:fldChar"))
instr_texts = child.findall(qn("w:instrText"))
# Handle Fields
if fld_chars or instr_texts:
# Process instrText to find HYPERLINK "url"
for instr in instr_texts:
if instr.text and "HYPERLINK" in instr.text:
# Quick regex to extract URL
match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
if match:
hyperlink_field_url = match.group(1)
# Process fldChar
for fld_char in fld_chars:
fld_char_type = fld_char.get(qn("w:fldCharType"))
if fld_char_type == "begin":
# Start of a field: reset legacy link state
hyperlink_field_url = None
hyperlink_field_text_parts = []
is_collecting_field_text = False
elif fld_char_type == "separate":
# Separator: if we found a URL, start collecting visible text
if hyperlink_field_url:
is_collecting_field_text = True
elif fld_char_type == "end":
# End of field
if is_collecting_field_text and hyperlink_field_url:
# Create markdown link and append to main content
display_text = "".join(hyperlink_field_text_parts).strip()
if display_text:
link_md = f"[{display_text}]({hyperlink_field_url})"
paragraph_content.append(link_md)
# Reset state
hyperlink_field_url = None
hyperlink_field_text_parts = []
is_collecting_field_text = False
# Decide where to append content
target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
process_run(run, target_buffer)
elif tag == qn("w:hyperlink"):
process_hyperlink(child, paragraph_content)
return "".join(paragraph_content) if paragraph_content else ""
paragraphs = doc.paragraphs.copy()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any
@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
return self.model_dump_json()
@classmethod
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
def deserialize(cls, serialized_data: str) -> TaskWrapper:
return cls.model_validate_json(serialized_data)

View File

@ -515,6 +515,7 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
dataset_count = len(available_datasets)
with measure_time() as timer:
cancel_event = threading.Event()
thread_exceptions: list[Exception] = []
@ -537,6 +538,7 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
"dataset_count": dataset_count,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
@ -562,6 +564,7 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
"dataset_count": dataset_count,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
@ -1195,18 +1198,24 @@ class DatasetRetrieval:
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
from libs.helper import escape_like_pattern
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
case "start with":
filters.append(json_field.like(f"{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
case "end with":
filters.append(json_field.like(f"%{value}"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
case "is" | "=":
if isinstance(value, str):
@ -1422,6 +1431,7 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
dataset_count: int,
cancel_event: threading.Event | None = None,
thread_exceptions: list[Exception] | None = None,
):
@ -1470,37 +1480,38 @@ class DatasetRetrieval:
if cancel_event and cancel_event.is_set():
break
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import json
import logging
import threading
from collections.abc import Mapping, MutableMapping
from pathlib import Path
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar
class SchemaRegistry:
@ -11,7 +13,7 @@ class SchemaRegistry:
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
_default_instance: ClassVar[SchemaRegistry | None] = None
_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self, base_dir: str):
@ -20,7 +22,7 @@ class SchemaRegistry:
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
@classmethod
def default_registry(cls) -> "SchemaRegistry":
def default_registry(cls) -> SchemaRegistry:
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
if cls._default_instance is None:
with cls._lock:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy
@ -25,7 +27,7 @@ class Tool(ABC):
self.entity = entity
self.runtime = runtime
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
"""
fork a new tool with metadata
:return: the new tool
@ -221,7 +223,7 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
def create_file_message(self, file: File) -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool
@ -24,7 +26,7 @@ class BuiltinTool(Tool):
super().__init__(**kwargs)
self.provider = provider
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
"""
fork a new tool with metadata
:return: the new tool

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from pydantic import Field
from sqlalchemy import select
@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = []
@classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
credentials_schema = [
ProviderConfig(
name="auth_type",

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import contextlib
from collections.abc import Mapping
@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
MCP = auto()
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
def value_of(cls, value: str) -> ToolProviderType:
"""
Get value of given mode.
@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
OPENAI_ACTIONS = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderSchemaType":
def value_of(cls, value: str) -> ApiProviderSchemaType:
"""
Get value of given mode.
@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
API_KEY_QUERY = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
def value_of(cls, value: str) -> ApiProviderAuthType:
"""
Get value of given mode.
@ -307,7 +309,7 @@ class ToolParameter(PluginParameter):
typ: ToolParameterType,
required: bool,
options: list[str] | None = None,
) -> "ToolParameter":
) -> ToolParameter:
"""
get a simple tool parameter
@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
def empty(cls) -> "ToolInvokeMeta":
def empty(cls) -> ToolInvokeMeta:
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "ToolInvokeMeta":
def error_instance(cls, error: str) -> ToolInvokeMeta:
"""
Get an instance of ToolInvokeMeta with error
"""

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import json
import logging
@ -118,7 +120,7 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,
runtime=runtime,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
@ -46,7 +48,7 @@ class PluginTool(Tool):
message_id=message_id,
)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
return PluginTool(
entity=self.entity,
runtime=runtime,

View File

@ -7,12 +7,12 @@ import time
from configs import dify_config
def sign_tool_file(tool_file_id: str, extension: str) -> str:
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""
sign file to get a temporary url for plugin access
"""
# Use internal URL for plugin/tool file access in Docker environments
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
# Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))

View File

@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
auto parse to tool bundle

View File

@ -4,6 +4,7 @@ import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Preserves markdown links like [text](url) at the start.
Args:
text (str): The input text to process.
@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Check if text starts with a markdown link - preserve it
markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
if re.match(markdown_link_pattern, text):
return text
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Mapping
from pydantic import Field
@ -47,14 +49,13 @@ class WorkflowToolProviderController(ToolProviderController):
self.provider_id = provider_id
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
with session_factory.create_session() as session, session.begin():
app = session.get(App, db_provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
@ -67,7 +68,7 @@ class WorkflowToolProviderController(ToolProviderController):
credentials_schema=[],
plugin_id=None,
),
provider_id="",
provider_id=db_provider.id,
)
controller.tools = [

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
@ -181,7 +183,7 @@ class WorkflowTool(Tool):
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
"""
fork a new tool with metadata

View File

@ -1,6 +1,8 @@
from __future__ import annotations
from collections.abc import Mapping
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from core.file.models import File
@ -52,7 +54,7 @@ class SegmentType(StrEnum):
return self in _ARRAY_TYPES
@classmethod
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
def infer_segment_type(cls, value: Any) -> SegmentType | None:
"""
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
@ -173,7 +175,7 @@ class SegmentType(StrEnum):
raise AssertionError("this statement should be unreachable.")
@staticmethod
def cast_value(value: Any, type_: "SegmentType"):
def cast_value(value: Any, type_: SegmentType):
# Cast Python's `bool` type to `int` when the runtime type requires
# an integer or number.
#
@ -193,7 +195,7 @@ class SegmentType(StrEnum):
return [int(i) for i in value]
return value
def exposed_type(self) -> "SegmentType":
def exposed_type(self) -> SegmentType:
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
@ -202,7 +204,7 @@ class SegmentType(StrEnum):
return SegmentType.NUMBER
return self
def element_type(self) -> "SegmentType | None":
def element_type(self) -> SegmentType | None:
"""Return the element type of the current segment type, or `None` if the element type is undefined.
Raises:
@ -217,7 +219,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: "SegmentType"):
def get_zero_value(t: SegmentType):
# Lazy import to avoid circular dependency
from factories import variable_factory

View File

@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
can assume `graph_runtime_state` is available.
### Event-Driven Architecture
All node executions emit events for monitoring and integration:

View File

@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime
from typing import Any
@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
) -> WorkflowExecution:
return WorkflowExecution(
id_=id_,
workflow_id=workflow_id,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
@ -175,7 +177,7 @@ class Graph:
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_factory: "NodeFactory",
node_factory: NodeFactory,
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
@ -197,7 +199,7 @@ class Graph:
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
def new(cls) -> GraphBuilder:
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@ -284,9 +286,9 @@ class Graph:
cls,
*,
graph_config: Mapping[str, object],
node_factory: "NodeFactory",
node_factory: NodeFactory,
root_node_id: str | None = None,
) -> "Graph":
) -> Graph:
"""
Initialize graph
@ -383,7 +385,7 @@ class GraphBuilder:
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
def add_root(self, node: Node) -> GraphBuilder:
"""Register the root node. Must be called exactly once."""
if self._nodes:
@ -398,7 +400,7 @@ class GraphBuilder:
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
) -> GraphBuilder:
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
@ -419,7 +421,7 @@ class GraphBuilder:
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:

View File

@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
if command_type == CommandType.UPDATE_VARIABLES:
return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)

View File

@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
"UpdateVariablesCommandHandler",
]

View File

@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
@final
class UpdateVariablesCommandHandler(CommandHandler):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, UpdateVariablesCommand)
for update in command.updates:
try:
variable = update.value
self._variable_pool.add(variable.selector, variable)
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
except ValueError as exc:
logger.warning(
"Skipping invalid variable selector %s for workflow %s: %s",
getattr(update.value, "selector", None),
execution.workflow_id,
exc,
)

View File

@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
ABORT = auto()
PAUSE = auto()
UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
"""Command to update a group of variables in the variable pool."""
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")

View File

@ -5,9 +5,12 @@ This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
from __future__ import annotations
import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@ -30,8 +33,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .command_processing import (
AbortCommandHandler,
CommandProcessor,
PauseCommandHandler,
UpdateVariablesCommandHandler,
)
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@ -70,10 +78,13 @@ class GraphEngine:
scale_down_idle_time: float | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
@ -140,6 +151,9 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
@ -169,6 +183,7 @@ class GraphEngine:
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
stop_event=self._stop_event,
)
# === Orchestration ===
@ -199,6 +214,7 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@ -212,9 +228,16 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
def _bind_layer_context(
self,
layer: GraphEngineLayer,
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
@ -301,14 +324,7 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
# Create a read-only wrapper for the runtime state
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
try:
layer.initialize(read_only_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
try:
layer.on_graph_start()
except Exception as e:
@ -316,6 +332,7 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
@ -343,13 +360,12 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
logger = logging.getLogger(__name__)
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)

View File

@ -60,6 +60,7 @@ class SkipPropagator:
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node

View File

@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```
`engine.layer()` binds the read-only runtime state before execution, so
`graph_runtime_state` is always available inside layer hooks.
## Custom Layers
```python

View File

@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayerNotInitializedError(Exception):
"""Raised when a layer's runtime state is accessed before initialization."""
def __init__(self, layer_name: str | None = None) -> None:
name = layer_name or "GraphEngineLayer"
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
@property
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
if self._graph_runtime_state is None:
raise GraphEngineLayerNotInitializedError(type(self).__name__)
return self._graph_runtime_state
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Called by GraphEngine to inject the read-only runtime state and command channel.
This is invoked when the layer is registered with a `GraphEngine` instance.
Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod

View File

@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
if self.include_outputs and self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop and pause operations.
"""
@staticmethod
@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""

View File

@ -44,6 +44,7 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@ -61,7 +62,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._stop_event = stop_event
self._start_time: float | None = None
def start(self) -> None:
@ -69,16 +70,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
self._thread.join(timeout=2.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""

View File

@ -2,6 +2,8 @@
Factory for creating ReadyQueue instances from serialized state.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
@ -11,7 +13,7 @@ if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
"""
Create a ReadyQueue instance from a serialized state.

View File

@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from __future__ import annotations
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
@ -27,7 +29,7 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> "ResponseSession":
def from_node(cls, node: Node) -> ResponseSession:
"""
Create a ResponseSession from an AnswerNode or EndNode.

View File

@ -42,6 +42,7 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
@ -65,13 +66,16 @@ class Worker(threading.Thread):
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
self._stop_event = stop_event
self._layers = layers if layers is not None else []
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
@property
def is_idle(self) -> bool:

View File

@ -41,6 +41,7 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
@ -81,6 +82,7 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@ -135,7 +137,7 @@ class WorkerPool:
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
worker.join(timeout=2.0)
self._workers.clear()
@ -152,6 +154,7 @@ class WorkerPool:
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
stop_event=self._stop_event,
)
worker.start()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: "PluginAgentStrategy",
strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]):
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> "InvokeCredentials":
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]):
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy

View File

@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import importlib
import logging
import operator
@ -62,7 +64,7 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
node_type: ClassVar["NodeType"]
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@ -201,14 +203,14 @@ class Node(Generic[NodeDataT]):
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
self._graph_init_params = graph_init_params
self.id = id
@ -244,7 +246,7 @@ class Node(Generic[NodeDataT]):
return
@property
def graph_init_params(self) -> "GraphInitParams":
def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
@property
@ -267,6 +269,10 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@ -335,6 +341,21 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(
@ -441,7 +462,7 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.

View File

@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
@ -58,7 +60,7 @@ class Template:
segments: list[TemplateSegmentUnion]
@classmethod
def from_answer_template(cls, template_str: str) -> "Template":
def from_answer_template(cls, template_str: str) -> Template:
"""Create a Template from an Answer node template string.
Example:
@ -107,7 +109,7 @@ class Template:
return cls(segments=segments)
@classmethod
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.

View File

@ -1,8 +1,7 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from .exc import (
CodeNodeError,
@ -20,9 +20,41 @@ from .exc import (
OutputValidationError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
Python3CodeProvider,
JavascriptCodeProvider,
)
_limits: CodeNodeLimits
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
)
self._limits = code_limits
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
code_provider: type[CodeNodeProvider] = next(
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
)
return code_provider.get_default_config()
@classmethod
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
return cls._DEFAULT_CODE_PROVIDERS
@classmethod
def version(cls) -> str:
return "1"
@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
_ = self._select_code_provider(code_language)
result = self._code_executor.execute_workflow_code_template(
language=code_language,
code=code,
inputs=variables,
@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
for provider in self._code_providers:
if provider.is_accept_language(code_language):
return provider
raise CodeNodeError(f"Unsupported code language: {code_language}")
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
if len(value) > self._limits.max_string_length:
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
f" less than {self._limits.max_string_length} characters"
)
return value.replace("\x00", "")
@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
if value > self._limits.max_number or value < self._limits.min_number:
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
if precision > dify_config.CODE_MAX_PRECISION:
if precision > self._limits.max_precision:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
f" it must be less than {self._limits.max_precision} digits."
)
return value
@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
if depth > self._limits.max_depth:
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:
@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
if len(value) > self._limits.max_number_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
f" less than {self._limits.max_number_array_length} elements."
)
for i, inner_value in enumerate(value):
@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
if len(result[output_name]) > self._limits.max_string_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
f" less than {self._limits.max_string_array_length} elements."
)
transformed_result[output_name] = [
@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
if len(result[output_name]) > self._limits.max_object_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
f" less than {self._limits.max_object_array_length} elements."
)
for i, value in enumerate(result[output_name]):

View File

@ -0,0 +1,13 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class CodeNodeLimits:
max_string_length: int
max_number: int | float
min_number: int | float
max_precision: int
max_depth: int
max_number_array_length: int
max_string_array_length: int
max_object_array_length: int

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import io
import json
@ -134,7 +136,7 @@ class LLMNode(Node[LLMNodeData]):
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_file_outputs: list[File]
_llm_file_saver: LLMFileSaver
@ -142,8 +144,8 @@ class LLMNode(Node[LLMNodeData]):
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@ -445,7 +447,7 @@ class LLMNode(Node[LLMNodeData]):
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@ -499,7 +501,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@ -675,7 +677,7 @@ class LLMNode(Node[LLMNodeData]):
)
@staticmethod
def _image_file_to_markdown(file: "File", /):
def _image_file_to_markdown(file: File, /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@ -924,7 +926,7 @@ class LLMNode(Node[LLMNodeData]):
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence["File"],
sys_files: Sequence[File],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -935,7 +937,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list["File"] | None = None,
context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@ -1287,7 +1289,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | LLMResultWithStructuredOutput,
saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_latency: float | None = None,
) -> ModelInvokeCompletedEvent:
@ -1329,7 +1331,7 @@ class LLMNode(Node[LLMNodeData]):
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> "File":
) -> File:
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@ -1379,7 +1381,7 @@ class LLMNode(Node[LLMNodeData]):
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.

View File

@ -1,10 +1,21 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory):
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
)
self._code_limits = code_limits or CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -72,6 +103,25 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: type[CodeExecutor]
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc
rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod

View File

@ -1,28 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from extensions.ext_database import db
from models import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def flush(self):
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()

View File

@ -1,9 +1,8 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
def __init__(
self,
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
):
super().__init__(
id=id,
@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._conv_var_updater_factory = conv_var_updater_factory
@classmethod
def version(cls) -> str:
@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found")
conv_var_updater = self._conv_var_updater_factory()
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater.flush()
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={

View File

@ -1,24 +1,20 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, cast
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidDataError,
InvalidInputValueError,
@ -26,6 +22,10 @@ from .exc import (
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
return False
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@classmethod
def version(cls) -> str:
return "2"
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
conv_var_updater = self._conv_var_updater_factory()
# Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
if self.invoke_from != InvokeFrom.DEBUGGER:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
conv_var_updater.update(
conversation_id=cast(str, conversation_id),
variable=variable,
)
conv_var_updater.flush()
updated_variables = [
common_helpers.variable_to_processed_data(selector, seg)
for selector in updated_variable_selectors

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import abc
from collections.abc import Mapping
from typing import Any, Protocol
@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> "DraftVariableSaver":
) -> DraftVariableSaver:
pass

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import importlib
import json
import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
@ -168,6 +169,7 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self.stop_event: threading.Event = threading.Event()
if graph is not None:
self.attach_graph(graph)

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Segment | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""Get a variable value (read-only)."""
...

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any
@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""Return a copy of a variable value if present."""
value = self._variable_pool.get([node_id, variable_key])
value = self._variable_pool.get(selector)
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
@ -267,6 +269,6 @@ class VariablePool(BaseModel):
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":
def empty(cls) -> VariablePool:
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Any
@ -70,7 +72,7 @@ class SystemVariable(BaseModel):
return data
@classmethod
def empty(cls) -> "SystemVariable":
def empty(cls) -> SystemVariable:
return cls()
def to_dict(self) -> dict[SystemVariableKey, Any]:
@ -114,7 +116,7 @@ class SystemVariable(BaseModel):
d[SystemVariableKey.TIMESTAMP] = self.timestamp
return d
def as_view(self) -> "SystemVariableReadOnlyView":
def as_view(self) -> SystemVariableReadOnlyView:
return SystemVariableReadOnlyView(self)