Merge remote-tracking branch 'upstream/main' into feat/human-input-merge-again

This commit is contained in:
QuantumGhost
2026-01-28 16:21:37 +08:00
4167 changed files with 345823 additions and 171263 deletions

View File

@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
@ -42,10 +43,28 @@ def init_app(app: DifyApp):
_apply_cors_once(
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
resources={
# Embedded bot endpoints (unauthenticated, cross-origin safe)
r"^/chat-messages$": {
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
"supports_credentials": False,
"allow_headers": list(EMBED_HEADERS),
"methods": ["GET", "POST", "OPTIONS"],
},
r"^/chat-messages/.*": {
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
"supports_credentials": False,
"allow_headers": list(EMBED_HEADERS),
"methods": ["GET", "POST", "OPTIONS"],
},
# Default web application endpoints (authenticated)
r"/*": {
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
"supports_credentials": True,
"allow_headers": list(AUTHENTICATED_HEADERS),
"methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
},
},
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(web_bp)

View File

@ -12,9 +12,8 @@ from dify_app import DifyApp
def _get_celery_ssl_options() -> dict[str, Any] | None:
"""Get SSL configuration for Celery broker/backend connections."""
# Use REDIS_USE_SSL for consistency with the main Redis client
# Only apply SSL if we're using Redis as broker/backend
if not dify_config.REDIS_USE_SSL:
if not dify_config.BROKER_USE_SSL:
return None
# Check if Celery is actually using Redis
@ -47,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
from core.logging.context import init_request_context
with app.app_context():
# Initialize logging context for this task (similar to before_request in Flask)
init_request_context()
return self.run(*args, **kwargs)
broker_transport_options = {}
@ -166,6 +169,13 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
"schedule": crontab(minute="0", hour="2"),
}
if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
# for saas only
imports.append("schedule.clean_workflow_runs_task")
beat_schedule["clean_workflow_runs_task"] = {
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
"schedule": crontab(minute="0", hour="0"),
}
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {

View File

@ -4,13 +4,18 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
archive_workflow_runs,
clean_expired_messages,
clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
extract_plugins,
extract_unique_plugins,
file_usage,
fix_app_site_missing,
install_plugins,
install_rag_pipeline_plugins,
@ -21,6 +26,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
restore_workflow_runs,
setup_datasource_oauth_client,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
@ -47,6 +53,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
file_usage,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
cleanup_orphaned_draft_variables,
@ -54,6 +61,11 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
archive_workflow_runs,
delete_archived_workflow_runs,
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
def init_app(app: DifyApp):
db.init_app(app)
_setup_gevent_compatibility()
# Eagerly build the engine so pool_size/max_overflow/etc. come from config
try:
with app.app_context():
_ = db.engine # triggers engine creation with the configured options
except Exception:
logger.exception("Failed to initialize SQLAlchemy engine during app startup")

View File

@ -0,0 +1,45 @@
from fastopenapi.routers import FlaskRouter
from flask_cors import CORS
from configs import dify_config
from controllers.fastopenapi import console_router
from dify_app import DifyApp
from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
DOCS_PREFIX = "/fastopenapi"
def init_app(app: DifyApp) -> None:
docs_enabled = dify_config.SWAGGER_UI_ENABLED
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
redoc_url = f"{DOCS_PREFIX}/redoc" if docs_enabled else None
openapi_url = f"{DOCS_PREFIX}/openapi.json" if docs_enabled else None
router = FlaskRouter(
app=app,
docs_url=docs_url,
redoc_url=redoc_url,
openapi_url=openapi_url,
openapi_version="3.0.0",
title="Dify API (FastOpenAPI PoC)",
version="1.0",
description="FastOpenAPI proof of concept for Dify API",
)
# Ensure route decorators are evaluated.
import controllers.console.ping as ping_module
from controllers.console import setup
_ = ping_module
_ = setup
router.include_router(console_router, prefix="/console/api")
CORS(
app,
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
)
app.extensions["fastopenapi"] = router

View File

@ -1,18 +1,19 @@
"""Logging extension for Dify Flask application."""
import logging
import os
import sys
import uuid
from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
def init_app(app: DifyApp):
"""Initialize logging with support for text or JSON format."""
log_handlers: list[logging.Handler] = []
# File handler
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@ -25,27 +26,53 @@ def init_app(app: DifyApp):
)
)
# Always add StreamHandler to log to console
# Console handler
sh = logging.StreamHandler(sys.stdout)
log_handlers.append(sh)
# Apply RequestIdFilter to all handlers
for handler in log_handlers:
handler.addFilter(RequestIdFilter())
# Apply filters to all handlers
from core.logging.filters import IdentityContextFilter, TraceContextFilter
for handler in log_handlers:
handler.addFilter(TraceContextFilter())
handler.addFilter(IdentityContextFilter())
# Configure formatter based on format type
formatter = _create_formatter()
for handler in log_handlers:
handler.setFormatter(formatter)
# Configure root logger
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
# Apply RequestIdFormatter to all handlers
apply_request_id_formatter()
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
# Apply timezone if specified (only for text format)
if dify_config.LOG_OUTPUT_FORMAT == "text":
_apply_timezone(log_handlers)
def _create_formatter() -> logging.Formatter:
"""Create appropriate formatter based on configuration."""
if dify_config.LOG_OUTPUT_FORMAT == "json":
from core.logging.structured_formatter import StructuredJSONFormatter
return StructuredJSONFormatter()
else:
# Text format - use existing pattern with backward compatible formatter
return _TextFormatter(
fmt=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
)
def _apply_timezone(handlers: list[logging.Handler]):
"""Apply timezone conversion to text formatters."""
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime
@ -57,34 +84,51 @@ def init_app(app: DifyApp):
def time_converter(seconds):
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers:
for handler in handlers:
if handler.formatter:
handler.formatter.converter = time_converter
handler.formatter.converter = time_converter # type: ignore[attr-defined]
def get_request_id():
if getattr(flask.g, "request_id", None):
return flask.g.request_id
class _TextFormatter(logging.Formatter):
"""Text formatter that ensures trace_id and req_id are always present."""
new_uuid = uuid.uuid4().hex[:10]
flask.g.request_id = new_uuid
return new_uuid
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
if not hasattr(record, "span_id"):
record.span_id = ""
return super().format(record)
def get_request_id() -> str:
"""Get request ID for current request context.
Deprecated: Use core.logging.context.get_request_id() directly.
"""
from core.logging.context import get_request_id as _get_request_id
return _get_request_id()
# Backward compatibility aliases
class RequestIdFilter(logging.Filter):
# This is a logging filter that makes the request ID available for use in
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
def filter(self, record: logging.LogRecord) -> bool:
from core.logging.context import get_request_id as _get_request_id
from core.logging.context import get_trace_id as _get_trace_id
record.req_id = _get_request_id()
record.trace_id = _get_trace_id()
return True
class RequestIdFormatter(logging.Formatter):
def format(self, record):
"""Deprecated: Use _TextFormatter instead."""
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
def apply_request_id_formatter():
"""Deprecated: Formatter is now applied in init_app."""
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)

View File

@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv
from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -19,12 +20,17 @@ def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
Logstore is considered enabled when:
1. All required Aliyun SLS environment variables are set
2. At least one repository configuration points to a logstore implementation
Returns:
True if all required Aliyun SLS environment variables are set, False otherwise
True if logstore should be initialized, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
# Check if Aliyun SLS connection parameters are configured
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME",
]
all_set = all(os.environ.get(var) for var in required_vars)
sls_vars_set = all(os.environ.get(var) for var in required_vars)
if not all_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
if not sls_vars_set:
return False
return all_set
# Check if any repository configuration points to logstore implementation
repository_configs = [
dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_RUN_REPOSITORY,
]
uses_logstore = any("logstore" in config.lower() for config in repository_configs)
if not uses_logstore:
return False
logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
return True
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
If initialization fails, the application continues running without logstore features.
Args:
app: The Dify application instance
@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...")
logger.info("Initializing Aliyun SLS Logstore...")
# Create logstore client and initialize project/logstores/indexes
# Create logstore client and initialize resources
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
except Exception:
logger.exception("Failed to initialize logstore")
# Don't raise - allow application to continue even if logstore init fails
# This ensures that the application can still run if logstore is misconfigured
logger.exception(
"Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
"Application will continue but logstore features will NOT work.",
os.environ.get("ALIYUN_SLS_ENDPOINT"),
os.environ.get("ALIYUN_SLS_REGION"),
os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
)
# Don't raise - allow application to continue even if logstore setup fails

View File

@ -1,5 +1,8 @@
from __future__ import annotations
import logging
import os
import socket
import threading
import time
from collections.abc import Sequence
@ -33,7 +36,7 @@ class AliyunLogStore:
Ensures only one instance exists to prevent multiple PG connection pools.
"""
_instance: "AliyunLogStore | None" = None
_instance: AliyunLogStore | None = None
_initialized: bool = False
# Track delayed PG connection for newly created projects
@ -66,7 +69,7 @@ class AliyunLogStore:
"\t",
]
def __new__(cls) -> "AliyunLogStore":
def __new__(cls) -> AliyunLogStore:
"""Implement singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
@ -177,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
self.log_enabled: bool = (
os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
)
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Get timeout configuration
check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
# Pre-check endpoint connectivity to prevent indefinite hangs
self._check_endpoint_connectivity(self.endpoint, check_timeout)
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@ -197,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True
@staticmethod
def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
"""
Check if the SLS endpoint is reachable before creating LogClient.
Prevents indefinite hangs when the endpoint is unreachable.
Args:
endpoint: SLS endpoint URL
timeout: Connection timeout in seconds
Raises:
ConnectionError: If endpoint is not reachable
"""
# Parse endpoint URL to extract hostname and port
from urllib.parse import urlparse
parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
hostname = parsed_url.hostname
port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
if not hostname:
raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
sock = None
try:
# Create socket and set timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect((hostname, port))
except Exception as e:
# Catch all exceptions and provide clear error message
error_type = type(e).__name__
raise ConnectionError(
f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
) from e
finally:
# Ensure socket is properly closed
if sock:
try:
sock.close()
except Exception: # noqa: S110
pass # Ignore errors during cleanup
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
@ -218,19 +273,16 @@ class AliyunLogStore:
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
logger.info("Using SDK mode for project %s", self.project_name)
return False
except Exception as e:
logger.warning(
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
self.project_name,
str(e),
)
logger.info("Using SDK mode for project %s", self.project_name)
logger.debug("PG connection details: %s", str(e))
self._use_pg_protocol = False
return False
@ -244,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol:
return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
@ -282,11 +330,7 @@ class AliyunLogStore:
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
logger.info(
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
logger.info("Using SDK mode for project %s (newly created)", self.project_name)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
@ -297,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@ -316,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
"PG protocol requires scan_index to be enabled.",
"Logstore %s requires scan_index enabled, using SDK mode for project %s",
logstore_name,
self.project_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
@ -746,7 +789,6 @@ class AliyunLogStore:
reverse=reverse,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@ -768,7 +810,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@ -843,7 +884,6 @@ class AliyunLogStore:
query=full_query,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@ -851,8 +891,7 @@ class AliyunLogStore:
self.project_name,
from_time,
to_time,
query,
sql,
full_query,
)
try:
@ -863,7 +902,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",

View File

@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.pool
from psycopg2 import InterfaceError, OperationalError
from sqlalchemy import create_engine
from configs import dify_config
@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG:
"""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
"""PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
"""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
"""Fast TCP port check to avoid long waits on unsupported regions."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False
def init_connection(self) -> bool:
"""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
"""Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
# Pool configuration
pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
logger.debug(
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
# Fast port connectivity check before attempting full connection
# This prevents long waits when connecting to unsupported regions
# Fast port check to avoid long waits
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info(
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
logger.debug("Using SDK mode for host=%s", pg_host)
return False
# Create connection pool
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=pg_max_connections,
host=pg_host,
port=5432,
database=self.project_name,
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
# Build connection URL
from urllib.parse import quote_plus
username = quote_plus(self._access_key_id)
password = quote_plus(self._access_key_secret)
database_url = (
f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
)
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
# Connection pool creation success already indicates connectivity
# Create SQLAlchemy engine with connection pool
self._engine = create_engine(
database_url,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
pool_timeout=30,
connect_args={
"connect_timeout": 5,
"application_name": f"Dify-{dify_config.project.version}-fixautocommit",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
},
)
self._use_pg_protocol = True
logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
"PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name,
pool_size,
pool_recycle,
)
return True
except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
if self._pg_pool:
if self._engine:
try:
self._pg_pool.closeall()
self._engine.dispose()
except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
self._pg_pool = None
logger.debug("Failed to dispose engine during cleanup, ignoring")
self._engine = None
logger.info(
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
logger.debug("Using SDK mode for region: %s", str(e))
return False
@contextmanager
def _get_connection(self):
"""
Context manager to get a PostgreSQL connection from the pool.
"""Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
if not self._engine:
raise RuntimeError("SQLAlchemy engine is not initialized")
Automatically validates and refreshes stale connections.
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
connection = self._engine.raw_connection()
try:
# Validate connection and get a fresh one if needed
if not self._is_connection_valid(conn):
logger.debug("Connection is stale, marking as bad and getting a new one")
# Mark connection as bad and get a new one
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
connection.autocommit = True # SLS PG protocol does not support transactions
yield connection
except Exception:
raise
finally:
# Return connection to pool (or close if it's bad)
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
connection.close()
def close(self) -> None:
"""Close the PostgreSQL connection pool."""
if self._pg_pool:
"""Dispose SQLAlchemy engine and close all connections."""
if self._engine:
try:
self._pg_pool.closeall()
logger.info("PG connection pool closed")
self._engine.dispose()
logger.info("SQLAlchemy engine disposed")
except Exception:
logger.exception("Failed to close PG connection pool")
logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool:
"""
Check if an error is retriable (connection-related issues).
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
"""Check if error is retriable (connection-related issues)."""
# Check for psycopg2 connection errors directly
if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
return True
# Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer",
"no route to host",
"network",
"operational error",
"interface error",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
"""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
if not contents:
return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents),
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
# Success - exit retry loop
return
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
"""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql,
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
"Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore,
sql,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
"Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
# This line should never be reached due to raise above, but makes type checker happy
return []

View File

@ -0,0 +1,29 @@
"""
LogStore repository utilities.
"""
from typing import Any
def safe_float(value: Any, default: float = 0.0) -> float:
"""
Safely convert a value to float, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value: Any, default: int = 0) -> int:
"""
Safely convert a value to int, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default

View File

@ -15,6 +15,8 @@ from sqlalchemy.orm import sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@ -53,9 +55,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.index = int(data.get("index", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
model.index = safe_int(data.get("index", 0))
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
@ -131,6 +132,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_id = escape_identifier(workflow_id)
escaped_node_id = escape_identifier(node_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -139,10 +146,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_id = '{workflow_id}'
AND node_id = '{node_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_id = '{escaped_workflow_id}'
AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@ -154,7 +161,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else:
# Use SDK with LogStore query syntax
query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -230,6 +238,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_run_id = escape_identifier(workflow_run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -238,9 +251,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_run_id = '{workflow_run_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
@ -251,7 +264,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
query = (
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_run_id: {escaped_workflow_run_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -318,16 +334,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
# Escape parameters to prevent SQL injection
escaped_execution_id = escape_identifier(execution_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
if tenant_id:
escaped_tenant_id = escape_identifier(tenant_id)
tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
else:
tenant_filter = ""
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
@ -337,10 +361,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
# Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}"
query = (
f"id:{escape_logstore_query_value(execution_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
)
else:
query = f"id: {execution_id}"
query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0
to_time = int(time.time()) # now

View File

@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
- SQL injection prevention via parameter escaping
"""
import logging
@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.total_tokens = int(data.get("total_tokens", 0))
model.total_steps = int(data.get("total_steps", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
model.total_tokens = safe_int(data.get("total_tokens", 0))
model.total_steps = safe_int(data.get("total_steps", 0))
model.exceptions_count = safe_int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
model.elapsed_time = float(data.get("elapsed_time", 0))
# Use safe conversion to handle 'null' strings and None values
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model
@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status,
)
# Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
# Build triggered_from filter
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Build status filter
status_filter = f"AND status='{status}'" if status else ""
# Build triggered_from filter with escaped values
# Support both enum and string values for triggered_from
triggered_from_filter = " OR ".join(
[
f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
for tf in triggered_from_list
]
)
# Build status filter with escaped value
status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
# Escape parameters to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}'
AND tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
# Note: Values must be quoted in LogStore query syntax to prevent injection
query = (
f"id:{escape_logstore_query_value(run_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
f"and app_id:{escape_logstore_query_value(app_id)}"
)
from_time = 0
to_time = int(time.time()) # now
@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
# Escape parameter to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id}"
# Note: Values must be quoted in LogStore query syntax
query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0
to_time = int(time.time()) # now
@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from,
status,
)
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter
time_filter = ""
if time_range:
@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count
if status:
escaped_status = escape_sql_string(status)
if status == "running":
# Running status requires window function
sql = f"""
@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='{status}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='{escaped_status}'
AND finished_at IS NOT NULL
{time_filter}
"""
@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
# Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
@ -67,7 +68,12 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
"""
@ -96,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
# Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
json_converter = WorkflowRuntimeTypeConverter()
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
@ -108,9 +117,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
),
("type", domain_model.workflow_type.value),
("version", domain_model.workflow_version),
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
(
"graph",
json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"outputs",
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),
("status", domain_model.status.value),
("error_message", domain_model.error_message or ""),
("total_tokens", str(domain_model.total_tokens)),

View File

@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id
from models import (
Account,
@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
index=int(data.get("index", 0)),
index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs,
status=status,
error=data.get("error"),
elapsed_time=float(data.get("elapsed_time", 0.0)),
elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
For LogStore implementation, this is similar to save() since we always write
complete records. We append a new record with updated data fields.
For LogStore implementation, this is a no-op for the LogStore write because save()
already writes all fields including inputs, process_data, and outputs. The caller
typically calls save() first to persist status/metadata, then calls save_execution_data()
to persist data fields. Since LogStore writes complete records atomically, we don't
need a separate write here to avoid duplicate records.
However, if dual-write is enabled, we still need to call the SQL repository's
save_execution_data() method to properly update the SQL database.
Args:
execution: The NodeExecution instance with data to save
"""
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
# In LogStore, we simply write a new complete record with the data
# The log_version timestamp will ensure this is treated as the latest version
self.save(execution)
logger.debug(
"save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
execution.id,
execution.node_execution_id,
)
# No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
# Calling save() again would create a duplicate record in the append-only LogStore
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save_execution_data(execution)
logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
except Exception:
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run(
self,
@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
This ensures we only get the final version of each node execution.
Uses LogStore SQL query with window function to get the latest version of each node execution.
This ensures we only get the most recent version of each node execution record.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances
Note:
This method filters by finished_at IS NOT NULL to avoid duplicates from
version updates. For complete history including intermediate states,
a different query strategy would be needed.
This method uses ROW_NUMBER() window function partitioned by node_execution_id
to get the latest version (highest log_version) of each node execution.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
# Build SQL query with deduplication using finished_at IS NOT NULL
# This optimization avoids window functions for common case where we only
# want the final state of each node execution
# Build SQL query with deduplication using window function
# ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
# ensures we get the latest version of each node execution
# Build ORDER BY clause
# Escape parameters to prevent SQL injection
escaped_workflow_run_id = escape_identifier(workflow_run_id)
escaped_tenant_id = escape_identifier(self._tenant_id)
# Build ORDER BY clause for outer query
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
sql = f"""
SELECT *
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{workflow_run_id}'
AND tenant_id='{self._tenant_id}'
AND finished_at IS NOT NULL
"""
# Build app_id filter for subquery
app_id_filter = ""
if self._app_id:
sql += f" AND app_id='{self._app_id}'"
escaped_app_id = escape_identifier(self._app_id)
app_id_filter = f" AND app_id='{escaped_app_id}'"
# Use window function to get latest version of each node execution
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{escaped_workflow_run_id}'
AND tenant_id='{escaped_tenant_id}'
{app_id_filter}
) t
WHERE rn = 1
"""
if order_clause:
sql += f" {order_clause}"

View File

@ -0,0 +1,134 @@
"""
SQL Escape Utility for LogStore Queries
This module provides escaping utilities to prevent injection attacks in LogStore queries.
LogStore supports two query modes:
1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
Key Security Concerns:
- Prevent tenant A from accessing tenant B's data via injection
- SLS queries are read-only, so we focus on data access control
- Different escaping strategies for SQL vs LogStore query syntax
"""
def escape_sql_string(value: str) -> str:
"""
Escape a string value for safe use in SQL queries.
This function escapes single quotes by doubling them, which is the standard
SQL escaping method. This prevents SQL injection by ensuring that user input
cannot break out of string literals.
Args:
value: The string value to escape
Returns:
Escaped string safe for use in SQL queries
Examples:
>>> escape_sql_string("normal_value")
"normal_value"
>>> escape_sql_string("value' OR '1'='1")
"value'' OR ''1''=''1"
>>> escape_sql_string("tenant's_id")
"tenant''s_id"
Security:
- Prevents breaking out of string literals
- Stops injection attacks like: ' OR '1'='1
- Protects against cross-tenant data access
"""
if not value:
return value
# Escape single quotes by doubling them (standard SQL escaping)
# This prevents breaking out of string literals in SQL queries
return value.replace("'", "''")
def escape_identifier(value: str) -> str:
"""
Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
This function is for PG protocol mode (SQL syntax).
For SDK mode, use escape_logstore_query_value() instead.
Args:
value: The identifier value to escape
Returns:
Escaped identifier safe for use in SQL queries
Examples:
>>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
"550e8400-e29b-41d4-a716-446655440000"
>>> escape_identifier("tenant_id' OR '1'='1")
"tenant_id'' OR ''1''=''1"
Security:
- Prevents SQL injection via identifiers
- Stops cross-tenant access attempts
- Works for UUIDs, alphanumeric IDs, and similar identifiers
"""
# For identifiers, use the same escaping as strings
# This is simple and effective for preventing injection
return escape_sql_string(value)
def escape_logstore_query_value(value: str) -> str:
"""
Escape value for LogStore query syntax (SDK mode).
LogStore query syntax rules:
1. Keywords (and/or/not) are case-insensitive
2. Single quotes are ordinary characters (no special meaning)
3. Double quotes wrap values: key:"value"
4. Backslash is the escape character:
- \" for double quote inside value
- \\ for backslash itself
5. Parentheses can change query structure
To prevent injection:
- Wrap value in double quotes to treat special chars as literals
- Escape backslashes and double quotes using backslash
Args:
value: The value to escape for LogStore query syntax
Returns:
Quoted and escaped value safe for LogStore query syntax (includes the quotes)
Examples:
>>> escape_logstore_query_value("normal_value")
'"normal_value"'
>>> escape_logstore_query_value("value or field:evil")
'"value or field:evil"' # 'or' and ':' are now literals
>>> escape_logstore_query_value('value"test')
'"value\\"test"' # Internal double quote escaped
>>> escape_logstore_query_value('value\\test')
'"value\\\\test"' # Backslash escaped
Security:
- Prevents injection via and/or/not keywords
- Prevents injection via colons (:)
- Prevents injection via parentheses
- Protects against cross-tenant data access
Note:
Escape order is critical: backslash first, then double quotes.
Otherwise, we'd double-escape the escape character itself.
"""
if not value:
return '""'
# IMPORTANT: Escape backslashes FIRST, then double quotes
# This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
escaped = value.replace("\\", "\\\\") # \ -> \\
escaped = escaped.replace('"', '\\"') # " -> \"
# Wrap in double quotes to treat as literal string
# This prevents and/or/not/:/() from being interpreted as operators
return f'"{escaped}"'

View File

@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
class ExceptionLoggingHandler(logging.Handler):
"""
Handler that records exceptions to the current OpenTelemetry span.
Unlike creating a new span, this records exceptions on the existing span
to maintain trace context consistency throughout the request lifecycle.
"""
def emit(self, record: logging.LogRecord):
with contextlib.suppress(Exception):
if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
) as span:
span.set_status(StatusCode.ERROR)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
span.set_attribute("exception.message", str(record.exc_info[1]))
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
if not record.exc_info:
return
from opentelemetry.trace import get_current_span
span = get_current_span()
if not span or not span.is_recording():
return
# Record exception on the current span instead of creating a new one
span.set_status(StatusCode.ERROR, record.getMessage())
# Add log context as span events/attributes
span.add_event(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
def instrument_exception_logging() -> None:

View File

@ -0,0 +1,20 @@
"""
OpenTelemetry node parsers for workflow nodes.
This module provides parsers that extract node-specific metadata and set
OpenTelemetry span attributes according to semantic conventions.
"""
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
from extensions.otel.parser.llm import LLMNodeOTelParser
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
from extensions.otel.parser.tool import ToolNodeOTelParser
__all__ = [
"DefaultNodeOTelParser",
"LLMNodeOTelParser",
"NodeOTelParser",
"RetrievalNodeOTelParser",
"ToolNodeOTelParser",
"safe_json_dumps",
]

View File

@ -0,0 +1,117 @@
"""
Base parser interface and utilities for OpenTelemetry node parsers.
"""
import json
from typing import Any, Protocol
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from pydantic import BaseModel
from core.file.models import File
from core.variables import Segment
from core.workflow.enums import NodeType
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
"""
Safely serialize objects to JSON, handling non-serializable types.
Handles:
- Segment types (ArrayFileSegment, FileSegment, etc.) - converts to their value
- File objects - converts to dict using to_dict()
- BaseModel objects - converts using model_dump()
- Other types - falls back to str() representation
Args:
obj: Object to serialize
ensure_ascii: Whether to ensure ASCII encoding
Returns:
JSON string representation of the object
"""
def _convert_value(value: Any) -> Any:
"""Recursively convert non-serializable values."""
if value is None:
return None
if isinstance(value, (bool, int, float, str)):
return value
if isinstance(value, Segment):
# Convert Segment to its underlying value
return _convert_value(value.value)
if isinstance(value, File):
# Convert File to dict
return value.to_dict()
if isinstance(value, BaseModel):
# Convert Pydantic model to dict
return _convert_value(value.model_dump(mode="json"))
if isinstance(value, dict):
return {k: _convert_value(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_convert_value(item) for item in value]
# Fallback to string representation for unknown types
return str(value)
try:
converted = _convert_value(obj)
return json.dumps(converted, ensure_ascii=ensure_ascii)
except (TypeError, ValueError) as e:
# If conversion still fails, return error message as string
return json.dumps(
{"error": f"Failed to serialize: {type(obj).__name__}", "message": str(e)}, ensure_ascii=ensure_ascii
)
class NodeOTelParser(Protocol):
"""Parser interface for node-specific OpenTelemetry enrichment."""
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None: ...
class DefaultNodeOTelParser:
"""Fallback parser used when no node-specific parser is registered."""
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
span.set_attribute("node.id", node.id)
if node.execution_id:
span.set_attribute("node.execution_id", node.execution_id)
if hasattr(node, "node_type") and node.node_type:
span.set_attribute("node.type", node.node_type.value)
span.set_attribute(GenAIAttributes.FRAMEWORK, "dify")
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
if node_type == NodeType.LLM:
span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
elif node_type == NodeType.TOOL:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
else:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
else:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
# Extract inputs and outputs from result_event
if result_event and result_event.node_run_result:
node_run_result = result_event.node_run_result
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
else:
span.set_status(Status(StatusCode.OK))

View File

@ -0,0 +1,155 @@
"""
Parser for LLM nodes that captures LLM-specific metadata.
"""
import logging
from collections.abc import Mapping
from typing import Any
from opentelemetry.trace import Span
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.semconv.gen_ai import LLMAttributes
logger = logging.getLogger(__name__)
def _format_input_messages(process_data: Mapping[str, Any]) -> str:
"""
Format input messages from process_data for LLM spans.
Args:
process_data: Process data containing prompts
Returns:
JSON string of formatted input messages
"""
try:
if not isinstance(process_data, dict):
return safe_json_dumps([])
prompts = process_data.get("prompts", [])
if not prompts:
return safe_json_dumps([])
valid_roles = {"system", "user", "assistant", "tool"}
input_messages = []
for prompt in prompts:
if not isinstance(prompt, dict):
continue
role = prompt.get("role", "")
text = prompt.get("text", "")
if not role or role not in valid_roles:
continue
if text:
message = {"role": role, "parts": [{"type": "text", "content": text}]}
input_messages.append(message)
return safe_json_dumps(input_messages)
except Exception as e:
logger.warning("Failed to format input messages: %s", e, exc_info=True)
return safe_json_dumps([])
def _format_output_messages(outputs: Mapping[str, Any]) -> str:
"""
Format output messages from outputs for LLM spans.
Args:
outputs: Output data containing text and finish_reason
Returns:
JSON string of formatted output messages
"""
try:
if not isinstance(outputs, dict):
return safe_json_dumps([])
text = outputs.get("text", "")
finish_reason = outputs.get("finish_reason", "")
if not text:
return safe_json_dumps([])
valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
if finish_reason not in valid_finish_reasons:
finish_reason = "stop"
output_message = {
"role": "assistant",
"parts": [{"type": "text", "content": text}],
"finish_reason": finish_reason,
}
return safe_json_dumps([output_message])
except Exception as e:
logger.warning("Failed to format output messages: %s", e, exc_info=True)
return safe_json_dumps([])
class LLMNodeOTelParser:
"""Parser for LLM nodes that captures LLM-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
if not result_event or not result_event.node_run_result:
return
node_run_result = result_event.node_run_result
process_data = node_run_result.process_data or {}
outputs = node_run_result.outputs or {}
# Extract usage data (from process_data or outputs)
usage_data = process_data.get("usage") or outputs.get("usage") or {}
# Model and provider information
model_name = process_data.get("model_name") or ""
model_provider = process_data.get("model_provider") or ""
if model_name:
span.set_attribute(LLMAttributes.REQUEST_MODEL, model_name)
if model_provider:
span.set_attribute(LLMAttributes.PROVIDER_NAME, model_provider)
# Token usage
if usage_data:
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
total_tokens = usage_data.get("total_tokens", 0)
span.set_attribute(LLMAttributes.USAGE_INPUT_TOKENS, prompt_tokens)
span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens)
span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens)
# Prompts and completion
prompts = process_data.get("prompts", [])
if prompts:
prompts_json = safe_json_dumps(prompts)
span.set_attribute(LLMAttributes.PROMPT, prompts_json)
text_output = str(outputs.get("text", ""))
if text_output:
span.set_attribute(LLMAttributes.COMPLETION, text_output)
# Finish reason
finish_reason = outputs.get("finish_reason") or ""
if finish_reason:
span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason)
# Structured input/output messages
gen_ai_input_message = _format_input_messages(process_data)
gen_ai_output_message = _format_output_messages(outputs)
span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)

View File

@ -0,0 +1,105 @@
"""
Parser for knowledge retrieval nodes that captures retrieval-specific metadata.
"""
import logging
from collections.abc import Sequence
from typing import Any
from opentelemetry.trace import Span
from core.variables import Segment
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.semconv.gen_ai import RetrieverAttributes
logger = logging.getLogger(__name__)
def _format_retrieval_documents(retrieval_documents: list[Any]) -> list:
"""
Format retrieval documents for semantic conventions.
Args:
retrieval_documents: List of retrieval document dictionaries
Returns:
List of formatted semantic documents
"""
try:
if not isinstance(retrieval_documents, list):
return []
semantic_documents = []
for doc in retrieval_documents:
if not isinstance(doc, dict):
continue
metadata = doc.get("metadata", {})
content = doc.get("content", "")
title = doc.get("title", "")
score = metadata.get("score", 0.0)
document_id = metadata.get("document_id", "")
semantic_metadata = {}
if title:
semantic_metadata["title"] = title
if metadata.get("source"):
semantic_metadata["source"] = metadata["source"]
elif metadata.get("_source"):
semantic_metadata["source"] = metadata["_source"]
if metadata.get("doc_metadata"):
doc_metadata = metadata["doc_metadata"]
if isinstance(doc_metadata, dict):
semantic_metadata.update(doc_metadata)
semantic_doc = {
"document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
}
semantic_documents.append(semantic_doc)
return semantic_documents
except Exception as e:
logger.warning("Failed to format retrieval documents: %s", e, exc_info=True)
return []
class RetrievalNodeOTelParser:
"""Parser for knowledge retrieval nodes that captures retrieval-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
if not result_event or not result_event.node_run_result:
return
node_run_result = result_event.node_run_result
inputs = node_run_result.inputs or {}
outputs = node_run_result.outputs or {}
# Extract query from inputs
query = str(inputs.get("query", "")) if inputs else ""
if query:
span.set_attribute(RetrieverAttributes.QUERY, query)
# Extract and format retrieval documents from outputs
result_value = outputs.get("result") if outputs else None
retrieval_documents: list[Any] = []
if result_value:
value_to_check = result_value
if isinstance(result_value, Segment):
value_to_check = result_value.value
if isinstance(value_to_check, (list, Sequence)):
retrieval_documents = list(value_to_check)
if retrieval_documents:
semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)

View File

@ -0,0 +1,47 @@
"""
Parser for tool nodes that captures tool-specific metadata.
"""
from opentelemetry.trace import Span
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.semconv.gen_ai import ToolAttributes
class ToolNodeOTelParser:
"""Parser for tool nodes that captures tool-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
tool_data = getattr(node, "_node_data", None)
if not isinstance(tool_data, ToolNodeData):
return
span.set_attribute(ToolAttributes.TOOL_NAME, node.title)
span.set_attribute(ToolAttributes.TOOL_TYPE, tool_data.provider_type.value)
# Extract tool info from metadata (consistent with aliyun_trace)
tool_info = {}
if result_event and result_event.node_run_result:
node_run_result = result_event.node_run_result
if node_run_result.metadata:
tool_info = node_run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
if tool_info:
span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info))
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))

View File

@ -1,6 +1,13 @@
"""Semantic convention shortcuts for Dify-specific spans."""
from .dify import DifySpanAttributes
from .gen_ai import GenAIAttributes
from .gen_ai import ChainAttributes, GenAIAttributes, LLMAttributes, RetrieverAttributes, ToolAttributes
__all__ = ["DifySpanAttributes", "GenAIAttributes"]
__all__ = [
"ChainAttributes",
"DifySpanAttributes",
"GenAIAttributes",
"LLMAttributes",
"RetrieverAttributes",
"ToolAttributes",
]

View File

@ -62,3 +62,37 @@ class ToolAttributes:
TOOL_CALL_RESULT = "gen_ai.tool.call.result"
"""Tool invocation result."""
class LLMAttributes:
"""LLM operation attribute keys."""
REQUEST_MODEL = "gen_ai.request.model"
"""Model identifier."""
PROVIDER_NAME = "gen_ai.provider.name"
"""Provider name."""
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
"""Number of input tokens."""
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""Number of output tokens."""
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
"""Total number of tokens."""
PROMPT = "gen_ai.prompt"
"""Prompt text."""
COMPLETION = "gen_ai.completion"
"""Completion text."""
RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
"""Finish reason for the response."""
INPUT_MESSAGE = "gen_ai.input.messages"
"""Input messages in structured format."""
OUTPUT_MESSAGE = "gen_ai.output.messages"
"""Output messages in structured format."""

View File

@ -5,6 +5,8 @@ automatic cleanup, backup and restore.
Supports complete lifecycle management for knowledge base files.
"""
from __future__ import annotations
import json
import logging
import operator
@ -48,7 +50,7 @@ class FileMetadata:
return data
@classmethod
def from_dict(cls, data: dict) -> "FileMetadata":
def from_dict(cls, data: dict) -> FileMetadata:
"""Create instance from dictionary"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])

View File

@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage):
super().__init__()
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
config = CosConfig(
Region=dify_config.TENCENT_COS_REGION,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
)
if dify_config.TENCENT_COS_CUSTOM_DOMAIN:
config = CosConfig(
Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
)
else:
config = CosConfig(
Region=dify_config.TENCENT_COS_REGION,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
)
self.client = CosS3Client(config)
def save(self, filename, data):