mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch 'main' into feat/agent-node-v2
This commit is contained in:
@ -20,6 +20,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
@ -40,6 +42,7 @@ from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
)
|
||||
workflow_entry.graph_engine.layer(conversation_variable_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
|
||||
@ -471,6 +471,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
||||
# For ANSWER nodes, check if we need to send a message_replace event
|
||||
# Only send if the final output differs from the accumulated task_state.answer
|
||||
# This happens when variables were updated by variable_assigner during workflow execution
|
||||
if event.node_type == NodeType.ANSWER and event.outputs:
|
||||
final_answer = event.outputs.get("answer")
|
||||
if final_answer is not None and final_answer != self._task_state.answer:
|
||||
logger.info(
|
||||
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
|
||||
final_answer,
|
||||
self._task_state.answer,
|
||||
)
|
||||
# Update the task state answer
|
||||
self._task_state.answer = str(final_answer)
|
||||
# Send message_replace event to update the UI
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=str(final_answer),
|
||||
reason="variable_update",
|
||||
)
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
|
||||
@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
||||
)
|
||||
documents: list[Document] = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
|
||||
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"):
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = args.get("original_document_id") or None
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
|
||||
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry:
|
||||
document_id = document_id or documents[i].id
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document_id,
|
||||
|
||||
@ -42,7 +42,8 @@ class InvokeFrom(StrEnum):
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
# PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow.
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
|
||||
# VALIDATION indicates that this invocation is from validation.
|
||||
VALIDATION = "validation"
|
||||
|
||||
@ -75,7 +75,7 @@ class AnnotationReplyFeature:
|
||||
AppAnnotationService.add_annotation_history(
|
||||
annotation.id,
|
||||
app_record.id,
|
||||
annotation.question,
|
||||
annotation.question_text,
|
||||
annotation.content,
|
||||
query,
|
||||
user_id,
|
||||
|
||||
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
@ -0,0 +1,60 @@
|
||||
import logging
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
|
||||
super().__init__()
|
||||
self._conversation_variable_updater = conversation_variable_updater
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunSucceededEvent):
|
||||
return
|
||||
if event.node_type != NodeType.VARIABLE_ASSIGNER:
|
||||
return
|
||||
if self.graph_runtime_state is None:
|
||||
return
|
||||
|
||||
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
|
||||
if not updated_variables:
|
||||
return
|
||||
|
||||
conversation_id = self.graph_runtime_state.system_variable.conversation_id
|
||||
if conversation_id is None:
|
||||
return
|
||||
|
||||
updated_any = False
|
||||
for item in updated_variables:
|
||||
selector = item.selector
|
||||
if len(selector) < 2:
|
||||
logger.warning("Conversation variable selector invalid. selector=%s", selector)
|
||||
continue
|
||||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
continue
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
logger.warning(
|
||||
"Conversation variable not found in variable pool. selector=%s",
|
||||
selector,
|
||||
)
|
||||
continue
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
|
||||
updated_any = True
|
||||
|
||||
if updated_any:
|
||||
self._conversation_variable_updater.flush()
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
"""
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
super().__init__()
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
self._generate_entity = generate_entity
|
||||
@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
|
||||
entity_wrapper: _GenerateEntityUnion
|
||||
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
||||
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
||||
|
||||
@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
trigger_log_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
):
|
||||
super().__init__()
|
||||
self.trigger_log_id = trigger_log_id
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if not self.graph_runtime_state:
|
||||
logger.exception("Graph runtime state is not set")
|
||||
return
|
||||
|
||||
outputs = self.graph_runtime_state.outputs
|
||||
|
||||
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from configs import dify_config
|
||||
@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
|
||||
ONLINE_DRIVE = "online_drive"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
def value_of(cls, value: str) -> DatasourceProviderType:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: list[str] | None = None,
|
||||
) -> "DatasourceParameter":
|
||||
) -> DatasourceParameter:
|
||||
"""
|
||||
get a simple datasource parameter
|
||||
|
||||
@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
|
||||
tool_config: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "DatasourceInvokeMeta":
|
||||
def empty(cls) -> DatasourceInvokeMeta:
|
||||
"""
|
||||
Get an empty instance of DatasourceInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
||||
def error_instance(cls, error: str) -> DatasourceInvokeMeta:
|
||||
"""
|
||||
Get an instance of DatasourceInvokeMeta with error
|
||||
"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
_session_maker: sessionmaker | None = None
|
||||
_session_maker: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||
|
||||
|
||||
def get_session_maker() -> sessionmaker:
|
||||
def get_session_maker() -> sessionmaker[Session]:
|
||||
if _session_maker is None:
|
||||
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
|
||||
return _session_maker
|
||||
@ -27,7 +27,7 @@ class SessionFactory:
|
||||
configure_session_factory(engine, expire_on_commit)
|
||||
|
||||
@staticmethod
|
||||
def get_session_maker() -> sessionmaker:
|
||||
def get_session_maker() -> sessionmaker[Session]:
|
||||
return get_session_maker()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
||||
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
|
||||
"""Create entity from database model with decryption"""
|
||||
|
||||
return cls(
|
||||
|
||||
@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: list[ModelType]
|
||||
|
||||
def __init__(self, provider_entity: ProviderEntity):
|
||||
@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
label=provider_entity.label,
|
||||
icon_small=provider_entity.icon_small,
|
||||
icon_small_dark=provider_entity.icon_small_dark,
|
||||
icon_large=provider_entity.icon_large,
|
||||
supported_model_types=provider_entity.supported_model_types,
|
||||
)
|
||||
|
||||
@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel):
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType] = []
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Union
|
||||
|
||||
@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
def value_of(cls, value: str) -> ProviderConfig.Type:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
||||
@ -8,8 +8,9 @@ import urllib.parse
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
|
||||
@ -112,17 +112,17 @@ class File(BaseModel):
|
||||
|
||||
return text
|
||||
|
||||
def generate_url(self) -> str | None:
|
||||
def generate_url(self, for_external: bool = True) -> str | None:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
|
||||
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
|
||||
return None
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
@ -133,7 +133,7 @@ class File(BaseModel):
|
||||
"extension": self.extension,
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"url": self.generate_url(),
|
||||
"url": self.generate_url(for_external=False),
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -76,7 +76,7 @@ class TemplateTransformer(ABC):
|
||||
Post-process the result to convert scientific notation strings back to numbers
|
||||
"""
|
||||
|
||||
def convert_scientific_notation(value):
|
||||
def convert_scientific_notation(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
# Check if the string looks like scientific notation
|
||||
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||
@ -90,7 +90,7 @@ class TemplateTransformer(ABC):
|
||||
return [convert_scientific_notation(v) for v in value]
|
||||
return value
|
||||
|
||||
return convert_scientific_notation(result) # type: ignore[no-any-return]
|
||||
return convert_scientific_notation(result)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
|
||||
When OTEL is disabled, we manually inject the traceparent header.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return headers
|
||||
|
||||
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
|
||||
if dify_config.ENABLE_OTEL:
|
||||
return headers
|
||||
|
||||
# Generate and inject traceparent for non-OTEL scenarios
|
||||
try:
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
except Exception:
|
||||
# Silently ignore errors to avoid breaking requests
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
# Convert requests-style allow_redirects to httpx-style follow_redirects
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
if "follow_redirects" not in kwargs:
|
||||
@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||
headers = kwargs.get("headers") or {}
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Preserve user-provided Host header
|
||||
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||
headers = kwargs.get("headers", {})
|
||||
user_provided_host = _get_user_provided_host_header(headers)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
# Build the request manually to preserve the Host header
|
||||
# httpx may override the Host header when using a proxy, so we use
|
||||
# the request API to explicitly set headers before sending
|
||||
# Preserve the user-provided Host header
|
||||
# httpx may override the Host header when using a proxy
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||
if user_provided_host is not None:
|
||||
headers["host"] = user_provided_host
|
||||
|
||||
@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
|
||||
if len(parts) == 4 and len(parts[1]) == 32:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def get_span_id_from_otel_context() -> str | None:
|
||||
"""
|
||||
Retrieve the current span ID from the active OpenTelemetry trace context.
|
||||
|
||||
Returns:
|
||||
A 16-character hex string representing the span ID, or None if not available.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
return None
|
||||
|
||||
span_context = span.get_span_context()
|
||||
if not span_context or span_context.span_id == INVALID_SPAN_ID:
|
||||
return None
|
||||
|
||||
return f"{span_context.span_id:016x}"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_traceparent_header() -> str | None:
|
||||
"""
|
||||
Generate a W3C traceparent header from the current context.
|
||||
|
||||
Uses OpenTelemetry context if available, otherwise uses the
|
||||
ContextVar-based trace_id from the logging context.
|
||||
|
||||
Format: {version}-{trace_id}-{span_id}-{flags}
|
||||
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
|
||||
|
||||
Returns:
|
||||
A valid traceparent header string, or None if generation fails.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Try OTEL context first
|
||||
trace_id = get_trace_id_from_otel_context()
|
||||
span_id = get_span_id_from_otel_context()
|
||||
|
||||
if trace_id and span_id:
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
# Fallback: use ContextVar-based trace_id or generate new one
|
||||
from core.logging.context import get_trace_id as get_logging_trace_id
|
||||
|
||||
trace_id = get_logging_trace_id() or uuid.uuid4().hex
|
||||
|
||||
# Generate a new span_id (16 hex chars)
|
||||
span_id = uuid.uuid4().hex[:16]
|
||||
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
20
api/core/logging/__init__.py
Normal file
20
api/core/logging/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""Structured logging components for Dify."""
|
||||
|
||||
from core.logging.context import (
|
||||
clear_request_context,
|
||||
get_request_id,
|
||||
get_trace_id,
|
||||
init_request_context,
|
||||
)
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
__all__ = [
|
||||
"IdentityContextFilter",
|
||||
"StructuredJSONFormatter",
|
||||
"TraceContextFilter",
|
||||
"clear_request_context",
|
||||
"get_request_id",
|
||||
"get_trace_id",
|
||||
"init_request_context",
|
||||
]
|
||||
35
api/core/logging/context.py
Normal file
35
api/core/logging/context.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""Request context for logging - framework agnostic.
|
||||
|
||||
This module provides request-scoped context variables for logging,
|
||||
using Python's contextvars for thread-safe and async-safe storage.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
|
||||
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get current request ID (10 hex chars)."""
|
||||
return _request_id.get()
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
|
||||
return _trace_id.get()
|
||||
|
||||
|
||||
def init_request_context() -> None:
|
||||
"""Initialize request context. Call at start of each request."""
|
||||
req_id = uuid.uuid4().hex[:10]
|
||||
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
|
||||
_request_id.set(req_id)
|
||||
_trace_id.set(trace_id)
|
||||
|
||||
|
||||
def clear_request_context() -> None:
|
||||
"""Clear request context. Call at end of request (optional)."""
|
||||
_request_id.set("")
|
||||
_trace_id.set("")
|
||||
94
api/core/logging/filters.py
Normal file
94
api/core/logging/filters.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Logging filters for structured logging."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
import flask
|
||||
|
||||
from core.logging.context import get_request_id, get_trace_id
|
||||
|
||||
|
||||
class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds trace_id and span_id to log records.
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = get_request_id()
|
||||
|
||||
return True
|
||||
|
||||
def _get_otel_context(self) -> tuple[str, str]:
|
||||
"""Extract trace_id and span_id from OpenTelemetry context."""
|
||||
with contextlib.suppress(Exception):
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
if span and span.get_span_context():
|
||||
ctx = span.get_span_context()
|
||||
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
|
||||
trace_id = f"{ctx.trace_id:032x}"
|
||||
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
|
||||
return trace_id, span_id
|
||||
return "", ""
|
||||
|
||||
|
||||
class IdentityContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds user identity context to log records.
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
record.user_id = identity.get("user_id", "")
|
||||
record.user_type = identity.get("user_type", "")
|
||||
return True
|
||||
|
||||
def _extract_identity(self) -> dict[str, str]:
|
||||
"""Extract identity from current_user if in request context."""
|
||||
try:
|
||||
if not flask.has_request_context():
|
||||
return {}
|
||||
from flask_login import current_user
|
||||
|
||||
# Check if user is authenticated using the proxy
|
||||
if not current_user.is_authenticated:
|
||||
return {}
|
||||
|
||||
# Access the underlying user object
|
||||
user = current_user
|
||||
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
|
||||
if isinstance(user, Account):
|
||||
if user.current_tenant_id:
|
||||
identity["tenant_id"] = user.current_tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = "account"
|
||||
elif isinstance(user, EndUser):
|
||||
identity["tenant_id"] = user.tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = user.type or "end_user"
|
||||
|
||||
return identity
|
||||
except Exception:
|
||||
return {}
|
||||
107
api/core/logging/structured_formatter.py
Normal file
107
api/core/logging/structured_formatter.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Structured JSON log formatter for Dify."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
{
|
||||
"ts": "ISO 8601 UTC",
|
||||
"severity": "INFO|ERROR|WARN|DEBUG",
|
||||
"service": "service name",
|
||||
"caller": "file:line",
|
||||
"trace_id": "hex 32",
|
||||
"span_id": "hex 16",
|
||||
"identity": { "tenant_id", "user_id", "user_type" },
|
||||
"message": "log message",
|
||||
"attributes": { ... },
|
||||
"stack_trace": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
SEVERITY_MAP: dict[int, str] = {
|
||||
logging.DEBUG: "DEBUG",
|
||||
logging.INFO: "INFO",
|
||||
logging.WARNING: "WARN",
|
||||
logging.ERROR: "ERROR",
|
||||
logging.CRITICAL: "ERROR",
|
||||
}
|
||||
|
||||
def __init__(self, service_name: str | None = None):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
try:
|
||||
return orjson.dumps(log_dict).decode("utf-8")
|
||||
except TypeError:
|
||||
# Fallback: convert non-serializable objects to string
|
||||
import json
|
||||
|
||||
return json.dumps(log_dict, default=str, ensure_ascii=False)
|
||||
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||
# Core fields
|
||||
log_dict: dict[str, Any] = {
|
||||
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||
"service": self._service_name,
|
||||
"caller": f"{record.filename}:{record.lineno}",
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Trace context (from TraceContextFilter)
|
||||
trace_id = getattr(record, "trace_id", "")
|
||||
span_id = getattr(record, "span_id", "")
|
||||
|
||||
if trace_id:
|
||||
log_dict["trace_id"] = trace_id
|
||||
if span_id:
|
||||
log_dict["span_id"] = span_id
|
||||
|
||||
# Identity context (from IdentityContextFilter)
|
||||
identity = self._extract_identity(record)
|
||||
if identity:
|
||||
log_dict["identity"] = identity
|
||||
|
||||
# Dynamic attributes
|
||||
attributes = getattr(record, "attributes", None)
|
||||
if attributes:
|
||||
log_dict["attributes"] = attributes
|
||||
|
||||
# Stack trace for errors with exceptions
|
||||
if record.exc_info and record.levelno >= logging.ERROR:
|
||||
log_dict["stack_trace"] = self._format_exception(record.exc_info)
|
||||
|
||||
return log_dict
|
||||
|
||||
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
||||
tenant_id = getattr(record, "tenant_id", None)
|
||||
user_id = getattr(record, "user_id", None)
|
||||
user_type = getattr(record, "user_type", None)
|
||||
|
||||
if not any([tenant_id, user_id, user_type]):
|
||||
return None
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
if tenant_id:
|
||||
identity["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
identity["user_id"] = user_id
|
||||
if user_type:
|
||||
identity["user_type"] = user_type
|
||||
return identity
|
||||
|
||||
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
|
||||
if exc_info and exc_info[0] is not None:
|
||||
return "".join(traceback.format_exception(*exc_info))
|
||||
return ""
|
||||
@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
):
|
||||
self.request_id = request_id
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
|
||||
TOOL = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "PromptMessageRole":
|
||||
def value_of(cls, value: str) -> PromptMessageRole:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
@ -20,7 +22,7 @@ class ModelType(StrEnum):
|
||||
TTS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
def value_of(cls, origin_model_type: str) -> ModelType:
|
||||
"""
|
||||
Get model type from origin model type.
|
||||
|
||||
@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
|
||||
JSON_SCHEMA = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> "DefaultParameterName":
|
||||
def value_of(cls, value: Any) -> DefaultParameterName:
|
||||
"""
|
||||
Get parameter name from value.
|
||||
|
||||
|
||||
@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
|
||||
@ -123,7 +122,6 @@ class ProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
@ -157,7 +155,6 @@ class ProviderEntity(BaseModel):
|
||||
provider=self.provider,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
icon_large=self.icon_large,
|
||||
supported_model_types=self.supported_model_types,
|
||||
models=self.models,
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
@ -38,7 +40,7 @@ class ModelProviderFactory:
|
||||
plugin_providers = self.get_plugin_model_providers()
|
||||
return [provider.declaration for provider in plugin_providers]
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
@ -76,7 +78,7 @@ class ModelProviderFactory:
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
return plugin_model_provider_entity.declaration
|
||||
|
||||
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
|
||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||
"""
|
||||
Get plugin model provider
|
||||
:param provider: provider name
|
||||
@ -285,7 +287,7 @@ class ModelProviderFactory:
|
||||
"""
|
||||
Get provider icon
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_large)
|
||||
:param icon_type: icon type (icon_small or icon_small_dark)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return: provider icon
|
||||
"""
|
||||
@ -309,13 +311,7 @@ class ModelProviderFactory:
|
||||
else:
|
||||
file_name = provider_schema.icon_small_dark.en_US
|
||||
else:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
def of(cls, credential_type: str) -> CredentialType:
|
||||
type_name = credential_type.lower()
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
|
||||
@ -103,6 +103,9 @@ class BasePluginClient:
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
# Inject traceparent header for distributed tracing
|
||||
self._inject_trace_headers(prepared_headers)
|
||||
|
||||
prepared_data: bytes | dict[str, Any] | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
@ -114,6 +117,31 @@ class BasePluginClient:
|
||||
|
||||
return str(url), prepared_headers, prepared_data, params, files
|
||||
|
||||
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
This ensures trace context is propagated to plugin daemon even if
|
||||
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
|
||||
"""
|
||||
if not dify_config.ENABLE_OTEL:
|
||||
return
|
||||
|
||||
import contextlib
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return
|
||||
|
||||
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
|
||||
with contextlib.suppress(Exception):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
|
||||
@ -331,7 +331,6 @@ class ProviderManager:
|
||||
provider=provider_schema.provider,
|
||||
label=provider_schema.label,
|
||||
icon_small=provider_schema.icon_small,
|
||||
icon_large=provider_schema.icon_large,
|
||||
supported_model_types=provider_schema.supported_model_types,
|
||||
),
|
||||
)
|
||||
|
||||
@ -27,26 +27,44 @@ class CleanProcessor:
|
||||
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
# Remove URL but keep Markdown image URLs
|
||||
# First, temporarily replace Markdown image URLs with a placeholder
|
||||
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
|
||||
placeholders: list[str] = []
|
||||
# Remove URL but keep Markdown image URLs and link URLs
|
||||
# Replace the ENTIRE markdown link/image with a single placeholder to protect
|
||||
# the link text (which might also be a URL) from being removed
|
||||
markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
|
||||
markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
|
||||
placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
|
||||
|
||||
def replace_with_placeholder(match, placeholders=placeholders):
|
||||
def replace_markdown_with_placeholder(match, placeholders=placeholders):
|
||||
link_type = "link"
|
||||
link_text = match.group(1)
|
||||
url = match.group(2)
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
|
||||
placeholders.append((link_type, link_text, url))
|
||||
return placeholder
|
||||
|
||||
def replace_image_with_placeholder(match, placeholders=placeholders):
|
||||
link_type = "image"
|
||||
url = match.group(1)
|
||||
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
|
||||
placeholders.append(url)
|
||||
return f""
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
|
||||
placeholders.append((link_type, "image", url))
|
||||
return placeholder
|
||||
|
||||
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
|
||||
# Protect markdown links first
|
||||
text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
|
||||
# Then protect markdown images
|
||||
text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
|
||||
|
||||
# Now remove all remaining URLs
|
||||
url_pattern = r"https?://[^\s)]+"
|
||||
url_pattern = r"https?://\S+"
|
||||
text = re.sub(url_pattern, "", text)
|
||||
|
||||
# Finally, restore the Markdown image URLs
|
||||
for i, url in enumerate(placeholders):
|
||||
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
|
||||
# Restore the Markdown links and images
|
||||
for i, (link_type, text_or_alt, url) in enumerate(placeholders):
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
|
||||
if link_type == "link":
|
||||
text = text.replace(placeholder, f"[{text_or_alt}]({url})")
|
||||
else: # image
|
||||
text = text.replace(placeholder, f"")
|
||||
return text
|
||||
|
||||
def filter_string(self, text):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
@ -6,7 +8,7 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import clickzetta # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
|
||||
Manages connection reuse across ClickzettaVector instances.
|
||||
"""
|
||||
|
||||
_instance: Optional["ClickzettaConnectionPool"] = None
|
||||
_instance: ClickzettaConnectionPool | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
|
||||
self._start_cleanup_thread()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ClickzettaConnectionPool":
|
||||
def get_instance(cls) -> ClickzettaConnectionPool:
|
||||
"""Get singleton instance of connection pool."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
|
||||
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
|
||||
)
|
||||
|
||||
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
|
||||
def _create_connection(self, config: ClickzettaConfig) -> Connection:
|
||||
"""Create a new ClickZetta connection."""
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
|
||||
|
||||
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
|
||||
|
||||
def _configure_connection(self, connection: "Connection"):
|
||||
def _configure_connection(self, connection: Connection):
|
||||
"""Configure connection session settings."""
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
|
||||
except Exception:
|
||||
logger.exception("Failed to configure connection, continuing with defaults")
|
||||
|
||||
def _is_connection_valid(self, connection: "Connection") -> bool:
|
||||
def _is_connection_valid(self, connection: Connection) -> bool:
|
||||
"""Check if connection is still valid."""
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_connection(self, config: ClickzettaConfig) -> "Connection":
|
||||
def get_connection(self, config: ClickzettaConfig) -> Connection:
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
config_key = self._get_config_key(config)
|
||||
|
||||
@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
|
||||
# No valid connection found, create new one
|
||||
return self._create_connection(config)
|
||||
|
||||
def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
|
||||
def return_connection(self, config: ClickzettaConfig, connection: Connection):
|
||||
"""Return a connection to the pool."""
|
||||
config_key = self._get_config_key(config)
|
||||
|
||||
@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
|
||||
self._connection_pool = ClickzettaConnectionPool.get_instance()
|
||||
self._init_write_queue()
|
||||
|
||||
def _get_connection(self) -> "Connection":
|
||||
def _get_connection(self) -> Connection:
|
||||
"""Get a connection from the pool."""
|
||||
return self._connection_pool.get_connection(self._config)
|
||||
|
||||
def _return_connection(self, connection: "Connection"):
|
||||
def _return_connection(self, connection: Connection):
|
||||
"""Return a connection to the pool."""
|
||||
self._connection_pool.return_connection(self._config, connection)
|
||||
|
||||
class ConnectionContext:
|
||||
"""Context manager for borrowing and returning connections."""
|
||||
|
||||
def __init__(self, vector_instance: "ClickzettaVector"):
|
||||
def __init__(self, vector_instance: ClickzettaVector):
|
||||
self.vector = vector_instance
|
||||
self.connection: Connection | None = None
|
||||
|
||||
def __enter__(self) -> "Connection":
|
||||
def __enter__(self) -> Connection:
|
||||
self.connection = self.vector._get_connection()
|
||||
return self.connection
|
||||
|
||||
@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
|
||||
if self.connection:
|
||||
self.vector._return_connection(self.connection)
|
||||
|
||||
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
|
||||
def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
|
||||
"""Get a connection context manager."""
|
||||
return self.ConnectionContext(self)
|
||||
|
||||
@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
|
||||
"""Return the vector database type."""
|
||||
return "clickzetta"
|
||||
|
||||
def _ensure_connection(self) -> "Connection":
|
||||
def _ensure_connection(self) -> Connection:
|
||||
"""Get a connection from the pool."""
|
||||
return self._get_connection()
|
||||
|
||||
@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query).replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
|
||||
@ -287,11 +287,15 @@ class IrisVector(BaseVector):
|
||||
cursor.execute(sql, (query,))
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
query_pattern = f"%{query}%"
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query)
|
||||
query_pattern = f"%{escaped_query}%"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ?
|
||||
WHERE text LIKE ? ESCAPE '\\'
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
|
||||
@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
|
||||
in a Weaviate collection.
|
||||
"""
|
||||
|
||||
_DOCUMENT_ID_PROPERTY = "document_id"
|
||||
|
||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||
"""
|
||||
Initializes the Weaviate vector store.
|
||||
@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
|
||||
return []
|
||||
|
||||
col = self._client.collections.use(self._collection_name)
|
||||
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
|
||||
props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
|
||||
|
||||
where = None
|
||||
doc_ids = kwargs.get("document_ids_filter") or []
|
||||
if doc_ids:
|
||||
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
|
||||
where = ors[0]
|
||||
for f in ors[1:]:
|
||||
where = where | f
|
||||
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
|
||||
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
|
||||
where = None
|
||||
doc_ids = kwargs.get("document_ids_filter") or []
|
||||
if doc_ids:
|
||||
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
|
||||
where = ors[0]
|
||||
for f in ors[1:]:
|
||||
where = where | f
|
||||
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
|
||||
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
@ -22,7 +24,7 @@ class DatasetDocumentStore:
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -112,7 +112,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@ -148,7 +148,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
|
||||
@ -1,25 +1,57 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pypdfium2
|
||||
import pypdfium2.raw as pdfium_c
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
"""
|
||||
PdfExtractor is used to extract text and images from PDF files.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
file_path: Path to the PDF file.
|
||||
tenant_id: Workspace ID.
|
||||
user_id: ID of the user performing the extraction.
|
||||
file_cache_key: Optional cache key for the extracted text.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, file_cache_key: str | None = None):
|
||||
"""Initialize with file path."""
|
||||
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
|
||||
IMAGE_FORMATS = [
|
||||
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
|
||||
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
|
||||
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
|
||||
(b"GIF8", "gif", "image/gif"),
|
||||
(b"BM", "bmp", "image/bmp"),
|
||||
(b"II*\x00", "tiff", "image/tiff"),
|
||||
(b"MM\x00*", "tiff", "image/tiff"),
|
||||
(b"II+\x00", "tiff", "image/tiff"),
|
||||
(b"MM\x00+", "tiff", "image/tiff"),
|
||||
]
|
||||
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
|
||||
"""Initialize PdfExtractor."""
|
||||
self._file_path = file_path
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2 # type: ignore
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
|
||||
text_page = page.get_textpage()
|
||||
content = text_page.get_text_range()
|
||||
text_page.close()
|
||||
|
||||
image_content = self._extract_images(page)
|
||||
if image_content:
|
||||
content += "\n" + image_content
|
||||
|
||||
page.close()
|
||||
metadata = {"source": blob.source, "page": page_number}
|
||||
yield Document(page_content=content, metadata=metadata)
|
||||
finally:
|
||||
pdf_reader.close()
|
||||
|
||||
def _extract_images(self, page) -> str:
|
||||
"""
|
||||
Extract images from a PDF page, save them to storage and database,
|
||||
and return markdown image links.
|
||||
|
||||
Args:
|
||||
page: pypdfium2 page object.
|
||||
|
||||
Returns:
|
||||
Markdown string containing links to the extracted images.
|
||||
"""
|
||||
image_content = []
|
||||
upload_files = []
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
try:
|
||||
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
|
||||
for obj in image_objects:
|
||||
try:
|
||||
# Extract image bytes
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
|
||||
# Fallback to png for other formats
|
||||
obj.extract(img_byte_arr, fb_format="png")
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
|
||||
if not img_bytes:
|
||||
continue
|
||||
|
||||
header = img_bytes[: self.MAX_MAGIC_LEN]
|
||||
image_ext = None
|
||||
mime_type = None
|
||||
for magic, ext, mime in self.IMAGE_FORMATS:
|
||||
if header.startswith(magic):
|
||||
image_ext = ext
|
||||
mime_type = mime
|
||||
break
|
||||
|
||||
if not image_ext or not mime_type:
|
||||
continue
|
||||
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
|
||||
|
||||
storage.save(file_key, img_bytes)
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=len(img_bytes),
|
||||
extension=image_ext,
|
||||
mime_type=mime_type,
|
||||
created_by=self._user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self._user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
upload_files.append(upload_file)
|
||||
image_content.append(f"")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract image from PDF: %s", e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get objects from PDF page: %s", e)
|
||||
if upload_files:
|
||||
db.session.add_all(upload_files)
|
||||
db.session.commit()
|
||||
return "\n".join(image_content)
|
||||
|
||||
@ -7,10 +7,11 @@ import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import httpx
|
||||
from docx import Document as DocxDocument
|
||||
from docx.oxml.ns import qn
|
||||
from docx.text.run import Run
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
image_map = self._extract_images_from_docx(doc)
|
||||
|
||||
hyperlinks_url = None
|
||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
||||
for para in doc.paragraphs:
|
||||
for run in para.runs:
|
||||
if run.text and hyperlinks_url:
|
||||
result = f" [{run.text}]({hyperlinks_url}) "
|
||||
run.text = result
|
||||
hyperlinks_url = None
|
||||
if "HYPERLINK" in run.element.xml:
|
||||
try:
|
||||
xml = ElementTree.XML(run.element.xml)
|
||||
x_child = [c for c in xml.iter() if c is not None]
|
||||
for x in x_child:
|
||||
if x is None:
|
||||
continue
|
||||
if x.tag.endswith("instrText"):
|
||||
if x.text is None:
|
||||
continue
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse HYPERLINK xml")
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
def append_image_link(image_id, has_drawing):
|
||||
def append_image_link(image_id, has_drawing, target_buffer):
|
||||
"""Helper to append image link from image_map based on relationship type."""
|
||||
rel = doc.part.rels[image_id]
|
||||
if rel.is_external:
|
||||
if image_id in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
target_buffer.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
target_buffer.append(image_map[image_part])
|
||||
|
||||
for run in paragraph.runs:
|
||||
def process_run(run, target_buffer):
|
||||
# Helper to extract text and embedded images from a run element and append them to target_buffer
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
drawing_elements = run.element.findall(
|
||||
@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
|
||||
# External image: use embed_id as key
|
||||
if embed_id in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[embed_id])
|
||||
target_buffer.append(image_map[embed_id])
|
||||
else:
|
||||
# Internal image: use target_part as key
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
target_buffer.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
append_image_link(image_id, has_drawing, target_buffer)
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
append_image_link(image_id, has_drawing, target_buffer)
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
target_buffer.append(run.text.strip())
|
||||
|
||||
def process_hyperlink(hyperlink_elem, target_buffer):
|
||||
# Helper to extract text from a hyperlink element and append it to target_buffer
|
||||
r_id = hyperlink_elem.get(qn("r:id"))
|
||||
|
||||
# Extract text from runs inside the hyperlink
|
||||
link_text_parts = []
|
||||
for run_elem in hyperlink_elem.findall(qn("w:r")):
|
||||
run = Run(run_elem, paragraph)
|
||||
# Hyperlink text may be split across multiple runs (e.g., with different formatting),
|
||||
# so collect all run texts first
|
||||
if run.text:
|
||||
link_text_parts.append(run.text)
|
||||
|
||||
link_text = "".join(link_text_parts).strip()
|
||||
|
||||
# Resolve URL
|
||||
if r_id:
|
||||
try:
|
||||
rel = doc.part.rels.get(r_id)
|
||||
if rel and rel.is_external:
|
||||
link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
|
||||
|
||||
if link_text:
|
||||
target_buffer.append(link_text)
|
||||
|
||||
paragraph_content = []
|
||||
# State for legacy HYPERLINK fields
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts: list = []
|
||||
is_collecting_field_text = False
|
||||
# Iterate through paragraph elements in document order
|
||||
for child in paragraph._element:
|
||||
tag = child.tag
|
||||
if tag == qn("w:r"):
|
||||
# Regular run
|
||||
run = Run(child, paragraph)
|
||||
|
||||
# Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
|
||||
fld_chars = child.findall(qn("w:fldChar"))
|
||||
instr_texts = child.findall(qn("w:instrText"))
|
||||
|
||||
# Handle Fields
|
||||
if fld_chars or instr_texts:
|
||||
# Process instrText to find HYPERLINK "url"
|
||||
for instr in instr_texts:
|
||||
if instr.text and "HYPERLINK" in instr.text:
|
||||
# Quick regex to extract URL
|
||||
match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
|
||||
if match:
|
||||
hyperlink_field_url = match.group(1)
|
||||
|
||||
# Process fldChar
|
||||
for fld_char in fld_chars:
|
||||
fld_char_type = fld_char.get(qn("w:fldCharType"))
|
||||
if fld_char_type == "begin":
|
||||
# Start of a field: reset legacy link state
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts = []
|
||||
is_collecting_field_text = False
|
||||
elif fld_char_type == "separate":
|
||||
# Separator: if we found a URL, start collecting visible text
|
||||
if hyperlink_field_url:
|
||||
is_collecting_field_text = True
|
||||
elif fld_char_type == "end":
|
||||
# End of field
|
||||
if is_collecting_field_text and hyperlink_field_url:
|
||||
# Create markdown link and append to main content
|
||||
display_text = "".join(hyperlink_field_text_parts).strip()
|
||||
if display_text:
|
||||
link_md = f"[{display_text}]({hyperlink_field_url})"
|
||||
paragraph_content.append(link_md)
|
||||
# Reset state
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts = []
|
||||
is_collecting_field_text = False
|
||||
|
||||
# Decide where to append content
|
||||
target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
|
||||
process_run(run, target_buffer)
|
||||
elif tag == qn("w:hyperlink"):
|
||||
process_hyperlink(child, paragraph_content)
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
paragraphs = doc.paragraphs.copy()
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
||||
def deserialize(cls, serialized_data: str) -> TaskWrapper:
|
||||
return cls.model_validate_json(serialized_data)
|
||||
|
||||
|
||||
|
||||
@ -515,6 +515,7 @@ class DatasetRetrieval:
|
||||
0
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
dataset_count = len(available_datasets)
|
||||
with measure_time() as timer:
|
||||
cancel_event = threading.Event()
|
||||
thread_exceptions: list[Exception] = []
|
||||
@ -537,6 +538,7 @@ class DatasetRetrieval:
|
||||
"score_threshold": score_threshold,
|
||||
"query": query,
|
||||
"attachment_id": None,
|
||||
"dataset_count": dataset_count,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
@ -562,6 +564,7 @@ class DatasetRetrieval:
|
||||
"score_threshold": score_threshold,
|
||||
"query": None,
|
||||
"attachment_id": attachment_id,
|
||||
"dataset_count": dataset_count,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
@ -1195,18 +1198,24 @@ class DatasetRetrieval:
|
||||
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
@ -1422,6 +1431,7 @@ class DatasetRetrieval:
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
attachment_id: str | None,
|
||||
dataset_count: int,
|
||||
cancel_event: threading.Event | None = None,
|
||||
thread_exceptions: list[Exception] | None = None,
|
||||
):
|
||||
@ -1470,37 +1480,38 @@ class DatasetRetrieval:
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
# Skip second reranking when there is only one dataset
|
||||
if reranking_enable and dataset_count > 1:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
except Exception as e:
|
||||
if cancel_event:
|
||||
cancel_event.set()
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import Any, ClassVar
|
||||
|
||||
|
||||
class SchemaRegistry:
|
||||
@ -11,7 +13,7 @@ class SchemaRegistry:
|
||||
|
||||
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
||||
_default_instance: ClassVar[SchemaRegistry | None] = None
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self, base_dir: str):
|
||||
@ -20,7 +22,7 @@ class SchemaRegistry:
|
||||
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
|
||||
|
||||
@classmethod
|
||||
def default_registry(cls) -> "SchemaRegistry":
|
||||
def default_registry(cls) -> SchemaRegistry:
|
||||
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
|
||||
if cls._default_instance is None:
|
||||
with cls._lock:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
@ -25,7 +27,7 @@ class Tool(ABC):
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
:return: the new tool
|
||||
@ -221,7 +223,7 @@ class Tool(ABC):
|
||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||
)
|
||||
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
def create_file_message(self, file: File) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE,
|
||||
message=ToolInvokeMessage.FileMessage(),
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.__base.tool import Tool
|
||||
@ -24,7 +26,7 @@ class BuiltinTool(Tool):
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
:return: the new tool
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.tools = []
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
|
||||
credentials_schema = [
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
from collections.abc import Mapping
|
||||
@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
|
||||
MCP = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
def value_of(cls, value: str) -> ToolProviderType:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
|
||||
OPENAI_ACTIONS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
def value_of(cls, value: str) -> ApiProviderSchemaType:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
|
||||
API_KEY_QUERY = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
def value_of(cls, value: str) -> ApiProviderAuthType:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -307,7 +309,7 @@ class ToolParameter(PluginParameter):
|
||||
typ: ToolParameterType,
|
||||
required: bool,
|
||||
options: list[str] | None = None,
|
||||
) -> "ToolParameter":
|
||||
) -> ToolParameter:
|
||||
"""
|
||||
get a simple tool parameter
|
||||
|
||||
@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel):
|
||||
tool_config: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "ToolInvokeMeta":
|
||||
def empty(cls) -> ToolInvokeMeta:
|
||||
"""
|
||||
Get an empty instance of ToolInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "ToolInvokeMeta":
|
||||
def error_instance(cls, error: str) -> ToolInvokeMeta:
|
||||
"""
|
||||
Get an instance of ToolInvokeMeta with error
|
||||
"""
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -118,7 +120,7 @@ class MCPTool(Tool):
|
||||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
@ -46,7 +48,7 @@ class PluginTool(Tool):
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
|
||||
return PluginTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
|
||||
@ -7,12 +7,12 @@ import time
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
"""
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
# Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import re
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
Preserves markdown links like [text](url) at the start.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Check if text starts with a markdown link - preserve it
|
||||
markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
|
||||
if re.match(markdown_link_pattern, text):
|
||||
return text
|
||||
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import Field
|
||||
@ -47,14 +49,13 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
self.provider_id = provider_id
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
app = session.get(App, db_provider.app_id)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
@ -67,7 +68,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
provider_id="",
|
||||
provider_id=db_provider.id,
|
||||
)
|
||||
|
||||
controller.tools = [
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
@ -181,7 +183,7 @@ class WorkflowTool(Tool):
|
||||
return found
|
||||
return None
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.file.models import File
|
||||
|
||||
@ -52,7 +54,7 @@ class SegmentType(StrEnum):
|
||||
return self in _ARRAY_TYPES
|
||||
|
||||
@classmethod
|
||||
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
|
||||
def infer_segment_type(cls, value: Any) -> SegmentType | None:
|
||||
"""
|
||||
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
||||
|
||||
@ -173,7 +175,7 @@ class SegmentType(StrEnum):
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
@staticmethod
|
||||
def cast_value(value: Any, type_: "SegmentType"):
|
||||
def cast_value(value: Any, type_: SegmentType):
|
||||
# Cast Python's `bool` type to `int` when the runtime type requires
|
||||
# an integer or number.
|
||||
#
|
||||
@ -193,7 +195,7 @@ class SegmentType(StrEnum):
|
||||
return [int(i) for i in value]
|
||||
return value
|
||||
|
||||
def exposed_type(self) -> "SegmentType":
|
||||
def exposed_type(self) -> SegmentType:
|
||||
"""Returns the type exposed to the frontend.
|
||||
|
||||
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
||||
@ -202,7 +204,7 @@ class SegmentType(StrEnum):
|
||||
return SegmentType.NUMBER
|
||||
return self
|
||||
|
||||
def element_type(self) -> "SegmentType | None":
|
||||
def element_type(self) -> SegmentType | None:
|
||||
"""Return the element type of the current segment type, or `None` if the element type is undefined.
|
||||
|
||||
Raises:
|
||||
@ -217,7 +219,7 @@ class SegmentType(StrEnum):
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: "SegmentType"):
|
||||
def get_zero_value(t: SegmentType):
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
|
||||
engine.layer(ExecutionLimitsLayer(max_nodes=100))
|
||||
```
|
||||
|
||||
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
|
||||
can assume `graph_runtime_state` is available.
|
||||
|
||||
### Event-Driven Architecture
|
||||
|
||||
All node executions emit events for monitoring and integration:
|
||||
|
||||
@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
|
||||
implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
|
||||
graph: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any],
|
||||
started_at: datetime,
|
||||
) -> "WorkflowExecution":
|
||||
) -> WorkflowExecution:
|
||||
return WorkflowExecution(
|
||||
id_=id_,
|
||||
workflow_id=workflow_id,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
@ -175,7 +177,7 @@ class Graph:
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_factory: "NodeFactory",
|
||||
node_factory: NodeFactory,
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
Create node instances from configurations using the node factory.
|
||||
@ -197,7 +199,7 @@ class Graph:
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def new(cls) -> "GraphBuilder":
|
||||
def new(cls) -> GraphBuilder:
|
||||
"""Create a fluent builder for assembling a graph programmatically."""
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
@ -284,9 +286,9 @@ class Graph:
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: "NodeFactory",
|
||||
node_factory: NodeFactory,
|
||||
root_node_id: str | None = None,
|
||||
) -> "Graph":
|
||||
) -> Graph:
|
||||
"""
|
||||
Initialize graph
|
||||
|
||||
@ -383,7 +385,7 @@ class GraphBuilder:
|
||||
self._edges: list[Edge] = []
|
||||
self._edge_counter = 0
|
||||
|
||||
def add_root(self, node: Node) -> "GraphBuilder":
|
||||
def add_root(self, node: Node) -> GraphBuilder:
|
||||
"""Register the root node. Must be called exactly once."""
|
||||
|
||||
if self._nodes:
|
||||
@ -398,7 +400,7 @@ class GraphBuilder:
|
||||
*,
|
||||
from_node_id: str | None = None,
|
||||
source_handle: str = "source",
|
||||
) -> "GraphBuilder":
|
||||
) -> GraphBuilder:
|
||||
"""Append a node and connect it from the specified predecessor."""
|
||||
|
||||
if not self._nodes:
|
||||
@ -419,7 +421,7 @@ class GraphBuilder:
|
||||
|
||||
return self
|
||||
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
|
||||
"""Connect two existing nodes without adding a new node."""
|
||||
|
||||
if tail not in self._nodes_by_id:
|
||||
|
||||
@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
@ -113,6 +113,8 @@ class RedisChannel:
|
||||
return AbortCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
if command_type == CommandType.UPDATE_VARIABLES:
|
||||
return UpdateVariablesCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
@ -5,11 +5,12 @@ This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
"UpdateVariablesCommandHandler",
|
||||
]
|
||||
|
||||
@ -4,9 +4,10 @@ from typing import final
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
|
||||
|
||||
@final
|
||||
class UpdateVariablesCommandHandler(CommandHandler):
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, UpdateVariablesCommand)
|
||||
for update in command.updates:
|
||||
try:
|
||||
variable = update.value
|
||||
self._variable_pool.add(variable.selector, variable)
|
||||
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Skipping invalid variable selector %s for workflow %s: %s",
|
||||
getattr(update.value, "selector", None),
|
||||
execution.workflow_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.variables.variables import VariableUnion
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
ABORT = auto()
|
||||
PAUSE = auto()
|
||||
UPDATE_VARIABLES = auto()
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
|
||||
|
||||
class VariableUpdate(BaseModel):
|
||||
"""Represents a single variable update instruction."""
|
||||
|
||||
value: VariableUnion = Field(description="New variable value")
|
||||
|
||||
|
||||
class UpdateVariablesCommand(GraphEngineCommand):
|
||||
"""Command to update a group of variables in the variable pool."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
|
||||
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")
|
||||
|
||||
@ -5,9 +5,12 @@ This engine uses a modular architecture with separated packages following
|
||||
Domain-Driven Design principles for improved maintainability and testability.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@ -30,8 +33,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
|
||||
from .entities.commands import AbortCommand, PauseCommand
|
||||
from .command_processing import (
|
||||
AbortCommandHandler,
|
||||
CommandProcessor,
|
||||
PauseCommandHandler,
|
||||
UpdateVariablesCommandHandler,
|
||||
)
|
||||
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
@ -70,10 +78,13 @@ class GraphEngine:
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
@ -140,6 +151,9 @@ class GraphEngine:
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
|
||||
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
@ -169,6 +183,7 @@ class GraphEngine:
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@ -199,6 +214,7 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@ -212,9 +228,16 @@ class GraphEngine:
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
||||
def _bind_layer_context(
|
||||
self,
|
||||
layer: GraphEngineLayer,
|
||||
) -> None:
|
||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
||||
|
||||
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
|
||||
"""Add a layer for extending functionality."""
|
||||
self._layers.append(layer)
|
||||
self._bind_layer_context(layer)
|
||||
return self
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
@ -301,14 +324,7 @@ class GraphEngine:
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self._event_manager.set_layers(self._layers)
|
||||
# Create a read-only wrapper for the runtime state
|
||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(read_only_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception as e:
|
||||
@ -316,6 +332,7 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
@ -343,13 +360,12 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
|
||||
@ -60,6 +60,7 @@ class SkipPropagator:
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
self._state_manager.start_execution(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
|
||||
@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
|
||||
|
||||
Abstract base class for layers.
|
||||
|
||||
- `initialize()` - Receive runtime context
|
||||
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
|
||||
- `on_graph_start()` - Execution start hook
|
||||
- `on_event()` - Process all events
|
||||
- `on_graph_end()` - Execution end hook
|
||||
@ -34,6 +34,9 @@ engine.layer(debug_layer)
|
||||
engine.run()
|
||||
```
|
||||
|
||||
`engine.layer()` binds the read-only runtime state before execution, so
|
||||
`graph_runtime_state` is always available inside layer hooks.
|
||||
|
||||
## Custom Layers
|
||||
|
||||
```python
|
||||
|
||||
@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
class GraphEngineLayerNotInitializedError(Exception):
|
||||
"""Raised when a layer's runtime state is accessed before initialization."""
|
||||
|
||||
def __init__(self, layer_name: str | None = None) -> None:
|
||||
name = layer_name or "GraphEngineLayer"
|
||||
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
@property
|
||||
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
|
||||
if self._graph_runtime_state is None:
|
||||
raise GraphEngineLayerNotInitializedError(type(self).__name__)
|
||||
return self._graph_runtime_state
|
||||
|
||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine before execution starts to inject the read-only runtime state
|
||||
and command channel. This allows layers to observe engine context and send
|
||||
commands, but prevents direct state modification.
|
||||
|
||||
Called by GraphEngine to inject the read-only runtime state and command channel.
|
||||
This is invoked when the layer is registered with a `GraphEngine` instance.
|
||||
Implementations should be idempotent.
|
||||
Args:
|
||||
graph_runtime_state: Read-only view of the runtime state
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.info("=" * 80)
|
||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if self.graph_runtime_state:
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.info(" Node retries: %s", self.retry_count)
|
||||
|
||||
# Log final state if available
|
||||
if self.graph_runtime_state and self.include_outputs:
|
||||
if self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
if self.include_outputs and self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
if update_finished:
|
||||
execution.finished_at = naive_utc_now()
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return {}
|
||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
PauseCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -23,7 +29,6 @@ class GraphEngineManager:
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop and pause operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -45,6 +50,16 @@ class GraphEngineManager:
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
GraphEngineManager._send_command(task_id, update_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
@ -44,6 +44,7 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -61,7 +62,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event = stop_event
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@ -69,16 +70,14 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
Factory for creating ReadyQueue instances from serialized state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
@ -11,7 +13,7 @@ if TYPE_CHECKING:
|
||||
from .protocol import ReadyQueue
|
||||
|
||||
|
||||
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
||||
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
|
||||
"""
|
||||
Create a ReadyQueue instance from a serialized state.
|
||||
|
||||
|
||||
@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
@ -27,7 +29,7 @@ class ResponseSession:
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node) -> "ResponseSession":
|
||||
def from_node(cls, node: Node) -> ResponseSession:
|
||||
"""
|
||||
Create a ResponseSession from an AnswerNode or EndNode.
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
@ -65,13 +66,16 @@ class Worker(threading.Thread):
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@ -41,6 +41,7 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
@ -81,6 +82,7 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@ -135,7 +137,7 @@ class WorkerPool:
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
@ -152,6 +154,7 @@ class WorkerPool:
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: "PluginAgentStrategy",
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> "InvokeCredentials":
|
||||
) -> InvokeCredentials:
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(
|
||||
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
|
||||
@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMaxIterationError(AgentNodeError):
|
||||
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> "DefaultValue":
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
DefaultValueType.STRING: {
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import operator
|
||||
@ -62,7 +64,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
node_type: ClassVar["NodeType"]
|
||||
node_type: ClassVar[NodeType]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
|
||||
@ -201,14 +203,14 @@ class Node(Generic[NodeDataT]):
|
||||
return None
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
||||
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self.id = id
|
||||
@ -244,7 +246,7 @@ class Node(Generic[NodeDataT]):
|
||||
return
|
||||
|
||||
@property
|
||||
def graph_init_params(self) -> "GraphInitParams":
|
||||
def graph_init_params(self) -> GraphInitParams:
|
||||
return self._graph_init_params
|
||||
|
||||
@property
|
||||
@ -267,6 +269,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@ -335,6 +341,21 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
@ -441,7 +462,7 @@ class Node(Generic[NodeDataT]):
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
||||
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||
|
||||
@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
|
||||
similar to SegmentGroup but focused on template representation without values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -58,7 +60,7 @@ class Template:
|
||||
segments: list[TemplateSegmentUnion]
|
||||
|
||||
@classmethod
|
||||
def from_answer_template(cls, template_str: str) -> "Template":
|
||||
def from_answer_template(cls, template_str: str) -> Template:
|
||||
"""Create a Template from an Answer node template string.
|
||||
|
||||
Example:
|
||||
@ -107,7 +109,7 @@ class Template:
|
||||
return cls(segments=segments)
|
||||
|
||||
@classmethod
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
|
||||
"""Create a Template from an End node outputs configuration.
|
||||
|
||||
End nodes are treated as templates of concatenated variables with newlines.
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
@ -20,9 +20,41 @@ from .exc import (
|
||||
OutputValidationError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
|
||||
Python3CodeProvider,
|
||||
JavascriptCodeProvider,
|
||||
)
|
||||
_limits: CodeNodeLimits
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
|
||||
)
|
||||
self._limits = code_limits
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if filters:
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
code_provider: type[CodeNodeProvider] = next(
|
||||
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
|
||||
)
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
|
||||
return cls._DEFAULT_CODE_PROVIDERS
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
_ = self._select_code_provider(code_language)
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
|
||||
for provider in self._code_providers:
|
||||
if provider.is_accept_language(code_language):
|
||||
return provider
|
||||
raise CodeNodeError(f"Unsupported code language: {code_language}")
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
if len(value) > self._limits.max_string_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
f" less than {self._limits.max_string_length} characters"
|
||||
)
|
||||
|
||||
return value.replace("\x00", "")
|
||||
@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
if value > self._limits.max_number or value < self._limits.min_number:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
decimal_value = Decimal(str(value)).normalize()
|
||||
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
|
||||
# raise error if precision is too high
|
||||
if precision > dify_config.CODE_MAX_PRECISION:
|
||||
if precision > self._limits.max_precision:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
f" it must be less than {self._limits.max_precision} digits."
|
||||
)
|
||||
|
||||
return value
|
||||
@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
|
||||
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
|
||||
# Note that `_transform_result` may produce lists containing `None` values,
|
||||
# which don't conform to the type requirements of `Array*Segment` classes.
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
if depth > self._limits.max_depth:
|
||||
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
|
||||
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
if len(value) > self._limits.max_number_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_number_array_length} elements."
|
||||
)
|
||||
|
||||
for i, inner_value in enumerate(value):
|
||||
@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > self._limits.max_string_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_string_array_length} elements."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > self._limits.max_object_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_object_array_length} elements."
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
|
||||
13
api/core/workflow/nodes/code/limits.py
Normal file
13
api/core/workflow/nodes/code/limits.py
Normal file
@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CodeNodeLimits:
|
||||
max_string_length: int
|
||||
max_number: int | float
|
||||
min_number: int | float
|
||||
max_precision: int
|
||||
max_depth: int
|
||||
max_number_array_length: int
|
||||
max_string_array_length: int
|
||||
max_object_array_length: int
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
@ -134,7 +136,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
_file_outputs: list[File]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
@ -142,8 +144,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@ -445,7 +447,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
@ -499,7 +501,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
@ -675,7 +677,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _image_file_to_markdown(file: "File", /):
|
||||
def _image_file_to_markdown(file: File, /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
@ -924,7 +926,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@ -935,7 +937,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
context_files: list["File"] | None = None,
|
||||
context_files: list[File] | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
@ -1287,7 +1289,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
file_outputs: list[File],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
request_latency: float | None = None,
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
@ -1329,7 +1331,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
content: ImagePromptMessageContent,
|
||||
file_saver: LLMFileSaver,
|
||||
) -> "File":
|
||||
) -> File:
|
||||
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||
|
||||
There are two kinds of multimodal outputs:
|
||||
@ -1379,7 +1381,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
file_outputs: list[File],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||
|
||||
|
||||
@ -1,10 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory):
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
|
||||
)
|
||||
self._code_limits = code_limits or CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
max_precision=dify_config.CODE_MAX_PRECISION,
|
||||
max_depth=dify_config.CODE_MAX_DEPTH,
|
||||
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
@ -72,6 +103,25 @@ class DifyNodeFactory(NodeFactory):
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
# Create node instance
|
||||
if node_type == NodeType.CODE:
|
||||
return CodeNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
|
||||
_code_executor: type[CodeExecutor]
|
||||
|
||||
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
|
||||
self._code_executor = code_executor or CodeExecutor
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as exc:
|
||||
raise TemplateRenderError(str(exc)) from exc
|
||||
|
||||
rendered = result.get("result")
|
||||
if not isinstance(rendered, str):
|
||||
raise TemplateRenderError("Template render result must be a string.")
|
||||
return rendered
|
||||
@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
from .exc import VariableOperatorNodeError
|
||||
|
||||
|
||||
class ConversationVariableUpdaterImpl:
|
||||
def update(self, conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||
return ConversationVariableUpdaterImpl()
|
||||
@ -1,9 +1,8 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
||||
from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._conv_var_updater_factory = conv_var_updater_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
|
||||
@ -1,24 +1,20 @@
|
||||
import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
||||
from . import helpers
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidDataError,
|
||||
InvalidInputValueError,
|
||||
@ -26,6 +22,10 @@ from .exc import (
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
selector_node_id = item.variable_selector[0]
|
||||
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
|
||||
return False
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
# remove the duplicated items first.
|
||||
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
||||
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
# Update variables
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
raise VariableNotFoundError(variable_selector=selector)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
if self.invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
conv_var_updater.update(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [
|
||||
common_helpers.variable_to_processed_data(selector, seg)
|
||||
for selector in updated_variable_selectors
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> "DraftVariableSaver":
|
||||
) -> DraftVariableSaver:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -168,6 +169,7 @@ class GraphRuntimeState:
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""Return a copy of a variable value if present."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
value = self._variable_pool.get(selector)
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
@ -267,6 +269,6 @@ class VariablePool(BaseModel):
|
||||
self.add(selector, value)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "VariablePool":
|
||||
def empty(cls) -> VariablePool:
|
||||
"""Create an empty variable pool."""
|
||||
return cls(system_variables=SystemVariable.empty())
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
@ -70,7 +72,7 @@ class SystemVariable(BaseModel):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "SystemVariable":
|
||||
def empty(cls) -> SystemVariable:
|
||||
return cls()
|
||||
|
||||
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
||||
@ -114,7 +116,7 @@ class SystemVariable(BaseModel):
|
||||
d[SystemVariableKey.TIMESTAMP] = self.timestamp
|
||||
return d
|
||||
|
||||
def as_view(self) -> "SystemVariableReadOnlyView":
|
||||
def as_view(self) -> SystemVariableReadOnlyView:
|
||||
return SystemVariableReadOnlyView(self)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user