mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 05:06:15 +08:00
Merge remote-tracking branch 'upstream/main' into feat/human-input-merge-again
This commit is contained in:
@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
|
||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
|
||||
|
||||
@ -42,10 +43,28 @@ def init_app(app: DifyApp):
|
||||
|
||||
_apply_cors_once(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
resources={
|
||||
# Embedded bot endpoints (unauthenticated, cross-origin safe)
|
||||
r"^/chat-messages$": {
|
||||
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||
"supports_credentials": False,
|
||||
"allow_headers": list(EMBED_HEADERS),
|
||||
"methods": ["GET", "POST", "OPTIONS"],
|
||||
},
|
||||
r"^/chat-messages/.*": {
|
||||
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||
"supports_credentials": False,
|
||||
"allow_headers": list(EMBED_HEADERS),
|
||||
"methods": ["GET", "POST", "OPTIONS"],
|
||||
},
|
||||
# Default web application endpoints (authenticated)
|
||||
r"/*": {
|
||||
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||
"supports_credentials": True,
|
||||
"allow_headers": list(AUTHENTICATED_HEADERS),
|
||||
"methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
},
|
||||
},
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
)
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
@ -12,9 +12,8 @@ from dify_app import DifyApp
|
||||
|
||||
def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Use REDIS_USE_SSL for consistency with the main Redis client
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
if not dify_config.BROKER_USE_SSL:
|
||||
return None
|
||||
|
||||
# Check if Celery is actually using Redis
|
||||
@ -47,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
def init_app(app: DifyApp) -> Celery:
|
||||
class FlaskTask(Task):
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
from core.logging.context import init_request_context
|
||||
|
||||
with app.app_context():
|
||||
# Initialize logging context for this task (similar to before_request in Flask)
|
||||
init_request_context()
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
broker_transport_options = {}
|
||||
@ -166,6 +169,13 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
|
||||
"schedule": crontab(minute="0", hour="2"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
|
||||
# for saas only
|
||||
imports.append("schedule.clean_workflow_runs_task")
|
||||
beat_schedule["clean_workflow_runs_task"] = {
|
||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
|
||||
@ -4,13 +4,18 @@ from dify_app import DifyApp
|
||||
def init_app(app: DifyApp):
|
||||
from commands import (
|
||||
add_qdrant_index,
|
||||
archive_workflow_runs,
|
||||
clean_expired_messages,
|
||||
clean_workflow_runs,
|
||||
cleanup_orphaned_draft_variables,
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
fix_app_site_missing,
|
||||
install_plugins,
|
||||
install_rag_pipeline_plugins,
|
||||
@ -21,6 +26,7 @@ def init_app(app: DifyApp):
|
||||
reset_email,
|
||||
reset_encrypt_key_pair,
|
||||
reset_password,
|
||||
restore_workflow_runs,
|
||||
setup_datasource_oauth_client,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
@ -47,6 +53,7 @@ def init_app(app: DifyApp):
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
remove_orphaned_files_on_storage,
|
||||
file_usage,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
cleanup_orphaned_draft_variables,
|
||||
@ -54,6 +61,11 @@ def init_app(app: DifyApp):
|
||||
setup_datasource_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
install_rag_pipeline_plugins,
|
||||
archive_workflow_runs,
|
||||
delete_archived_workflow_runs,
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
|
||||
def init_app(app: DifyApp):
|
||||
db.init_app(app)
|
||||
_setup_gevent_compatibility()
|
||||
|
||||
# Eagerly build the engine so pool_size/max_overflow/etc. come from config
|
||||
try:
|
||||
with app.app_context():
|
||||
_ = db.engine # triggers engine creation with the configured options
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize SQLAlchemy engine during app startup")
|
||||
|
||||
45
api/extensions/ext_fastopenapi.py
Normal file
45
api/extensions/ext_fastopenapi.py
Normal file
@ -0,0 +1,45 @@
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from flask_cors import CORS
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
||||
|
||||
DOCS_PREFIX = "/fastopenapi"
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||
redoc_url = f"{DOCS_PREFIX}/redoc" if docs_enabled else None
|
||||
openapi_url = f"{DOCS_PREFIX}/openapi.json" if docs_enabled else None
|
||||
|
||||
router = FlaskRouter(
|
||||
app=app,
|
||||
docs_url=docs_url,
|
||||
redoc_url=redoc_url,
|
||||
openapi_url=openapi_url,
|
||||
openapi_version="3.0.0",
|
||||
title="Dify API (FastOpenAPI PoC)",
|
||||
version="1.0",
|
||||
description="FastOpenAPI proof of concept for Dify API",
|
||||
)
|
||||
|
||||
# Ensure route decorators are evaluated.
|
||||
import controllers.console.ping as ping_module
|
||||
from controllers.console import setup
|
||||
|
||||
_ = ping_module
|
||||
_ = setup
|
||||
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
)
|
||||
app.extensions["fastopenapi"] = router
|
||||
@ -1,18 +1,19 @@
|
||||
"""Logging extension for Dify Flask application."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import flask
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize logging with support for text or JSON format."""
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# File handler
|
||||
log_file = dify_config.LOG_FILE
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
@ -25,27 +26,53 @@ def init_app(app: DifyApp):
|
||||
)
|
||||
)
|
||||
|
||||
# Always add StreamHandler to log to console
|
||||
# Console handler
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
log_handlers.append(sh)
|
||||
|
||||
# Apply RequestIdFilter to all handlers
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(RequestIdFilter())
|
||||
# Apply filters to all handlers
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(TraceContextFilter())
|
||||
handler.addFilter(IdentityContextFilter())
|
||||
|
||||
# Configure formatter based on format type
|
||||
formatter = _create_formatter()
|
||||
for handler in log_handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=dify_config.LOG_LEVEL,
|
||||
format=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Apply RequestIdFormatter to all handlers
|
||||
apply_request_id_formatter()
|
||||
|
||||
# Disable propagation for noisy loggers to avoid duplicate logs
|
||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||
|
||||
# Apply timezone if specified (only for text format)
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "text":
|
||||
_apply_timezone(log_handlers)
|
||||
|
||||
|
||||
def _create_formatter() -> logging.Formatter:
|
||||
"""Create appropriate formatter based on configuration."""
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "json":
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
return StructuredJSONFormatter()
|
||||
else:
|
||||
# Text format - use existing pattern with backward compatible formatter
|
||||
return _TextFormatter(
|
||||
fmt=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
)
|
||||
|
||||
|
||||
def _apply_timezone(handlers: list[logging.Handler]):
|
||||
"""Apply timezone conversion to text formatters."""
|
||||
log_tz = dify_config.LOG_TZ
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
@ -57,34 +84,51 @@ def init_app(app: DifyApp):
|
||||
def time_converter(seconds):
|
||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
for handler in handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter.converter = time_converter
|
||||
handler.formatter.converter = time_converter # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_request_id():
|
||||
if getattr(flask.g, "request_id", None):
|
||||
return flask.g.request_id
|
||||
class _TextFormatter(logging.Formatter):
|
||||
"""Text formatter that ensures trace_id and req_id are always present."""
|
||||
|
||||
new_uuid = uuid.uuid4().hex[:10]
|
||||
flask.g.request_id = new_uuid
|
||||
|
||||
return new_uuid
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
record.trace_id = ""
|
||||
if not hasattr(record, "span_id"):
|
||||
record.span_id = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get request ID for current request context.
|
||||
|
||||
Deprecated: Use core.logging.context.get_request_id() directly.
|
||||
"""
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
|
||||
return _get_request_id()
|
||||
|
||||
|
||||
# Backward compatibility aliases
|
||||
class RequestIdFilter(logging.Filter):
|
||||
# This is a logging filter that makes the request ID available for use in
|
||||
# the logging format. Note that we're checking if we're in a request
|
||||
# context, as we may want to log things before Flask is fully loaded.
|
||||
def filter(self, record):
|
||||
trace_id = get_trace_id_from_otel_context() or ""
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
record.trace_id = trace_id
|
||||
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
from core.logging.context import get_trace_id as _get_trace_id
|
||||
|
||||
record.req_id = _get_request_id()
|
||||
record.trace_id = _get_trace_id()
|
||||
return True
|
||||
|
||||
|
||||
class RequestIdFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
"""Deprecated: Use _TextFormatter instead."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
|
||||
|
||||
|
||||
def apply_request_id_formatter():
|
||||
"""Deprecated: Formatter is now applied in init_app."""
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
||||
|
||||
@ -10,6 +10,7 @@ import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -19,12 +20,17 @@ def is_enabled() -> bool:
|
||||
"""
|
||||
Check if logstore extension is enabled.
|
||||
|
||||
Logstore is considered enabled when:
|
||||
1. All required Aliyun SLS environment variables are set
|
||||
2. At least one repository configuration points to a logstore implementation
|
||||
|
||||
Returns:
|
||||
True if all required Aliyun SLS environment variables are set, False otherwise
|
||||
True if logstore should be initialized, False otherwise
|
||||
"""
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Check if Aliyun SLS connection parameters are configured
|
||||
required_vars = [
|
||||
"ALIYUN_SLS_ACCESS_KEY_ID",
|
||||
"ALIYUN_SLS_ACCESS_KEY_SECRET",
|
||||
@ -33,24 +39,32 @@ def is_enabled() -> bool:
|
||||
"ALIYUN_SLS_PROJECT_NAME",
|
||||
]
|
||||
|
||||
all_set = all(os.environ.get(var) for var in required_vars)
|
||||
sls_vars_set = all(os.environ.get(var) for var in required_vars)
|
||||
|
||||
if not all_set:
|
||||
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
|
||||
if not sls_vars_set:
|
||||
return False
|
||||
|
||||
return all_set
|
||||
# Check if any repository configuration points to logstore implementation
|
||||
repository_configs = [
|
||||
dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
|
||||
dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
|
||||
dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
|
||||
dify_config.API_WORKFLOW_RUN_REPOSITORY,
|
||||
]
|
||||
|
||||
uses_logstore = any("logstore" in config.lower() for config in repository_configs)
|
||||
|
||||
if not uses_logstore:
|
||||
return False
|
||||
|
||||
logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
|
||||
return True
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""
|
||||
Initialize logstore on application startup.
|
||||
|
||||
This function:
|
||||
1. Creates Aliyun SLS project if it doesn't exist
|
||||
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
|
||||
3. Creates indexes with field configurations based on PostgreSQL table structures
|
||||
|
||||
This operation is idempotent and only executes once during application startup.
|
||||
If initialization fails, the application continues running without logstore features.
|
||||
|
||||
Args:
|
||||
app: The Dify application instance
|
||||
@ -58,17 +72,23 @@ def init_app(app: DifyApp):
|
||||
try:
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
|
||||
logger.info("Initializing logstore...")
|
||||
logger.info("Initializing Aliyun SLS Logstore...")
|
||||
|
||||
# Create logstore client and initialize project/logstores/indexes
|
||||
# Create logstore client and initialize resources
|
||||
logstore_client = AliyunLogStore()
|
||||
logstore_client.init_project_logstore()
|
||||
|
||||
# Attach to app for potential later use
|
||||
app.extensions["logstore"] = logstore_client
|
||||
|
||||
logger.info("Logstore initialized successfully")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize logstore")
|
||||
# Don't raise - allow application to continue even if logstore init fails
|
||||
# This ensures that the application can still run if logstore is misconfigured
|
||||
logger.exception(
|
||||
"Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
|
||||
"Application will continue but logstore features will NOT work.",
|
||||
os.environ.get("ALIYUN_SLS_ENDPOINT"),
|
||||
os.environ.get("ALIYUN_SLS_REGION"),
|
||||
os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
|
||||
os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
|
||||
)
|
||||
# Don't raise - allow application to continue even if logstore setup fails
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
@ -33,7 +36,7 @@ class AliyunLogStore:
|
||||
Ensures only one instance exists to prevent multiple PG connection pools.
|
||||
"""
|
||||
|
||||
_instance: "AliyunLogStore | None" = None
|
||||
_instance: AliyunLogStore | None = None
|
||||
_initialized: bool = False
|
||||
|
||||
# Track delayed PG connection for newly created projects
|
||||
@ -66,7 +69,7 @@ class AliyunLogStore:
|
||||
"\t",
|
||||
]
|
||||
|
||||
def __new__(cls) -> "AliyunLogStore":
|
||||
def __new__(cls) -> AliyunLogStore:
|
||||
"""Implement singleton pattern."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
@ -177,9 +180,18 @@ class AliyunLogStore:
|
||||
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
|
||||
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
|
||||
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
|
||||
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
|
||||
self.log_enabled: bool = (
|
||||
os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
|
||||
or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
|
||||
)
|
||||
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Get timeout configuration
|
||||
check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
|
||||
|
||||
# Pre-check endpoint connectivity to prevent indefinite hangs
|
||||
self._check_endpoint_connectivity(self.endpoint, check_timeout)
|
||||
|
||||
# Initialize SDK client
|
||||
self.client = LogClient(
|
||||
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
|
||||
@ -197,6 +209,49 @@ class AliyunLogStore:
|
||||
|
||||
self.__class__._initialized = True
|
||||
|
||||
@staticmethod
|
||||
def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
|
||||
"""
|
||||
Check if the SLS endpoint is reachable before creating LogClient.
|
||||
Prevents indefinite hangs when the endpoint is unreachable.
|
||||
|
||||
Args:
|
||||
endpoint: SLS endpoint URL
|
||||
timeout: Connection timeout in seconds
|
||||
|
||||
Raises:
|
||||
ConnectionError: If endpoint is not reachable
|
||||
"""
|
||||
# Parse endpoint URL to extract hostname and port
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
|
||||
hostname = parsed_url.hostname
|
||||
port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
|
||||
|
||||
if not hostname:
|
||||
raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
|
||||
|
||||
sock = None
|
||||
try:
|
||||
# Create socket and set timeout
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
sock.connect((hostname, port))
|
||||
except Exception as e:
|
||||
# Catch all exceptions and provide clear error message
|
||||
error_type = type(e).__name__
|
||||
raise ConnectionError(
|
||||
f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
|
||||
) from e
|
||||
finally:
|
||||
# Ensure socket is properly closed
|
||||
if sock:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
@property
|
||||
def supports_pg_protocol(self) -> bool:
|
||||
"""Check if PG protocol is supported and enabled."""
|
||||
@ -218,19 +273,16 @@ class AliyunLogStore:
|
||||
try:
|
||||
self._use_pg_protocol = self._pg_client.init_connection()
|
||||
if self._use_pg_protocol:
|
||||
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
|
||||
logger.info("Using PG protocol for project %s", self.project_name)
|
||||
# Check if scan_index is enabled for all logstores
|
||||
self._check_and_disable_pg_if_scan_index_disabled()
|
||||
return True
|
||||
else:
|
||||
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
|
||||
logger.info("Using SDK mode for project %s", self.project_name)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
|
||||
self.project_name,
|
||||
str(e),
|
||||
)
|
||||
logger.info("Using SDK mode for project %s", self.project_name)
|
||||
logger.debug("PG connection details: %s", str(e))
|
||||
self._use_pg_protocol = False
|
||||
return False
|
||||
|
||||
@ -244,10 +296,6 @@ class AliyunLogStore:
|
||||
if self._use_pg_protocol:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Attempting delayed PG connection for newly created project %s ...",
|
||||
self.project_name,
|
||||
)
|
||||
self._attempt_pg_connection_init()
|
||||
self.__class__._pg_connection_timer = None
|
||||
|
||||
@ -282,11 +330,7 @@ class AliyunLogStore:
|
||||
if project_is_new:
|
||||
# For newly created projects, schedule delayed PG connection
|
||||
self._use_pg_protocol = False
|
||||
logger.info(
|
||||
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
|
||||
self.project_name,
|
||||
self.__class__._pg_connection_delay,
|
||||
)
|
||||
logger.info("Using SDK mode for project %s (newly created)", self.project_name)
|
||||
if self.__class__._pg_connection_timer is not None:
|
||||
self.__class__._pg_connection_timer.cancel()
|
||||
self.__class__._pg_connection_timer = threading.Timer(
|
||||
@ -297,7 +341,6 @@ class AliyunLogStore:
|
||||
self.__class__._pg_connection_timer.start()
|
||||
else:
|
||||
# For existing projects, attempt PG connection immediately
|
||||
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
|
||||
self._attempt_pg_connection_init()
|
||||
|
||||
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
|
||||
@ -316,9 +359,9 @@ class AliyunLogStore:
|
||||
existing_config = self.get_existing_index_config(logstore_name)
|
||||
if existing_config and not existing_config.scan_index:
|
||||
logger.info(
|
||||
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
|
||||
"PG protocol requires scan_index to be enabled.",
|
||||
"Logstore %s requires scan_index enabled, using SDK mode for project %s",
|
||||
logstore_name,
|
||||
self.project_name,
|
||||
)
|
||||
self._use_pg_protocol = False
|
||||
# Close PG connection if it was initialized
|
||||
@ -746,7 +789,6 @@ class AliyunLogStore:
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
# Log query info if SQLALCHEMY_ECHO is enabled
|
||||
if self.log_enabled:
|
||||
logger.info(
|
||||
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
|
||||
@ -768,7 +810,6 @@ class AliyunLogStore:
|
||||
for log in logs:
|
||||
result.append(log.get_contents())
|
||||
|
||||
# Log result count if SQLALCHEMY_ECHO is enabled
|
||||
if self.log_enabled:
|
||||
logger.info(
|
||||
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
|
||||
@ -843,7 +884,6 @@ class AliyunLogStore:
|
||||
query=full_query,
|
||||
)
|
||||
|
||||
# Log query info if SQLALCHEMY_ECHO is enabled
|
||||
if self.log_enabled:
|
||||
logger.info(
|
||||
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
|
||||
@ -851,8 +891,7 @@ class AliyunLogStore:
|
||||
self.project_name,
|
||||
from_time,
|
||||
to_time,
|
||||
query,
|
||||
sql,
|
||||
full_query,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -863,7 +902,6 @@ class AliyunLogStore:
|
||||
for log in logs:
|
||||
result.append(log.get_contents())
|
||||
|
||||
# Log result count if SQLALCHEMY_ECHO is enabled
|
||||
if self.log_enabled:
|
||||
logger.info(
|
||||
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
|
||||
|
||||
@ -7,8 +7,7 @@ from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.pool
|
||||
from psycopg2 import InterfaceError, OperationalError
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AliyunLogStorePG:
|
||||
"""
|
||||
PostgreSQL protocol support for Aliyun SLS LogStore.
|
||||
|
||||
Handles PG connection pooling and operations for regions that support PG protocol.
|
||||
"""
|
||||
"""PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
|
||||
|
||||
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
|
||||
"""
|
||||
@ -36,24 +31,11 @@ class AliyunLogStorePG:
|
||||
self._access_key_secret = access_key_secret
|
||||
self._endpoint = endpoint
|
||||
self.project_name = project_name
|
||||
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
|
||||
self._engine: Any = None # SQLAlchemy Engine
|
||||
self._use_pg_protocol = False
|
||||
|
||||
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
|
||||
"""
|
||||
Check if a TCP port is reachable using socket connection.
|
||||
|
||||
This provides a fast check before attempting full database connection,
|
||||
preventing long waits when connecting to unsupported regions.
|
||||
|
||||
Args:
|
||||
host: Hostname or IP address
|
||||
port: Port number
|
||||
timeout: Connection timeout in seconds (default: 2.0)
|
||||
|
||||
Returns:
|
||||
True if port is reachable, False otherwise
|
||||
"""
|
||||
"""Fast TCP port check to avoid long waits on unsupported regions."""
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
@ -65,166 +47,101 @@ class AliyunLogStorePG:
|
||||
return False
|
||||
|
||||
def init_connection(self) -> bool:
|
||||
"""
|
||||
Initialize PostgreSQL connection pool for SLS PG protocol support.
|
||||
|
||||
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
|
||||
_use_pg_protocol to True and creates a connection pool. If connection fails
|
||||
(region doesn't support PG protocol or other errors), returns False.
|
||||
|
||||
Returns:
|
||||
True if PG protocol is supported and initialized, False otherwise
|
||||
"""
|
||||
"""Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
|
||||
try:
|
||||
# Extract hostname from endpoint (remove protocol if present)
|
||||
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
|
||||
|
||||
# Get pool configuration
|
||||
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
|
||||
# Pool configuration
|
||||
pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
|
||||
max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
|
||||
pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
|
||||
pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
|
||||
|
||||
logger.debug(
|
||||
"Check PG protocol connection to SLS: host=%s, project=%s",
|
||||
pg_host,
|
||||
self.project_name,
|
||||
)
|
||||
logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
|
||||
|
||||
# Fast port connectivity check before attempting full connection
|
||||
# This prevents long waits when connecting to unsupported regions
|
||||
# Fast port check to avoid long waits
|
||||
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
|
||||
logger.info(
|
||||
"USE SDK mode for read/write operations, host=%s",
|
||||
pg_host,
|
||||
)
|
||||
logger.debug("Using SDK mode for host=%s", pg_host)
|
||||
return False
|
||||
|
||||
# Create connection pool
|
||||
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=pg_max_connections,
|
||||
host=pg_host,
|
||||
port=5432,
|
||||
database=self.project_name,
|
||||
user=self._access_key_id,
|
||||
password=self._access_key_secret,
|
||||
sslmode="require",
|
||||
connect_timeout=5,
|
||||
application_name=f"Dify-{dify_config.project.version}",
|
||||
# Build connection URL
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
username = quote_plus(self._access_key_id)
|
||||
password = quote_plus(self._access_key_secret)
|
||||
database_url = (
|
||||
f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
|
||||
)
|
||||
|
||||
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
|
||||
# Connection pool creation success already indicates connectivity
|
||||
# Create SQLAlchemy engine with connection pool
|
||||
self._engine = create_engine(
|
||||
database_url,
|
||||
pool_size=pool_size,
|
||||
max_overflow=max_overflow,
|
||||
pool_recycle=pool_recycle,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
pool_timeout=30,
|
||||
connect_args={
|
||||
"connect_timeout": 5,
|
||||
"application_name": f"Dify-{dify_config.project.version}-fixautocommit",
|
||||
"keepalives": 1,
|
||||
"keepalives_idle": 60,
|
||||
"keepalives_interval": 10,
|
||||
"keepalives_count": 5,
|
||||
},
|
||||
)
|
||||
|
||||
self._use_pg_protocol = True
|
||||
logger.info(
|
||||
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
|
||||
"PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
|
||||
self.project_name,
|
||||
pool_size,
|
||||
pool_recycle,
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# PG connection failed - fallback to SDK mode
|
||||
self._use_pg_protocol = False
|
||||
if self._pg_pool:
|
||||
if self._engine:
|
||||
try:
|
||||
self._pg_pool.closeall()
|
||||
self._engine.dispose()
|
||||
except Exception:
|
||||
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
|
||||
self._pg_pool = None
|
||||
logger.debug("Failed to dispose engine during cleanup, ignoring")
|
||||
self._engine = None
|
||||
|
||||
logger.info(
|
||||
"PG protocol connection failed (region may not support PG protocol): %s. "
|
||||
"Falling back to SDK mode for read/write operations.",
|
||||
str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
def _is_connection_valid(self, conn: Any) -> bool:
|
||||
"""
|
||||
Check if a connection is still valid.
|
||||
|
||||
Args:
|
||||
conn: psycopg2 connection object
|
||||
|
||||
Returns:
|
||||
True if connection is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check if connection is closed
|
||||
if conn.closed:
|
||||
return False
|
||||
|
||||
# Quick ping test - execute a lightweight query
|
||||
# For SLS PG protocol, we can't use SELECT 1 without FROM,
|
||||
# so we just check the connection status
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchone()
|
||||
return True
|
||||
except Exception:
|
||||
logger.debug("Using SDK mode for region: %s", str(e))
|
||||
return False
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""
|
||||
Context manager to get a PostgreSQL connection from the pool.
|
||||
"""Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
|
||||
if not self._engine:
|
||||
raise RuntimeError("SQLAlchemy engine is not initialized")
|
||||
|
||||
Automatically validates and refreshes stale connections.
|
||||
|
||||
Note: Aliyun SLS PG protocol does not support transactions, so we always
|
||||
use autocommit mode.
|
||||
|
||||
Yields:
|
||||
psycopg2 connection object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If PG pool is not initialized
|
||||
"""
|
||||
if not self._pg_pool:
|
||||
raise RuntimeError("PG connection pool is not initialized")
|
||||
|
||||
conn = self._pg_pool.getconn()
|
||||
connection = self._engine.raw_connection()
|
||||
try:
|
||||
# Validate connection and get a fresh one if needed
|
||||
if not self._is_connection_valid(conn):
|
||||
logger.debug("Connection is stale, marking as bad and getting a new one")
|
||||
# Mark connection as bad and get a new one
|
||||
self._pg_pool.putconn(conn, close=True)
|
||||
conn = self._pg_pool.getconn()
|
||||
|
||||
# Aliyun SLS PG protocol does not support transactions, always use autocommit
|
||||
conn.autocommit = True
|
||||
yield conn
|
||||
connection.autocommit = True # SLS PG protocol does not support transactions
|
||||
yield connection
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
# Return connection to pool (or close if it's bad)
|
||||
if self._is_connection_valid(conn):
|
||||
self._pg_pool.putconn(conn)
|
||||
else:
|
||||
self._pg_pool.putconn(conn, close=True)
|
||||
connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the PostgreSQL connection pool."""
|
||||
if self._pg_pool:
|
||||
"""Dispose SQLAlchemy engine and close all connections."""
|
||||
if self._engine:
|
||||
try:
|
||||
self._pg_pool.closeall()
|
||||
logger.info("PG connection pool closed")
|
||||
self._engine.dispose()
|
||||
logger.info("SQLAlchemy engine disposed")
|
||||
except Exception:
|
||||
logger.exception("Failed to close PG connection pool")
|
||||
logger.exception("Failed to dispose engine")
|
||||
|
||||
def _is_retriable_error(self, error: Exception) -> bool:
|
||||
"""
|
||||
Check if an error is retriable (connection-related issues).
|
||||
|
||||
Args:
|
||||
error: Exception to check
|
||||
|
||||
Returns:
|
||||
True if the error is retriable, False otherwise
|
||||
"""
|
||||
# Retry on connection-related errors
|
||||
if isinstance(error, (OperationalError, InterfaceError)):
|
||||
"""Check if error is retriable (connection-related issues)."""
|
||||
# Check for psycopg2 connection errors directly
|
||||
if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
|
||||
return True
|
||||
|
||||
# Check error message for specific connection issues
|
||||
error_msg = str(error).lower()
|
||||
retriable_patterns = [
|
||||
"connection",
|
||||
@ -234,34 +151,18 @@ class AliyunLogStorePG:
|
||||
"reset by peer",
|
||||
"no route to host",
|
||||
"network",
|
||||
"operational error",
|
||||
"interface error",
|
||||
]
|
||||
return any(pattern in error_msg for pattern in retriable_patterns)
|
||||
|
||||
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
|
||||
"""
|
||||
Write log to SLS using PostgreSQL protocol with automatic retry.
|
||||
|
||||
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
|
||||
writes with log_version field for versioning, same as SDK implementation.
|
||||
|
||||
Args:
|
||||
logstore: Name of the logstore table
|
||||
contents: List of (field_name, value) tuples
|
||||
log_enabled: Whether to enable logging
|
||||
|
||||
Raises:
|
||||
psycopg2.Error: If database operation fails after all retries
|
||||
"""
|
||||
"""Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
|
||||
if not contents:
|
||||
return
|
||||
|
||||
# Extract field names and values from contents
|
||||
fields = [field_name for field_name, _ in contents]
|
||||
values = [value for _, value in contents]
|
||||
|
||||
# Build INSERT statement with literal values
|
||||
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
|
||||
# so we need to use mogrify to safely create literal values
|
||||
field_list = ", ".join([f'"{field}"' for field in fields])
|
||||
|
||||
if log_enabled:
|
||||
@ -272,67 +173,40 @@ class AliyunLogStorePG:
|
||||
len(contents),
|
||||
)
|
||||
|
||||
# Retry configuration
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # Start with 100ms
|
||||
retry_delay = 0.1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
# Use mogrify to safely convert values to SQL literals
|
||||
placeholders = ", ".join(["%s"] * len(fields))
|
||||
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
|
||||
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
|
||||
cursor.execute(insert_sql)
|
||||
# Success - exit retry loop
|
||||
return
|
||||
|
||||
except psycopg2.Error as e:
|
||||
# Check if error is retriable
|
||||
if not self._is_retriable_error(e):
|
||||
# Not a retriable error (e.g., data validation error), fail immediately
|
||||
logger.exception(
|
||||
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
|
||||
logstore,
|
||||
)
|
||||
logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
|
||||
raise
|
||||
|
||||
# Retriable error - log and retry if we have attempts left
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
|
||||
"Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
|
||||
logstore,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
str(e),
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
# Last attempt failed
|
||||
logger.exception(
|
||||
"Failed to put logs to logstore %s via PG protocol after %d attempts",
|
||||
logstore,
|
||||
max_retries,
|
||||
)
|
||||
logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
|
||||
raise
|
||||
|
||||
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Execute SQL query using PostgreSQL protocol with automatic retry.
|
||||
|
||||
Args:
|
||||
sql: SQL query string
|
||||
logstore: Name of the logstore (for logging purposes)
|
||||
log_enabled: Whether to enable logging
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
|
||||
Raises:
|
||||
psycopg2.Error: If database operation fails after all retries
|
||||
"""
|
||||
"""Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
|
||||
if log_enabled:
|
||||
logger.info(
|
||||
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
|
||||
@ -341,20 +215,16 @@ class AliyunLogStorePG:
|
||||
sql,
|
||||
)
|
||||
|
||||
# Retry configuration
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # Start with 100ms
|
||||
retry_delay = 0.1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
|
||||
# Get column names from cursor description
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
|
||||
# Fetch all results and convert to list of dicts
|
||||
result = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
@ -372,36 +242,31 @@ class AliyunLogStorePG:
|
||||
return result
|
||||
|
||||
except psycopg2.Error as e:
|
||||
# Check if error is retriable
|
||||
if not self._is_retriable_error(e):
|
||||
# Not a retriable error (e.g., SQL syntax error), fail immediately
|
||||
logger.exception(
|
||||
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
|
||||
"Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
|
||||
logstore,
|
||||
sql,
|
||||
)
|
||||
raise
|
||||
|
||||
# Retriable error - log and retry if we have attempts left
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
|
||||
"Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
|
||||
logstore,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
str(e),
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
# Last attempt failed
|
||||
logger.exception(
|
||||
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
|
||||
"Failed to execute SQL on logstore %s after %d attempts: sql=%s",
|
||||
logstore,
|
||||
max_retries,
|
||||
sql,
|
||||
)
|
||||
raise
|
||||
|
||||
# This line should never be reached due to raise above, but makes type checker happy
|
||||
return []
|
||||
|
||||
@ -0,0 +1,29 @@
|
||||
"""
|
||||
LogStore repository utilities.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def safe_float(value: Any, default: float = 0.0) -> float:
|
||||
"""
|
||||
Safely convert a value to float, handling 'null' strings and None.
|
||||
"""
|
||||
if value is None or value in {"null", ""}:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
def safe_int(value: Any, default: int = 0) -> int:
|
||||
"""
|
||||
Safely convert a value to int, handling 'null' strings and None.
|
||||
"""
|
||||
if value is None or value in {"null", ""}:
|
||||
return default
|
||||
try:
|
||||
return int(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
@ -15,6 +15,8 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
@ -53,9 +55,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
|
||||
model.created_by_role = data.get("created_by_role") or ""
|
||||
model.created_by = data.get("created_by") or ""
|
||||
|
||||
# Numeric fields with defaults
|
||||
model.index = int(data.get("index", 0))
|
||||
model.elapsed_time = float(data.get("elapsed_time", 0))
|
||||
model.index = safe_int(data.get("index", 0))
|
||||
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
|
||||
|
||||
# Optional fields
|
||||
model.workflow_run_id = data.get("workflow_run_id")
|
||||
@ -131,6 +132,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
node_id,
|
||||
)
|
||||
try:
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_workflow_id = escape_identifier(workflow_id)
|
||||
escaped_node_id = escape_identifier(node_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of each record)
|
||||
@ -139,10 +146,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
|
||||
WHERE tenant_id = '{tenant_id}'
|
||||
AND app_id = '{app_id}'
|
||||
AND workflow_id = '{workflow_id}'
|
||||
AND node_id = '{node_id}'
|
||||
WHERE tenant_id = '{escaped_tenant_id}'
|
||||
AND app_id = '{escaped_app_id}'
|
||||
AND workflow_id = '{escaped_workflow_id}'
|
||||
AND node_id = '{escaped_node_id}'
|
||||
AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 100
|
||||
@ -154,7 +161,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
query = (
|
||||
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
|
||||
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
|
||||
f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
|
||||
)
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
@ -230,6 +238,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
workflow_run_id,
|
||||
)
|
||||
try:
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_workflow_run_id = escape_identifier(workflow_run_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of each record)
|
||||
@ -238,9 +251,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
|
||||
WHERE tenant_id = '{tenant_id}'
|
||||
AND app_id = '{app_id}'
|
||||
AND workflow_run_id = '{workflow_run_id}'
|
||||
WHERE tenant_id = '{escaped_tenant_id}'
|
||||
AND app_id = '{escaped_app_id}'
|
||||
AND workflow_run_id = '{escaped_workflow_run_id}'
|
||||
AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 1000
|
||||
@ -251,7 +264,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
)
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
|
||||
query = (
|
||||
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
|
||||
f"and workflow_run_id: {escaped_workflow_run_id}"
|
||||
)
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
||||
@ -318,16 +334,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
"""
|
||||
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
|
||||
try:
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_execution_id = escape_identifier(execution_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of record)
|
||||
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
|
||||
if tenant_id:
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
|
||||
else:
|
||||
tenant_filter = ""
|
||||
|
||||
sql_query = f"""
|
||||
SELECT * FROM (
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
|
||||
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
|
||||
WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 1
|
||||
"""
|
||||
@ -337,10 +361,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
)
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
# Note: Values must be quoted in LogStore query syntax to prevent injection
|
||||
if tenant_id:
|
||||
query = f"id: {execution_id} and tenant_id: {tenant_id}"
|
||||
query = (
|
||||
f"id:{escape_logstore_query_value(execution_id)} "
|
||||
f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
|
||||
)
|
||||
else:
|
||||
query = f"id: {execution_id}"
|
||||
query = f"id:{escape_logstore_query_value(execution_id)}"
|
||||
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
||||
@ -10,6 +10,7 @@ Key Features:
|
||||
- Optimized deduplication using finished_at IS NOT NULL filter
|
||||
- Window functions only when necessary (running status queries)
|
||||
- Multi-tenant data isolation and security
|
||||
- SQL injection prevention via parameter escaping
|
||||
"""
|
||||
|
||||
import logging
|
||||
@ -22,6 +23,8 @@ from typing import Any, cast
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
|
||||
model.created_by_role = data.get("created_by_role") or ""
|
||||
model.created_by = data.get("created_by") or ""
|
||||
|
||||
# Numeric fields with defaults
|
||||
model.total_tokens = int(data.get("total_tokens", 0))
|
||||
model.total_steps = int(data.get("total_steps", 0))
|
||||
model.exceptions_count = int(data.get("exceptions_count", 0))
|
||||
model.total_tokens = safe_int(data.get("total_tokens", 0))
|
||||
model.total_steps = safe_int(data.get("total_steps", 0))
|
||||
model.exceptions_count = safe_int(data.get("exceptions_count", 0))
|
||||
|
||||
# Optional fields
|
||||
model.graph = data.get("graph")
|
||||
@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
|
||||
if model.finished_at and model.created_at:
|
||||
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
|
||||
else:
|
||||
model.elapsed_time = float(data.get("elapsed_time", 0))
|
||||
# Use safe conversion to handle 'null' strings and None values
|
||||
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
|
||||
|
||||
return model
|
||||
|
||||
@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
status,
|
||||
)
|
||||
# Convert triggered_from to list if needed
|
||||
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
|
||||
if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
|
||||
triggered_from_list = [triggered_from]
|
||||
else:
|
||||
triggered_from_list = list(triggered_from)
|
||||
|
||||
# Build triggered_from filter
|
||||
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
|
||||
# Build status filter
|
||||
status_filter = f"AND status='{status}'" if status else ""
|
||||
# Build triggered_from filter with escaped values
|
||||
# Support both enum and string values for triggered_from
|
||||
triggered_from_filter = " OR ".join(
|
||||
[
|
||||
f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
|
||||
for tf in triggered_from_list
|
||||
]
|
||||
)
|
||||
|
||||
# Build status filter with escaped value
|
||||
status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
|
||||
|
||||
# Build last_id filter for pagination
|
||||
# Note: This is simplified. In production, you'd need to track created_at from last record
|
||||
@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
SELECT * FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND ({triggered_from_filter})
|
||||
{status_filter}
|
||||
{last_id_filter}
|
||||
@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
|
||||
|
||||
try:
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_run_id = escape_identifier(run_id)
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of record)
|
||||
@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_execution_logstore}"
|
||||
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
|
||||
WHERE id = '{escaped_run_id}'
|
||||
AND tenant_id = '{escaped_tenant_id}'
|
||||
AND app_id = '{escaped_app_id}'
|
||||
AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 100
|
||||
"""
|
||||
@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
)
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
|
||||
# Note: Values must be quoted in LogStore query syntax to prevent injection
|
||||
query = (
|
||||
f"id:{escape_logstore_query_value(run_id)} "
|
||||
f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
|
||||
f"and app_id:{escape_logstore_query_value(app_id)}"
|
||||
)
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
||||
@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
|
||||
|
||||
try:
|
||||
# Escape parameter to prevent SQL injection
|
||||
escaped_run_id = escape_identifier(run_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of record)
|
||||
@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_execution_logstore}"
|
||||
WHERE id = '{run_id}' AND __time__ > 0
|
||||
WHERE id = '{escaped_run_id}' AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 100
|
||||
"""
|
||||
@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
)
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
query = f"id: {run_id}"
|
||||
# Note: Values must be quoted in LogStore query syntax
|
||||
query = f"id:{escape_logstore_query_value(run_id)}"
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
||||
@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
triggered_from,
|
||||
status,
|
||||
)
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter
|
||||
time_filter = ""
|
||||
if time_range:
|
||||
@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
# If status is provided, simple count
|
||||
if status:
|
||||
escaped_status = escape_sql_string(status)
|
||||
|
||||
if status == "running":
|
||||
# Running status requires window function
|
||||
sql = f"""
|
||||
@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND status='running'
|
||||
{time_filter}
|
||||
) t
|
||||
@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
sql = f"""
|
||||
SELECT COUNT(DISTINCT id) as count
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
AND status='{status}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND status='{escaped_status}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
"""
|
||||
@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
# No status filter - get counts grouped by status
|
||||
# Use optimized query for finished runs, separate query for running
|
||||
try:
|
||||
# Escape parameters (already escaped above, reuse variables)
|
||||
# Count finished runs grouped by status
|
||||
finished_sql = f"""
|
||||
SELECT status, COUNT(DISTINCT id) as count
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY status
|
||||
@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND status='running'
|
||||
{time_filter}
|
||||
) t
|
||||
@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
logger.debug(
|
||||
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
|
||||
)
|
||||
# Build time range filter
|
||||
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter (datetime.isoformat() is safe)
|
||||
time_filter = ""
|
||||
if start_date:
|
||||
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
|
||||
@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
sql = f"""
|
||||
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY date
|
||||
@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
app_id,
|
||||
triggered_from,
|
||||
)
|
||||
# Build time range filter
|
||||
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter (datetime.isoformat() is safe)
|
||||
time_filter = ""
|
||||
if start_date:
|
||||
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
|
||||
@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
sql = f"""
|
||||
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY date
|
||||
@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
app_id,
|
||||
triggered_from,
|
||||
)
|
||||
# Build time range filter
|
||||
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter (datetime.isoformat() is safe)
|
||||
time_filter = ""
|
||||
if start_date:
|
||||
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
|
||||
@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
sql = f"""
|
||||
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY date
|
||||
@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
app_id,
|
||||
triggered_from,
|
||||
)
|
||||
# Build time range filter
|
||||
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter (datetime.isoformat() is safe)
|
||||
time_filter = ""
|
||||
if start_date:
|
||||
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
|
||||
@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
created_by,
|
||||
COUNT(DISTINCT id) AS interactions
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY date, created_by
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
@ -67,7 +68,12 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
|
||||
# Control flag for dual-write (write to both LogStore and SQL database)
|
||||
# Set to True to enable dual-write for safe migration, False to use LogStore only
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
|
||||
|
||||
# Control flag for whether to write the `graph` field to LogStore.
|
||||
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
|
||||
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
|
||||
self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true"
|
||||
|
||||
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
|
||||
"""
|
||||
@ -96,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
# Generate log_version as nanosecond timestamp for record versioning
|
||||
log_version = str(time.time_ns())
|
||||
|
||||
# Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
logstore_model = [
|
||||
("id", domain_model.id_),
|
||||
("log_version", log_version), # Add log_version field for append-only writes
|
||||
@ -108,9 +117,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
),
|
||||
("type", domain_model.workflow_type.value),
|
||||
("version", domain_model.workflow_version),
|
||||
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
|
||||
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
|
||||
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
|
||||
(
|
||||
"graph",
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
|
||||
if domain_model.graph and self._enable_put_graph_field
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"inputs",
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
|
||||
if domain_model.inputs
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"outputs",
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
|
||||
if domain_model.outputs
|
||||
else "{}",
|
||||
),
|
||||
("status", domain_model.status.value),
|
||||
("error_message", domain_model.error_message or ""),
|
||||
("total_tokens", str(domain_model.total_tokens)),
|
||||
|
||||
@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
Account,
|
||||
@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
|
||||
node_execution_id=data.get("node_execution_id"),
|
||||
workflow_id=data.get("workflow_id", ""),
|
||||
workflow_execution_id=data.get("workflow_run_id"),
|
||||
index=int(data.get("index", 0)),
|
||||
index=safe_int(data.get("index", 0)),
|
||||
predecessor_node_id=data.get("predecessor_node_id"),
|
||||
node_id=data.get("node_id", ""),
|
||||
node_type=NodeType(data.get("node_type", "start")),
|
||||
@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
|
||||
outputs=outputs,
|
||||
status=status,
|
||||
error=data.get("error"),
|
||||
elapsed_time=float(data.get("elapsed_time", 0.0)),
|
||||
elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
|
||||
metadata=domain_metadata,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
|
||||
# Control flag for dual-write (write to both LogStore and SQL database)
|
||||
# Set to True to enable dual-write for safe migration, False to use LogStore only
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
|
||||
|
||||
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
|
||||
logger.debug(
|
||||
@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
Save or update the inputs, process_data, or outputs associated with a specific
|
||||
node_execution record.
|
||||
|
||||
For LogStore implementation, this is similar to save() since we always write
|
||||
complete records. We append a new record with updated data fields.
|
||||
For LogStore implementation, this is a no-op for the LogStore write because save()
|
||||
already writes all fields including inputs, process_data, and outputs. The caller
|
||||
typically calls save() first to persist status/metadata, then calls save_execution_data()
|
||||
to persist data fields. Since LogStore writes complete records atomically, we don't
|
||||
need a separate write here to avoid duplicate records.
|
||||
|
||||
However, if dual-write is enabled, we still need to call the SQL repository's
|
||||
save_execution_data() method to properly update the SQL database.
|
||||
|
||||
Args:
|
||||
execution: The NodeExecution instance with data to save
|
||||
"""
|
||||
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
|
||||
# In LogStore, we simply write a new complete record with the data
|
||||
# The log_version timestamp will ensure this is treated as the latest version
|
||||
self.save(execution)
|
||||
logger.debug(
|
||||
"save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
|
||||
execution.id,
|
||||
execution.node_execution_id,
|
||||
)
|
||||
# No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
|
||||
# Calling save() again would create a duplicate record in the append-only LogStore
|
||||
|
||||
# Dual-write to SQL database if enabled (for safe migration)
|
||||
if self._enable_dual_write:
|
||||
try:
|
||||
self.sql_repository.save_execution_data(execution)
|
||||
logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
|
||||
# Don't raise - LogStore write succeeded, SQL is just a backup
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
|
||||
This ensures we only get the final version of each node execution.
|
||||
Uses LogStore SQL query with window function to get the latest version of each node execution.
|
||||
This ensures we only get the most recent version of each node execution record.
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
A list of NodeExecution instances
|
||||
|
||||
Note:
|
||||
This method filters by finished_at IS NOT NULL to avoid duplicates from
|
||||
version updates. For complete history including intermediate states,
|
||||
a different query strategy would be needed.
|
||||
This method uses ROW_NUMBER() window function partitioned by node_execution_id
|
||||
to get the latest version (highest log_version) of each node execution.
|
||||
"""
|
||||
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
|
||||
# Build SQL query with deduplication using finished_at IS NOT NULL
|
||||
# This optimization avoids window functions for common case where we only
|
||||
# want the final state of each node execution
|
||||
# Build SQL query with deduplication using window function
|
||||
# ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
|
||||
# ensures we get the latest version of each node execution
|
||||
|
||||
# Build ORDER BY clause
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_workflow_run_id = escape_identifier(workflow_run_id)
|
||||
escaped_tenant_id = escape_identifier(self._tenant_id)
|
||||
|
||||
# Build ORDER BY clause for outer query
|
||||
order_clause = ""
|
||||
if order_config and order_config.order_by:
|
||||
order_fields = []
|
||||
@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
if order_fields:
|
||||
order_clause = "ORDER BY " + ", ".join(order_fields)
|
||||
|
||||
sql = f"""
|
||||
SELECT *
|
||||
FROM {AliyunLogStore.workflow_node_execution_logstore}
|
||||
WHERE workflow_run_id='{workflow_run_id}'
|
||||
AND tenant_id='{self._tenant_id}'
|
||||
AND finished_at IS NOT NULL
|
||||
"""
|
||||
|
||||
# Build app_id filter for subquery
|
||||
app_id_filter = ""
|
||||
if self._app_id:
|
||||
sql += f" AND app_id='{self._app_id}'"
|
||||
escaped_app_id = escape_identifier(self._app_id)
|
||||
app_id_filter = f" AND app_id='{escaped_app_id}'"
|
||||
|
||||
# Use window function to get latest version of each node execution
|
||||
sql = f"""
|
||||
SELECT * FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
|
||||
FROM {AliyunLogStore.workflow_node_execution_logstore}
|
||||
WHERE workflow_run_id='{escaped_workflow_run_id}'
|
||||
AND tenant_id='{escaped_tenant_id}'
|
||||
{app_id_filter}
|
||||
) t
|
||||
WHERE rn = 1
|
||||
"""
|
||||
|
||||
if order_clause:
|
||||
sql += f" {order_clause}"
|
||||
|
||||
134
api/extensions/logstore/sql_escape.py
Normal file
134
api/extensions/logstore/sql_escape.py
Normal file
@ -0,0 +1,134 @@
|
||||
"""
|
||||
SQL Escape Utility for LogStore Queries
|
||||
|
||||
This module provides escaping utilities to prevent injection attacks in LogStore queries.
|
||||
|
||||
LogStore supports two query modes:
|
||||
1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
|
||||
2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
|
||||
|
||||
Key Security Concerns:
|
||||
- Prevent tenant A from accessing tenant B's data via injection
|
||||
- SLS queries are read-only, so we focus on data access control
|
||||
- Different escaping strategies for SQL vs LogStore query syntax
|
||||
"""
|
||||
|
||||
|
||||
def escape_sql_string(value: str) -> str:
|
||||
"""
|
||||
Escape a string value for safe use in SQL queries.
|
||||
|
||||
This function escapes single quotes by doubling them, which is the standard
|
||||
SQL escaping method. This prevents SQL injection by ensuring that user input
|
||||
cannot break out of string literals.
|
||||
|
||||
Args:
|
||||
value: The string value to escape
|
||||
|
||||
Returns:
|
||||
Escaped string safe for use in SQL queries
|
||||
|
||||
Examples:
|
||||
>>> escape_sql_string("normal_value")
|
||||
"normal_value"
|
||||
>>> escape_sql_string("value' OR '1'='1")
|
||||
"value'' OR ''1''=''1"
|
||||
>>> escape_sql_string("tenant's_id")
|
||||
"tenant''s_id"
|
||||
|
||||
Security:
|
||||
- Prevents breaking out of string literals
|
||||
- Stops injection attacks like: ' OR '1'='1
|
||||
- Protects against cross-tenant data access
|
||||
"""
|
||||
if not value:
|
||||
return value
|
||||
|
||||
# Escape single quotes by doubling them (standard SQL escaping)
|
||||
# This prevents breaking out of string literals in SQL queries
|
||||
return value.replace("'", "''")
|
||||
|
||||
|
||||
def escape_identifier(value: str) -> str:
|
||||
"""
|
||||
Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
|
||||
|
||||
This function is for PG protocol mode (SQL syntax).
|
||||
For SDK mode, use escape_logstore_query_value() instead.
|
||||
|
||||
Args:
|
||||
value: The identifier value to escape
|
||||
|
||||
Returns:
|
||||
Escaped identifier safe for use in SQL queries
|
||||
|
||||
Examples:
|
||||
>>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
|
||||
"550e8400-e29b-41d4-a716-446655440000"
|
||||
>>> escape_identifier("tenant_id' OR '1'='1")
|
||||
"tenant_id'' OR ''1''=''1"
|
||||
|
||||
Security:
|
||||
- Prevents SQL injection via identifiers
|
||||
- Stops cross-tenant access attempts
|
||||
- Works for UUIDs, alphanumeric IDs, and similar identifiers
|
||||
"""
|
||||
# For identifiers, use the same escaping as strings
|
||||
# This is simple and effective for preventing injection
|
||||
return escape_sql_string(value)
|
||||
|
||||
|
||||
def escape_logstore_query_value(value: str) -> str:
|
||||
"""
|
||||
Escape value for LogStore query syntax (SDK mode).
|
||||
|
||||
LogStore query syntax rules:
|
||||
1. Keywords (and/or/not) are case-insensitive
|
||||
2. Single quotes are ordinary characters (no special meaning)
|
||||
3. Double quotes wrap values: key:"value"
|
||||
4. Backslash is the escape character:
|
||||
- \" for double quote inside value
|
||||
- \\ for backslash itself
|
||||
5. Parentheses can change query structure
|
||||
|
||||
To prevent injection:
|
||||
- Wrap value in double quotes to treat special chars as literals
|
||||
- Escape backslashes and double quotes using backslash
|
||||
|
||||
Args:
|
||||
value: The value to escape for LogStore query syntax
|
||||
|
||||
Returns:
|
||||
Quoted and escaped value safe for LogStore query syntax (includes the quotes)
|
||||
|
||||
Examples:
|
||||
>>> escape_logstore_query_value("normal_value")
|
||||
'"normal_value"'
|
||||
>>> escape_logstore_query_value("value or field:evil")
|
||||
'"value or field:evil"' # 'or' and ':' are now literals
|
||||
>>> escape_logstore_query_value('value"test')
|
||||
'"value\\"test"' # Internal double quote escaped
|
||||
>>> escape_logstore_query_value('value\\test')
|
||||
'"value\\\\test"' # Backslash escaped
|
||||
|
||||
Security:
|
||||
- Prevents injection via and/or/not keywords
|
||||
- Prevents injection via colons (:)
|
||||
- Prevents injection via parentheses
|
||||
- Protects against cross-tenant data access
|
||||
|
||||
Note:
|
||||
Escape order is critical: backslash first, then double quotes.
|
||||
Otherwise, we'd double-escape the escape character itself.
|
||||
"""
|
||||
if not value:
|
||||
return '""'
|
||||
|
||||
# IMPORTANT: Escape backslashes FIRST, then double quotes
|
||||
# This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
|
||||
escaped = value.replace("\\", "\\\\") # \ -> \\
|
||||
escaped = escaped.replace('"', '\\"') # " -> \"
|
||||
|
||||
# Wrap in double quotes to treat as literal string
|
||||
# This prevents and/or/not/:/() from being interpreted as operators
|
||||
return f'"{escaped}"'
|
||||
@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""
|
||||
Handler that records exceptions to the current OpenTelemetry span.
|
||||
|
||||
Unlike creating a new span, this records exceptions on the existing span
|
||||
to maintain trace context consistency throughout the request lifecycle.
|
||||
"""
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
with contextlib.suppress(Exception):
|
||||
if record.exc_info:
|
||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
||||
with tracer.start_as_current_span(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
) as span:
|
||||
span.set_status(StatusCode.ERROR)
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
span.set_attribute("exception.message", str(record.exc_info[1]))
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
if not record.exc_info:
|
||||
return
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
span = get_current_span()
|
||||
if not span or not span.is_recording():
|
||||
return
|
||||
|
||||
# Record exception on the current span instead of creating a new one
|
||||
span.set_status(StatusCode.ERROR, record.getMessage())
|
||||
|
||||
# Add log context as span events/attributes
|
||||
span.add_event(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
)
|
||||
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
|
||||
|
||||
def instrument_exception_logging() -> None:
|
||||
|
||||
20
api/extensions/otel/parser/__init__.py
Normal file
20
api/extensions/otel/parser/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""
|
||||
OpenTelemetry node parsers for workflow nodes.
|
||||
|
||||
This module provides parsers that extract node-specific metadata and set
|
||||
OpenTelemetry span attributes according to semantic conventions.
|
||||
"""
|
||||
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.parser.llm import LLMNodeOTelParser
|
||||
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
|
||||
from extensions.otel.parser.tool import ToolNodeOTelParser
|
||||
|
||||
__all__ = [
|
||||
"DefaultNodeOTelParser",
|
||||
"LLMNodeOTelParser",
|
||||
"NodeOTelParser",
|
||||
"RetrievalNodeOTelParser",
|
||||
"ToolNodeOTelParser",
|
||||
"safe_json_dumps",
|
||||
]
|
||||
117
api/extensions/otel/parser/base.py
Normal file
117
api/extensions/otel/parser/base.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
Base parser interface and utilities for OpenTelemetry node parsers.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Protocol
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.status import Status, StatusCode
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.models import File
|
||||
from core.variables import Segment
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
|
||||
"""
|
||||
Safely serialize objects to JSON, handling non-serializable types.
|
||||
|
||||
Handles:
|
||||
- Segment types (ArrayFileSegment, FileSegment, etc.) - converts to their value
|
||||
- File objects - converts to dict using to_dict()
|
||||
- BaseModel objects - converts using model_dump()
|
||||
- Other types - falls back to str() representation
|
||||
|
||||
Args:
|
||||
obj: Object to serialize
|
||||
ensure_ascii: Whether to ensure ASCII encoding
|
||||
|
||||
Returns:
|
||||
JSON string representation of the object
|
||||
"""
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
"""Recursively convert non-serializable values."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (bool, int, float, str)):
|
||||
return value
|
||||
if isinstance(value, Segment):
|
||||
# Convert Segment to its underlying value
|
||||
return _convert_value(value.value)
|
||||
if isinstance(value, File):
|
||||
# Convert File to dict
|
||||
return value.to_dict()
|
||||
if isinstance(value, BaseModel):
|
||||
# Convert Pydantic model to dict
|
||||
return _convert_value(value.model_dump(mode="json"))
|
||||
if isinstance(value, dict):
|
||||
return {k: _convert_value(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_convert_value(item) for item in value]
|
||||
# Fallback to string representation for unknown types
|
||||
return str(value)
|
||||
|
||||
try:
|
||||
converted = _convert_value(obj)
|
||||
return json.dumps(converted, ensure_ascii=ensure_ascii)
|
||||
except (TypeError, ValueError) as e:
|
||||
# If conversion still fails, return error message as string
|
||||
return json.dumps(
|
||||
{"error": f"Failed to serialize: {type(obj).__name__}", "message": str(e)}, ensure_ascii=ensure_ascii
|
||||
)
|
||||
|
||||
|
||||
class NodeOTelParser(Protocol):
|
||||
"""Parser interface for node-specific OpenTelemetry enrichment."""
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class DefaultNodeOTelParser:
|
||||
"""Fallback parser used when no node-specific parser is registered."""
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
span.set_attribute("node.id", node.id)
|
||||
if node.execution_id:
|
||||
span.set_attribute("node.execution_id", node.execution_id)
|
||||
if hasattr(node, "node_type") and node.node_type:
|
||||
span.set_attribute("node.type", node.node_type.value)
|
||||
|
||||
span.set_attribute(GenAIAttributes.FRAMEWORK, "dify")
|
||||
|
||||
node_type = getattr(node, "node_type", None)
|
||||
if isinstance(node_type, NodeType):
|
||||
if node_type == NodeType.LLM:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
|
||||
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
|
||||
elif node_type == NodeType.TOOL:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
|
||||
else:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
|
||||
else:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
|
||||
|
||||
# Extract inputs and outputs from result_event
|
||||
if result_event and result_event.node_run_result:
|
||||
node_run_result = result_event.node_run_result
|
||||
if node_run_result.inputs:
|
||||
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
|
||||
if node_run_result.outputs:
|
||||
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
|
||||
|
||||
if error:
|
||||
span.record_exception(error)
|
||||
span.set_status(Status(StatusCode.ERROR, str(error)))
|
||||
else:
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
155
api/extensions/otel/parser/llm.py
Normal file
155
api/extensions/otel/parser/llm.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""
|
||||
Parser for LLM nodes that captures LLM-specific metadata.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.semconv.gen_ai import LLMAttributes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_input_messages(process_data: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Format input messages from process_data for LLM spans.
|
||||
|
||||
Args:
|
||||
process_data: Process data containing prompts
|
||||
|
||||
Returns:
|
||||
JSON string of formatted input messages
|
||||
"""
|
||||
try:
|
||||
if not isinstance(process_data, dict):
|
||||
return safe_json_dumps([])
|
||||
|
||||
prompts = process_data.get("prompts", [])
|
||||
if not prompts:
|
||||
return safe_json_dumps([])
|
||||
|
||||
valid_roles = {"system", "user", "assistant", "tool"}
|
||||
input_messages = []
|
||||
for prompt in prompts:
|
||||
if not isinstance(prompt, dict):
|
||||
continue
|
||||
|
||||
role = prompt.get("role", "")
|
||||
text = prompt.get("text", "")
|
||||
|
||||
if not role or role not in valid_roles:
|
||||
continue
|
||||
|
||||
if text:
|
||||
message = {"role": role, "parts": [{"type": "text", "content": text}]}
|
||||
input_messages.append(message)
|
||||
|
||||
return safe_json_dumps(input_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to format input messages: %s", e, exc_info=True)
|
||||
return safe_json_dumps([])
|
||||
|
||||
|
||||
def _format_output_messages(outputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Format output messages from outputs for LLM spans.
|
||||
|
||||
Args:
|
||||
outputs: Output data containing text and finish_reason
|
||||
|
||||
Returns:
|
||||
JSON string of formatted output messages
|
||||
"""
|
||||
try:
|
||||
if not isinstance(outputs, dict):
|
||||
return safe_json_dumps([])
|
||||
|
||||
text = outputs.get("text", "")
|
||||
finish_reason = outputs.get("finish_reason", "")
|
||||
|
||||
if not text:
|
||||
return safe_json_dumps([])
|
||||
|
||||
valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
|
||||
if finish_reason not in valid_finish_reasons:
|
||||
finish_reason = "stop"
|
||||
|
||||
output_message = {
|
||||
"role": "assistant",
|
||||
"parts": [{"type": "text", "content": text}],
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
return safe_json_dumps([output_message])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to format output messages: %s", e, exc_info=True)
|
||||
return safe_json_dumps([])
|
||||
|
||||
|
||||
class LLMNodeOTelParser:
|
||||
"""Parser for LLM nodes that captures LLM-specific metadata."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._delegate = DefaultNodeOTelParser()
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
|
||||
|
||||
if not result_event or not result_event.node_run_result:
|
||||
return
|
||||
|
||||
node_run_result = result_event.node_run_result
|
||||
process_data = node_run_result.process_data or {}
|
||||
outputs = node_run_result.outputs or {}
|
||||
|
||||
# Extract usage data (from process_data or outputs)
|
||||
usage_data = process_data.get("usage") or outputs.get("usage") or {}
|
||||
|
||||
# Model and provider information
|
||||
model_name = process_data.get("model_name") or ""
|
||||
model_provider = process_data.get("model_provider") or ""
|
||||
|
||||
if model_name:
|
||||
span.set_attribute(LLMAttributes.REQUEST_MODEL, model_name)
|
||||
if model_provider:
|
||||
span.set_attribute(LLMAttributes.PROVIDER_NAME, model_provider)
|
||||
|
||||
# Token usage
|
||||
if usage_data:
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
total_tokens = usage_data.get("total_tokens", 0)
|
||||
|
||||
span.set_attribute(LLMAttributes.USAGE_INPUT_TOKENS, prompt_tokens)
|
||||
span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens)
|
||||
span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens)
|
||||
|
||||
# Prompts and completion
|
||||
prompts = process_data.get("prompts", [])
|
||||
if prompts:
|
||||
prompts_json = safe_json_dumps(prompts)
|
||||
span.set_attribute(LLMAttributes.PROMPT, prompts_json)
|
||||
|
||||
text_output = str(outputs.get("text", ""))
|
||||
if text_output:
|
||||
span.set_attribute(LLMAttributes.COMPLETION, text_output)
|
||||
|
||||
# Finish reason
|
||||
finish_reason = outputs.get("finish_reason") or ""
|
||||
if finish_reason:
|
||||
span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason)
|
||||
|
||||
# Structured input/output messages
|
||||
gen_ai_input_message = _format_input_messages(process_data)
|
||||
gen_ai_output_message = _format_output_messages(outputs)
|
||||
|
||||
span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
|
||||
span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)
|
||||
105
api/extensions/otel/parser/retrieval.py
Normal file
105
api/extensions/otel/parser/retrieval.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""
|
||||
Parser for knowledge retrieval nodes that captures retrieval-specific metadata.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.semconv.gen_ai import RetrieverAttributes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_retrieval_documents(retrieval_documents: list[Any]) -> list:
|
||||
"""
|
||||
Format retrieval documents for semantic conventions.
|
||||
|
||||
Args:
|
||||
retrieval_documents: List of retrieval document dictionaries
|
||||
|
||||
Returns:
|
||||
List of formatted semantic documents
|
||||
"""
|
||||
try:
|
||||
if not isinstance(retrieval_documents, list):
|
||||
return []
|
||||
|
||||
semantic_documents = []
|
||||
for doc in retrieval_documents:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
score = metadata.get("score", 0.0)
|
||||
document_id = metadata.get("document_id", "")
|
||||
|
||||
semantic_metadata = {}
|
||||
if title:
|
||||
semantic_metadata["title"] = title
|
||||
if metadata.get("source"):
|
||||
semantic_metadata["source"] = metadata["source"]
|
||||
elif metadata.get("_source"):
|
||||
semantic_metadata["source"] = metadata["_source"]
|
||||
if metadata.get("doc_metadata"):
|
||||
doc_metadata = metadata["doc_metadata"]
|
||||
if isinstance(doc_metadata, dict):
|
||||
semantic_metadata.update(doc_metadata)
|
||||
|
||||
semantic_doc = {
|
||||
"document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
|
||||
}
|
||||
semantic_documents.append(semantic_doc)
|
||||
|
||||
return semantic_documents
|
||||
except Exception as e:
|
||||
logger.warning("Failed to format retrieval documents: %s", e, exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
class RetrievalNodeOTelParser:
|
||||
"""Parser for knowledge retrieval nodes that captures retrieval-specific metadata."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._delegate = DefaultNodeOTelParser()
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
|
||||
|
||||
if not result_event or not result_event.node_run_result:
|
||||
return
|
||||
|
||||
node_run_result = result_event.node_run_result
|
||||
inputs = node_run_result.inputs or {}
|
||||
outputs = node_run_result.outputs or {}
|
||||
|
||||
# Extract query from inputs
|
||||
query = str(inputs.get("query", "")) if inputs else ""
|
||||
if query:
|
||||
span.set_attribute(RetrieverAttributes.QUERY, query)
|
||||
|
||||
# Extract and format retrieval documents from outputs
|
||||
result_value = outputs.get("result") if outputs else None
|
||||
retrieval_documents: list[Any] = []
|
||||
if result_value:
|
||||
value_to_check = result_value
|
||||
if isinstance(result_value, Segment):
|
||||
value_to_check = result_value.value
|
||||
|
||||
if isinstance(value_to_check, (list, Sequence)):
|
||||
retrieval_documents = list(value_to_check)
|
||||
|
||||
if retrieval_documents:
|
||||
semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
|
||||
semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
|
||||
span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)
|
||||
47
api/extensions/otel/parser/tool.py
Normal file
47
api/extensions/otel/parser/tool.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""
|
||||
Parser for tool nodes that captures tool-specific metadata.
|
||||
"""
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.semconv.gen_ai import ToolAttributes
|
||||
|
||||
|
||||
class ToolNodeOTelParser:
|
||||
"""Parser for tool nodes that captures tool-specific metadata."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._delegate = DefaultNodeOTelParser()
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
|
||||
|
||||
tool_data = getattr(node, "_node_data", None)
|
||||
if not isinstance(tool_data, ToolNodeData):
|
||||
return
|
||||
|
||||
span.set_attribute(ToolAttributes.TOOL_NAME, node.title)
|
||||
span.set_attribute(ToolAttributes.TOOL_TYPE, tool_data.provider_type.value)
|
||||
|
||||
# Extract tool info from metadata (consistent with aliyun_trace)
|
||||
tool_info = {}
|
||||
if result_event and result_event.node_run_result:
|
||||
node_run_result = result_event.node_run_result
|
||||
if node_run_result.metadata:
|
||||
tool_info = node_run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
|
||||
if tool_info:
|
||||
span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info))
|
||||
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
|
||||
span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
|
||||
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
|
||||
span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))
|
||||
@ -1,6 +1,13 @@
|
||||
"""Semantic convention shortcuts for Dify-specific spans."""
|
||||
|
||||
from .dify import DifySpanAttributes
|
||||
from .gen_ai import GenAIAttributes
|
||||
from .gen_ai import ChainAttributes, GenAIAttributes, LLMAttributes, RetrieverAttributes, ToolAttributes
|
||||
|
||||
__all__ = ["DifySpanAttributes", "GenAIAttributes"]
|
||||
__all__ = [
|
||||
"ChainAttributes",
|
||||
"DifySpanAttributes",
|
||||
"GenAIAttributes",
|
||||
"LLMAttributes",
|
||||
"RetrieverAttributes",
|
||||
"ToolAttributes",
|
||||
]
|
||||
|
||||
@ -62,3 +62,37 @@ class ToolAttributes:
|
||||
|
||||
TOOL_CALL_RESULT = "gen_ai.tool.call.result"
|
||||
"""Tool invocation result."""
|
||||
|
||||
|
||||
class LLMAttributes:
|
||||
"""LLM operation attribute keys."""
|
||||
|
||||
REQUEST_MODEL = "gen_ai.request.model"
|
||||
"""Model identifier."""
|
||||
|
||||
PROVIDER_NAME = "gen_ai.provider.name"
|
||||
"""Provider name."""
|
||||
|
||||
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
"""Number of input tokens."""
|
||||
|
||||
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
"""Number of output tokens."""
|
||||
|
||||
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
"""Total number of tokens."""
|
||||
|
||||
PROMPT = "gen_ai.prompt"
|
||||
"""Prompt text."""
|
||||
|
||||
COMPLETION = "gen_ai.completion"
|
||||
"""Completion text."""
|
||||
|
||||
RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
"""Finish reason for the response."""
|
||||
|
||||
INPUT_MESSAGE = "gen_ai.input.messages"
|
||||
"""Input messages in structured format."""
|
||||
|
||||
OUTPUT_MESSAGE = "gen_ai.output.messages"
|
||||
"""Output messages in structured format."""
|
||||
|
||||
@ -5,6 +5,8 @@ automatic cleanup, backup and restore.
|
||||
Supports complete lifecycle management for knowledge base files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
@ -48,7 +50,7 @@ class FileMetadata:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
||||
def from_dict(cls, data: dict) -> FileMetadata:
|
||||
"""Create instance from dictionary"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
|
||||
@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage):
|
||||
super().__init__()
|
||||
|
||||
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
|
||||
config = CosConfig(
|
||||
Region=dify_config.TENCENT_COS_REGION,
|
||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||
)
|
||||
if dify_config.TENCENT_COS_CUSTOM_DOMAIN:
|
||||
config = CosConfig(
|
||||
Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN,
|
||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||
)
|
||||
else:
|
||||
config = CosConfig(
|
||||
Region=dify_config.TENCENT_COS_REGION,
|
||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||
)
|
||||
self.client = CosS3Client(config)
|
||||
|
||||
def save(self, filename, data):
|
||||
|
||||
Reference in New Issue
Block a user