diff --git a/api/app_factory.py b/api/app_factory.py index 066eb2ae2c..76838f9925 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_enterprise_telemetry, ext_fastopenapi, ext_forward_refs, ext_hosting_provider, @@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_fastopenapi, ext_otel, + ext_enterprise_telemetry, ext_request_logging, ext_session_factory, ] diff --git a/api/configs/app_config.py b/api/configs/app_config.py index d3b1cf9d5b..831f0a49e0 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings from libs.file_utils import search_file_upwards from .deploy import DeploymentConfig -from .enterprise import EnterpriseFeatureConfig +from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig from .extra import ExtraServiceConfig from .feature import FeatureConfig from .middleware import MiddlewareConfig @@ -73,6 +73,8 @@ class DifyConfig( # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, + # Enterprise telemetry configs + EnterpriseTelemetryConfig, ): model_config = SettingsConfigDict( # read from dotenv format config file diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index f8447c6979..8a6a921a4e 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -22,3 +22,52 @@ class EnterpriseFeatureConfig(BaseSettings): ENTERPRISE_REQUEST_TIMEOUT: int = Field( ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 ) + + +class EnterpriseTelemetryConfig(BaseSettings): + """ + Configuration for enterprise telemetry. + """ + + ENTERPRISE_TELEMETRY_ENABLED: bool = Field( + description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).", + default=False, + ) + + ENTERPRISE_OTLP_ENDPOINT: str = Field( + description="Enterprise OTEL collector endpoint.", + default="", + ) + + ENTERPRISE_OTLP_HEADERS: str = Field( + description="Auth headers for OTLP export (key=value,key2=value2).", + default="", + ) + + ENTERPRISE_OTLP_PROTOCOL: str = Field( + description="OTLP protocol: 'http' or 'grpc' (default: http).", + default="http", + ) + + ENTERPRISE_OTLP_API_KEY: str = Field( + description="Bearer token for enterprise OTLP export authentication.", + default="", + ) + + ENTERPRISE_INCLUDE_CONTENT: bool = Field( + description="Include input/output content in traces (privacy toggle).", + # Setting the default value to False to avoid accidentally log PII data in traces. + default=False, + ) + + ENTERPRISE_SERVICE_NAME: str = Field( + description="Service name for OTEL resource.", + default="dify", + ) + + ENTERPRISE_OTEL_SAMPLING_RATE: float = Field( + description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).", + default=1.0, + ge=0.0, + le=1.0, + ) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 50a2cdea63..45b2f635ba 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): message_id: str | None = None message_data: Any | None = None - inputs: Union[str, dict[str, Any], list] | None = None - outputs: Union[str, dict[str, Any], list] | None = None + inputs: Union[str, dict[str, Any], list[Any]] | None = None + outputs: Union[str, dict[str, Any], list[Any]] | None = None start_time: datetime | None = None end_time: datetime | None = None metadata: dict[str, Any] @@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel): @field_validator("inputs", "outputs") @classmethod - def ensure_type(cls, v): + def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None: if v is None: return None if isinstance(v, str | dict | list): @@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel): model_config = ConfigDict(protected_namespaces=()) + @property + def resolved_trace_id(self) -> str | None: + """Get trace_id with intelligent fallback. + + Priority: + 1. External trace_id (from X-Trace-Id header) + 2. workflow_run_id (if this trace type has it) + 3. message_id (as final fallback) + """ + if self.trace_id: + return self.trace_id + + # Try workflow_run_id (only exists on workflow-related traces) + workflow_run_id = getattr(self, "workflow_run_id", None) + if workflow_run_id: + return workflow_run_id + + # Final fallback to message_id + return str(self.message_id) if self.message_id else None + + @property + def resolved_parent_context(self) -> tuple[str | None, str | None]: + """Resolve cross-workflow parent linking from metadata. + + Extracts typed parent IDs from the untyped ``parent_trace_context`` + metadata dict (set by tool_node when invoking nested workflows). + + Returns: + (trace_correlation_override, parent_span_id_source) where + trace_correlation_override is the outer workflow_run_id and + parent_span_id_source is the outer node_execution_id. + """ + parent_ctx = self.metadata.get("parent_trace_context") + if not isinstance(parent_ctx, dict): + return None, None + trace_override = parent_ctx.get("parent_workflow_run_id") + parent_span = parent_ctx.get("parent_node_execution_id") + return ( + trace_override if isinstance(trace_override, str) else None, + parent_span if isinstance(parent_span, str) else None, + ) + @field_serializer("start_time", "end_time") def serialize_datetime(self, dt: datetime | None) -> str | None: if dt is None: @@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_version: str error: str | None = None total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None file_list: list[str] + invoked_by: str | None = None query: str metadata: dict[str, Any] @@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo): answer_tokens: int total_tokens: int error: str | None = None - file_list: Union[str, dict[str, Any], list] | None = None + file_list: Union[str, dict[str, Any], list[Any]] | None = None message_file_data: Any | None = None conversation_mode: str gen_ai_server_time_to_first_token: float | None = None @@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] = None + file_url: Union[str, None, list[str]] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo): tenant_id: str +class PromptGenerationTraceInfo(BaseTraceInfo): + """Trace information for prompt generation operations (rule-generate, code-generate, etc.).""" + + tenant_id: str + user_id: str + app_id: str | None = None + + operation_type: str + instruction: str + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + model_provider: str + model_name: str + + latency: float + + total_price: float | None = None + currency: str | None = None + + error: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class WorkflowNodeTraceInfo(BaseTraceInfo): + workflow_id: str + workflow_run_id: str + tenant_id: str + node_execution_id: str + node_id: str + node_type: str + title: str + + status: str + error: str | None = None + elapsed_time: float + + index: int + predecessor_node_id: str | None = None + + total_tokens: int = 0 + total_price: float = 0.0 + currency: str | None = None + + model_provider: str | None = None + model_name: str | None = None + prompt_tokens: int | None = None + completion_tokens: int | None = None + + tool_name: str | None = None + + iteration_id: str | None = None + iteration_index: int | None = None + loop_id: str | None = None + loop_index: int | None = None + parallel_id: str | None = None + + node_inputs: Mapping[str, Any] | None = None + node_outputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + + invoked_by: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class DraftNodeExecutionTrace(WorkflowNodeTraceInfo): + pass + + class TaskData(BaseModel): app_id: str trace_info_type: str @@ -128,11 +246,31 @@ trace_info_info_map = { "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, "ToolTraceInfo": ToolTraceInfo, "GenerateNameTraceInfo": GenerateNameTraceInfo, + "PromptGenerationTraceInfo": PromptGenerationTraceInfo, + "WorkflowNodeTraceInfo": WorkflowNodeTraceInfo, + "DraftNodeExecutionTrace": DraftNodeExecutionTrace, } +class OperationType(StrEnum): + """Operation type for token metric labels. + + Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output`` + counters so consumers can break down token usage by operation. + """ + + WORKFLOW = "workflow" + NODE_EXECUTION = "node_execution" + MESSAGE = "message" + RULE_GENERATE = "rule_generate" + CODE_GENERATE = "code_generate" + STRUCTURED_OUTPUT = "structured_output" + INSTRUCTION_MODIFY = "instruction_modify" + + class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" + DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution" WORKFLOW_TRACE = "workflow" MESSAGE_TRACE = "message" MODERATION_TRACE = "moderation" @@ -140,4 +278,6 @@ class TraceTaskName(StrEnum): DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + PROMPT_GENERATION_TRACE = "prompt_generation" + NODE_EXECUTION_TRACE = "node_execution" DATASOURCE_TRACE = "datasource" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 87a7579f3a..0a2a0642f1 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,22 +15,32 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum +from core.ops.entities.config_entity import ( + OPS_FILE_PATH, + TracingProviderEnum, +) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, GenerateNameTraceInfo, MessageTraceInfo, ModerationTraceInfo, + PromptGenerationTraceInfo, SuggestedQuestionTraceInfo, TaskData, ToolTraceInfo, TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) from core.ops.utils import get_message_data +from extensions.ext_database import db from extensions.ext_storage import storage -from models.engine import db +from models.account import Tenant +from models.dataset import Dataset from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig +from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks @@ -40,9 +50,144 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: + """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" + app_name = "" + workspace_name = "" + if not app_id and not tenant_id: + return app_name, workspace_name + with Session(db.engine) as session: + if app_id: + name = session.scalar(select(App.name).where(App.id == app_id)) + if name: + app_name = name + if tenant_id: + name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id)) + if name: + workspace_name = name + return app_name, workspace_name + + +_PROVIDER_TYPE_TO_MODEL: dict[str, type] = { + "builtin": BuiltinToolProvider, + "plugin": BuiltinToolProvider, + "api": ApiToolProvider, + "workflow": WorkflowToolProvider, + "mcp": MCPToolProvider, +} + + +def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str: + if not credential_id: + return "" + model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "") + if not model_cls: + return "" + with Session(db.engine) as session: + name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined] + return str(name) if name else "" + + +def _lookup_llm_credential_info( + tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm" +) -> tuple[str | None, str]: + """ + Lookup LLM credential ID and name for the given provider and model. + Returns (credential_id, credential_name). + + Handles async timing issues gracefully - if credential is deleted between lookups, + returns the ID but empty name rather than failing. + """ + if not tenant_id or not provider: + return None, "" + + try: + with Session(db.engine) as session: + # Try to find provider-level or model-level configuration + provider_record = session.scalar( + select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider, + Provider.provider_type == ProviderType.CUSTOM, + ) + ) + + if not provider_record: + return None, "" + + # Check if there's a model-specific config + credential_id = None + credential_name = "" + is_model_level = False + + if model: + # Try model-level first + model_record = session.scalar( + select(ProviderModel).where( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name == provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type, + ) + ) + + if model_record and model_record.credential_id: + credential_id = model_record.credential_id + is_model_level = True + + if not credential_id and provider_record.credential_id: + # Fall back to provider-level credential + credential_id = provider_record.credential_id + is_model_level = False + + # Lookup credential_name if we have credential_id + if credential_id: + try: + if is_model_level: + # Query ProviderModelCredential + cred_name = session.scalar( + select(ProviderModelCredential.credential_name).where( + ProviderModelCredential.id == credential_id + ) + ) + else: + # Query ProviderCredential + cred_name = session.scalar( + select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id) + ) + + if cred_name: + credential_name = str(cred_name) + except Exception as e: + # Credential might have been deleted between lookups (async timing) + # Return ID but empty name rather than failing + logger.warning( + "Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s", + credential_id, + provider, + model, + str(e), + exc_info=True, + ) + + return credential_id, credential_name + except Exception as e: + # Database query failed or other unexpected error + # Return empty rather than propagating error to telemetry emission + logger.warning( + "Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s", + tenant_id, + provider, + model, + str(e), + exc_info=True, + ) + return None, "" + + class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, key: str) -> dict[str, Any]: - match key: + def __getitem__(self, provider: str) -> dict[str, Any]: + match provider: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -149,7 +294,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): } case _: - raise KeyError(f"Unsupported tracing provider: {key}") + raise KeyError(f"Unsupported tracing provider: {provider}") provider_config_map = OpsTraceProviderConfigMap() @@ -314,6 +459,10 @@ class OpsTraceManager: if app_id is None: return None + # Handle storage_id format (tenant-{uuid}) - not a real app_id + if isinstance(app_id, str) and app_id.startswith("tenant-"): + return None + app: App | None = db.session.query(App).where(App.id == app_id).first() if app is None: @@ -466,8 +615,6 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): - from repositories.factory import DifyAPIRepositoryFactory - if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: @@ -478,6 +625,77 @@ class TraceTask: cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo + @classmethod + def _calculate_workflow_token_split( + cls, session: "Session", workflow_run_id: str, tenant_id: str + ) -> tuple[int, int]: + """Sum prompt/completion tokens across all node executions for a workflow run. + + Reads from the ``outputs`` column (where LLM nodes store ``usage.prompt_tokens`` + and ``usage.completion_tokens``) rather than ``execution_metadata``, which only + carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading + large JSON blobs unnecessarily. + """ + import json + + from models.workflow import WorkflowNodeExecutionModel + + rows = ( + session.execute( + select(WorkflowNodeExecutionModel.outputs).where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + ) + .scalars() + .all() + ) + + total_prompt = 0 + total_completion = 0 + + for raw in rows: + if not raw: + continue + try: + outputs = json.loads(raw) if isinstance(raw, str) else raw + except (ValueError, TypeError): + continue + if not isinstance(outputs, dict): + continue + usage = outputs.get("usage") + if not isinstance(usage, dict): + continue + prompt = usage.get("prompt_tokens") + if isinstance(prompt, (int, float)): + total_prompt += int(prompt) + completion = usage.get("completion_tokens") + if isinstance(completion, (int, float)): + total_completion += int(completion) + + return (total_prompt, total_completion) + + @classmethod + def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str: + """Extract user ID from metadata, prioritizing end_user over account. + + Returns the actual user ID (end_user or account) who invoked the workflow, + regardless of invoke_from context. + """ + # Priority 1: End user (external users via API/WebApp) + if user_id := metadata.get("from_end_user_id"): + return f"end_user:{user_id}" + + # Priority 2: Account user (internal users via console/debugger) + if user_id := metadata.get("from_account_id"): + return f"account:{user_id}" + + # Priority 3: User (internal users via console/debugger) + if user_id := metadata.get("user_id"): + return f"user:{user_id}" + + return "anonymous" + def __init__( self, trace_type: Any, @@ -491,6 +709,7 @@ class TraceTask: self.trace_type = trace_type self.message_id = message_id self.workflow_run_id = workflow_execution.id_ if workflow_execution else None + self.workflow_total_tokens: int | None = workflow_execution.total_tokens if workflow_execution else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer @@ -498,6 +717,8 @@ class TraceTask: self.app_id = None self.trace_id = None self.kwargs = kwargs + if user_id is not None and "user_id" not in self.kwargs: + self.kwargs["user_id"] = user_id external_trace_id = kwargs.get("external_trace_id") if external_trace_id: self.trace_id = external_trace_id @@ -509,9 +730,12 @@ class TraceTask: preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + workflow_run_id=self.workflow_run_id, + conversation_id=self.conversation_id, + user_id=self.user_id, + total_tokens_override=self.workflow_total_tokens, ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs), TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( message_id=self.message_id, timer=self.timer, **self.kwargs ), @@ -527,6 +751,9 @@ class TraceTask: TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), + TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs), + TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs), + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs), } return preprocess_map.get(self.trace_type, lambda: None)() @@ -541,6 +768,7 @@ class TraceTask: workflow_run_id: str | None, conversation_id: str | None, user_id: str | None, + total_tokens_override: int | None = None, ): if not workflow_run_id: return {} @@ -560,7 +788,7 @@ class TraceTask: workflow_run_version = workflow_run.version error = workflow_run.error or "" - total_tokens = workflow_run.total_tokens + total_tokens = total_tokens_override if total_tokens_override is not None else workflow_run.total_tokens file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" @@ -581,8 +809,18 @@ class TraceTask: Message.workflow_run_id == workflow_run_id, ) message_id = session.scalar(message_data_stmt) + prompt_tokens, completion_tokens = self._calculate_workflow_token_split( + session, workflow_run_id=workflow_run_id, tenant_id=tenant_id + ) - metadata = { + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + metadata: dict[str, Any] = { "workflow_id": workflow_id, "conversation_id": conversation_id, "workflow_run_id": workflow_run_id, @@ -595,8 +833,14 @@ class TraceTask: "triggered_from": workflow_run.triggered_from, "user_id": user_id, "app_id": workflow_run.app_id, + "app_name": app_name, + "workspace_name": workspace_name, } + parent_trace_context = self.kwargs.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), @@ -611,6 +855,8 @@ class TraceTask: workflow_run_version=workflow_run_version, error=error, total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, file_list=file_list, query=query, metadata=metadata, @@ -618,10 +864,11 @@ class TraceTask: message_id=message_id, start_time=workflow_run.created_at, end_time=workflow_run.finished_at, + invoked_by=self._get_user_id_from_metadata(metadata), ) return workflow_trace_info - def message_trace(self, message_id: str | None): + def message_trace(self, message_id: str | None, **kwargs): if not message_id: return {} message_data = get_message_data(message_id) @@ -644,6 +891,19 @@ class TraceTask: streaming_metrics = self._extract_streaming_metrics(message_data) + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + metadata = { "conversation_id": message_data.conversation_id, "ls_provider": message_data.model_provider, @@ -655,7 +915,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, "message_id": message_id, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id message_tokens = message_data.message_tokens @@ -672,7 +939,9 @@ class TraceTask: outputs=message_data.answer, file_list=file_list, start_time=created_at, - end_time=created_at + timedelta(seconds=message_data.provider_response_latency), + end_time=message_data.updated_at + if message_data.updated_at and message_data.updated_at > created_at + else created_at + timedelta(seconds=message_data.provider_response_latency), metadata=metadata, message_file_data=message_file_data, conversation_mode=conversation_mode, @@ -697,6 +966,8 @@ class TraceTask: "preset_response": moderation_result.preset_response, "query": moderation_result.query, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None @@ -738,6 +1009,8 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None @@ -777,6 +1050,52 @@ class TraceTask: if not message_data: return {} + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + doc_list = [doc.model_dump() for doc in documents] if documents else [] + dataset_ids: set[str] = set() + for doc in doc_list: + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + if did: + dataset_ids.add(did) + + embedding_models: dict[str, dict[str, str]] = {} + if dataset_ids: + with Session(db.engine) as session: + rows = session.execute( + select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where( + Dataset.id.in_(list(dataset_ids)) + ) + ).all() + for row in rows: + embedding_models[str(row[0])] = { + "embedding_model": row[1] or "", + "embedding_model_provider": row[2] or "", + } + + # Extract rerank model info from retrieval_model kwargs + rerank_model_provider = "" + rerank_model_name = "" + if "retrieval_model" in kwargs: + retrieval_model = kwargs["retrieval_model"] + if isinstance(retrieval_model, dict): + reranking_model = retrieval_model.get("reranking_model") + if isinstance(reranking_model, dict): + rerank_model_provider = reranking_model.get("reranking_provider_name", "") + rerank_model_name = reranking_model.get("reranking_model_name", "") + metadata = { "message_id": message_id, "ls_provider": message_data.model_provider, @@ -787,13 +1106,23 @@ class TraceTask: "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, + "embedding_models": embedding_models, + "rerank_model_provider": rerank_model_provider, + "rerank_model_name": rerank_model_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents] if documents else [], + documents=doc_list, start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -836,6 +1165,10 @@ class TraceTask: "error": error, "tool_parameters": tool_parameters, } + if message_data.workflow_run_id: + metadata["workflow_run_id"] = message_data.workflow_run_id + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id file_url = "" message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() @@ -890,6 +1223,8 @@ class TraceTask: "conversation_id": conversation_id, "tenant_id": tenant_id, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id generate_name_trace_info = GenerateNameTraceInfo( trace_id=self.trace_id, @@ -904,6 +1239,182 @@ class TraceTask: return generate_name_trace_info + def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict: + tenant_id = kwargs.get("tenant_id", "") + user_id = kwargs.get("user_id", "") + app_id = kwargs.get("app_id") + operation_type = kwargs.get("operation_type", "") + instruction = kwargs.get("instruction", "") + generated_output = kwargs.get("generated_output", "") + + prompt_tokens = kwargs.get("prompt_tokens", 0) + completion_tokens = kwargs.get("completion_tokens", 0) + total_tokens = kwargs.get("total_tokens", 0) + + model_provider = kwargs.get("model_provider", "") + model_name = kwargs.get("model_name", "") + + latency = kwargs.get("latency", 0.0) + + timer = kwargs.get("timer") + start_time = timer.get("start") if timer else None + end_time = timer.get("end") if timer else None + + total_price = kwargs.get("total_price") + currency = kwargs.get("currency") + + error = kwargs.get("error") + + app_name = None + workspace_name = None + if app_id: + app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id) + + metadata = { + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id or "", + "app_name": app_name, + "workspace_name": workspace_name, + "operation_type": operation_type, + "model_provider": model_provider, + "model_name": model_name, + } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id + + return PromptGenerationTraceInfo( + trace_id=self.trace_id, + inputs=instruction, + outputs=generated_output, + start_time=start_time, + end_time=end_time, + metadata=metadata, + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=operation_type, + instruction=instruction, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model_provider=model_provider, + model_name=model_name, + latency=latency, + total_price=total_price, + currency=currency, + error=error, + ) + + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: + node_data: dict = kwargs.get("node_execution_data", {}) + if not node_data: + return {} + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names( + node_data.get("app_id"), node_data.get("tenant_id") + ) + else: + app_name, workspace_name = "", "" + + # Try tool credential lookup first + credential_id = node_data.get("credential_id") + if is_enterprise_telemetry_enabled(): + credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type")) + # If no credential_id found (e.g., LLM nodes), try LLM credential lookup + if not credential_id: + llm_cred_id, llm_cred_name = _lookup_llm_credential_info( + tenant_id=node_data.get("tenant_id"), + provider=node_data.get("model_provider"), + model=node_data.get("model_name"), + model_type="llm", + ) + if llm_cred_id: + credential_id = llm_cred_id + credential_name = llm_cred_name + else: + credential_name = "" + metadata: dict[str, Any] = { + "tenant_id": node_data.get("tenant_id"), + "app_id": node_data.get("app_id"), + "app_name": app_name, + "workspace_name": workspace_name, + "user_id": node_data.get("user_id"), + "invoke_from": node_data.get("invoke_from"), + "credential_id": credential_id, + "credential_name": credential_name, + "dataset_ids": node_data.get("dataset_ids"), + "dataset_names": node_data.get("dataset_names"), + "plugin_name": node_data.get("plugin_name"), + } + + parent_trace_context = node_data.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + + message_id: str | None = None + conversation_id = node_data.get("conversation_id") + workflow_execution_id = node_data.get("workflow_execution_id") + if conversation_id and workflow_execution_id and not parent_trace_context: + with Session(db.engine) as session: + msg_id = session.scalar( + select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_execution_id, + ) + ) + if msg_id: + message_id = str(msg_id) + metadata["message_id"] = message_id + if conversation_id: + metadata["conversation_id"] = conversation_id + + return WorkflowNodeTraceInfo( + trace_id=self.trace_id, + message_id=message_id, + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata=metadata, + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + invoked_by=self._get_user_id_from_metadata(metadata), + ) + + def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict: + node_trace = self.node_execution_trace(**kwargs) + if not isinstance(node_trace, WorkflowNodeTraceInfo): + return node_trace + return DraftNodeExecutionTrace(**node_trace.model_dump()) + def _extract_streaming_metrics(self, message_data) -> dict: if not message_data.message_metadata: return {} @@ -937,13 +1448,17 @@ class TraceQueueManager: self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) self.flask_app = current_app._get_current_object() # type: ignore + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() if trace_manager_timer is None: self.start_timer() def add_trace_task(self, trace_task: TraceTask): global trace_manager_timer, trace_manager_queue try: - if self.trace_instance: + if self._enterprise_telemetry_enabled or self.trace_instance: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception: @@ -979,20 +1494,27 @@ class TraceQueueManager: def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: - if task.app_id is None: - continue + storage_id = task.app_id + if storage_id is None: + tenant_id = task.kwargs.get("tenant_id") + if tenant_id: + storage_id = f"tenant-{tenant_id}" + else: + logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type) + continue + file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( - app_id=task.app_id, + app_id=storage_id, trace_info_type=type(trace_info).__name__, trace_info=trace_info.model_dump() if trace_info else None, ) - file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" + file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json" storage.save(file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, - "app_id": task.app_id, + "app_id": storage_id, } process_trace_tasks.delay(file_info) # type: ignore diff --git a/api/core/telemetry/__init__.py b/api/core/telemetry/__init__.py new file mode 100644 index 0000000000..ae4f53f3b7 --- /dev/null +++ b/api/core/telemetry/__init__.py @@ -0,0 +1,43 @@ +"""Telemetry facade. + +Thin public API for emitting telemetry events. All routing logic +lives in ``core.telemetry.gateway`` which is shared by both CE and EE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent +from core.telemetry.gateway import emit as gateway_emit +from core.telemetry.gateway import get_trace_task_to_case + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + + +def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None: + """Emit a telemetry event. + + Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a + ``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``. + """ + case = get_trace_task_to_case().get(event.name) + if case is None: + return + + context: dict[str, object] = { + "tenant_id": event.context.tenant_id, + "user_id": event.context.user_id, + "app_id": event.context.app_id, + } + gateway_emit(case, context, event.payload, trace_manager) + + +__all__ = [ + "TelemetryContext", + "TelemetryEvent", + "TraceTaskName", + "emit", +] diff --git a/api/core/telemetry/events.py b/api/core/telemetry/events.py new file mode 100644 index 0000000000..35ace47510 --- /dev/null +++ b/api/core/telemetry/events.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.ops.entities.trace_entity import TraceTaskName + + +@dataclass(frozen=True) +class TelemetryContext: + tenant_id: str | None = None + user_id: str | None = None + app_id: str | None = None + + +@dataclass(frozen=True) +class TelemetryEvent: + name: TraceTaskName + context: TelemetryContext + payload: dict[str, Any] diff --git a/api/core/telemetry/gateway.py b/api/core/telemetry/gateway.py new file mode 100644 index 0000000000..7b013d0563 --- /dev/null +++ b/api/core/telemetry/gateway.py @@ -0,0 +1,239 @@ +"""Telemetry gateway — single routing layer for all editions. + +Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either +the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only +metric/log Celery queue. + +This module lives in ``core/`` so both CE and EE share one routing table +and one ``emit()`` entry point. No separate enterprise gateway module is +needed — enterprise-specific dispatch (Celery task, payload offloading) +is handled here behind lazy imports that no-op in CE. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from core.ops.entities.trace_entity import TraceTaskName +from enterprise.telemetry.contracts import CaseRoute, SignalType +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + from enterprise.telemetry.contracts import TelemetryCase + +logger = logging.getLogger(__name__) + +PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024 + +# --------------------------------------------------------------------------- +# Routing table — authoritative mapping for all editions +# --------------------------------------------------------------------------- + +_case_to_trace_task: dict[TelemetryCase, TraceTaskName] | None = None +_case_routing: dict[TelemetryCase, CaseRoute] | None = None + + +def _get_case_to_trace_task() -> dict[TelemetryCase, TraceTaskName]: + global _case_to_trace_task + if _case_to_trace_task is None: + from enterprise.telemetry.contracts import TelemetryCase + + _case_to_trace_task = { + TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE, + TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE, + TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE, + TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE, + TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE, + TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE, + TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE, + TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE, + TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE, + } + return _case_to_trace_task + + +def get_trace_task_to_case() -> dict[TraceTaskName, TelemetryCase]: + """Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task).""" + return {v: k for k, v in _get_case_to_trace_task().items()} + + +def _get_case_routing() -> dict[TelemetryCase, CaseRoute]: + global _case_routing + if _case_routing is None: + from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase + + _case_routing = { + # TRACE — CE-eligible (flow in both CE and EE) + TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + # TRACE — enterprise-only + TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + # METRIC_LOG — enterprise-only (signal-driven, not trace) + TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + } + return _case_routing + + +def __getattr__(name: str) -> dict: + """Lazy module-level access to routing tables.""" + if name == "CASE_ROUTING": + return _get_case_routing() + if name == "CASE_TO_TRACE_TASK": + return _get_case_to_trace_task() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def is_enterprise_telemetry_enabled() -> bool: + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +def _handle_payload_sizing( + payload: dict[str, Any], + tenant_id: str, + event_id: str, +) -> tuple[dict[str, Any], str | None]: + """Inline or offload payload based on size. + + Returns ``(payload_for_envelope, storage_key | None)``. Payloads + exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object + storage and replaced with an empty dict in the envelope. + """ + try: + payload_json = json.dumps(payload) + payload_size = len(payload_json.encode("utf-8")) + except (TypeError, ValueError): + logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id) + return payload, None + + if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES: + return payload, None + + storage_key = f"telemetry/{tenant_id}/{event_id}.json" + try: + storage.save(storage_key, payload_json.encode("utf-8")) + logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size) + return {}, storage_key + except Exception: + logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True) + return payload, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def emit( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, +) -> None: + """Route a telemetry event to the correct pipeline. + + TRACE events are enqueued into ``TraceQueueManager`` (works in both CE + and EE). Enterprise-only traces are silently dropped when EE is + disabled. + + METRIC_LOG events are dispatched to the enterprise Celery queue; + silently dropped when enterprise telemetry is unavailable. + """ + route = _get_case_routing().get(case) + if route is None: + logger.warning("Unknown telemetry case: %s, dropping event", case) + return + + if not route.ce_eligible and not is_enterprise_telemetry_enabled(): + logger.debug("Dropping EE-only event: case=%s (EE disabled)", case) + return + + if route.signal_type == SignalType.TRACE: + _emit_trace(case, context, payload, trace_manager) + else: + _emit_metric_log(case, context, payload) + + +def _emit_trace( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None, +) -> None: + from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager + from core.ops.ops_trace_manager import TraceTask + + trace_task_name = _get_case_to_trace_task().get(case) + if trace_task_name is None: + logger.warning("No TraceTaskName mapping for case: %s", case) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload)) + logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id")) + + +def _emit_metric_log( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], +) -> None: + """Build envelope and dispatch to enterprise Celery queue. + + No-ops when the enterprise telemetry task is not importable (CE mode). + """ + try: + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + except ImportError: + logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case) + return + + tenant_id = context.get("tenant_id") or "" + event_id = str(uuid.uuid4()) + + payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id) + + from enterprise.telemetry.contracts import TelemetryEnvelope + + envelope = TelemetryEnvelope( + case=case, + tenant_id=tenant_id, + event_id=event_id, + payload=payload_for_envelope, + metadata={"payload_ref": payload_ref} if payload_ref else None, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + logger.debug( + "Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s", + case, + tenant_id, + event_id, + ) diff --git a/api/enterprise/__init__.py b/api/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/enterprise/telemetry/DATA_DICTIONARY.md b/api/enterprise/telemetry/DATA_DICTIONARY.md new file mode 100644 index 0000000000..60d482cd1c --- /dev/null +++ b/api/enterprise/telemetry/DATA_DICTIONARY.md @@ -0,0 +1,525 @@ +# Dify Enterprise Telemetry Data Dictionary + +Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md). + +## Resource Attributes + +Attached to every signal (Span, Metric, Log). + +| Attribute | Type | Example | +|-----------|------|---------| +| `service.name` | string | `dify` | +| `host.name` | string | `dify-api-7f8b` | + +## Traces (Spans) + +### `dify.workflow.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID (Workflow Run ID) | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Unique ID for this run | +| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. | +| `dify.workflow.error` | string | Error message if failed | +| `dify.workflow.elapsed_time` | float | Total execution time (seconds) | +| `dify.invoke_from` | string | `api`, `webapp`, `debug` | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.message.id` | string | Message ID (optional) | +| `dify.invoked_by` | string | User ID who triggered the run | +| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) | +| `gen_ai.user.id` | string | End-user identifier (optional) | +| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) | +| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) | +| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) | +| `dify.parent.app.id` | string | Parent app ID (optional) | + +### `dify.node.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Workflow Run ID | +| `dify.message.id` | string | Message ID (optional) | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.node.execution_id` | string | Unique node execution ID | +| `dify.node.id` | string | Node ID in workflow graph | +| `dify.node.type` | string | Node type (see appendix) | +| `dify.node.title` | string | Display title | +| `dify.node.status` | string | `succeeded`, `failed` | +| `dify.node.error` | string | Error message if failed | +| `dify.node.elapsed_time` | float | Execution time (seconds) | +| `dify.node.index` | int | Execution order index | +| `dify.node.predecessor_node_id` | string | Triggering node ID | +| `dify.node.iteration_id` | string | Iteration ID (optional) | +| `dify.node.loop_id` | string | Loop ID (optional) | +| `dify.node.parallel_id` | string | Parallel branch ID (optional) | +| `dify.node.invoked_by` | string | User ID who triggered execution | +| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) | +| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) | +| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) | +| `gen_ai.request.model` | string | LLM model name (LLM nodes only) | +| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) | +| `gen_ai.user.id` | string | End-user identifier (optional) | + +### `dify.node.execution.draft` + +Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs. + +## Counters + +All counters are cumulative and emitted at 100% accuracy. + +### Token Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.tokens.total` | `{token}` | Total tokens consumed | +| `dify.tokens.input` | `{token}` | Input (prompt) tokens | +| `dify.tokens.output` | `{token}` | Output (completion) tokens | + +**Labels:** + +- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution) + +⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting. + +#### Token Hierarchy & Query Patterns + +Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting: + +``` +App-level total +├── workflow ← sum of all node_execution tokens (DO NOT add both) +│ └── node_execution ← per-node breakdown +├── message ← independent (non-workflow chat apps only) +├── rule_generate ← independent helper LLM call +├── code_generate ← independent helper LLM call +├── structured_output ← independent helper LLM call +└── instruction_modify← independent helper LLM call +``` + +**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both. + +**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`. +App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries. + +**Common queries** (PromQL): + +```promql +# ── Totals ────────────────────────────────────────────────── +# App-level total (exclude node_execution to avoid double-counting) +sum by (app_id) (dify_tokens_total{operation_type!="node_execution"}) + +# Single app total +sum (dify_tokens_total{app_id="", operation_type!="node_execution"}) + +# Per-tenant totals +sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"}) + +# ── Drill-down ────────────────────────────────────────────── +# Workflow-level tokens for an app +sum (dify_tokens_total{app_id="", operation_type="workflow"}) + +# Node-level breakdown within an app +sum by (node_type) (dify_tokens_total{app_id="", operation_type="node_execution"}) + +# Model breakdown for an app +sum by (model_provider, model_name) (dify_tokens_total{app_id=""}) + +# Input vs output per model +sum by (model_name) (dify_tokens_input_total{app_id=""}) +sum by (model_name) (dify_tokens_output_total{app_id=""}) + +# ── Rates ─────────────────────────────────────────────────── +# Token consumption rate (per hour) +sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h])) + +# Per-app consumption rate +sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h])) +``` + +**Finding `app_id` from app name** (trace query — Tempo / Jaeger): + +``` +{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id) +``` + +### Request Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.requests.total` | `{request}` | Total operations count | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `moderation` | `tenant_id`, `app_id` | +| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dataset_retrieval` | `tenant_id`, `app_id` | +| `generate_name` | `tenant_id`, `app_id` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` | + +### Error Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.errors.total` | `{error}` | Total failed operations | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +### Other Counters + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` | +| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` | +| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` | +| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` | +| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` | + +## Histograms + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` | +| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` | +| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` | +| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +## Structured Logs + +### Span Companion Logs + +Logs that accompany spans. Signal type: `span_detail` + +#### `dify.workflow.run` Companion Log + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.workflow.version` | string | Yes | Workflow definition version | +| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) | +| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) | +| `dify.workflow.query` | string | No | User query text (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.workflow.run"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.invoke_from` | string | No | Invocation source | +| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) | +| `dify.node.total_price` | float | No | Cost (LLM nodes only) | +| `dify.node.currency` | string | No | Currency code (LLM nodes only) | +| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) | +| `dify.node.loop_index` | int | No | Loop index (loop nodes) | +| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) | +| `dify.credential.name` | string | No | Credential name (plugin nodes) | +| `dify.credential.id` | string | No | Credential ID (plugin nodes) | +| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) | +| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) | +| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) | +| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) | +| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +### Standalone Logs + +Logs without structural spans. Signal type: `metric_only` + +#### `dify.message.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.message.run"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID (32-char hex) | +| `span_id` | string | OTEL span ID (16-char hex) | +| `tenant_id` | string | Tenant identifier | +| `user_id` | string | User identifier (optional) | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.workflow.run_id` | string | Workflow run ID (optional) | +| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.message.status` | string | `succeeded`, `failed` | +| `dify.message.error` | string | Error message (if failed) | +| `dify.message.duration` | float | Duration (seconds) | +| `dify.message.time_to_first_token` | float | TTFT (seconds) | +| `dify.message.inputs` | string/JSON | Inputs (content-gated) | +| `dify.message.outputs` | string/JSON | Outputs (content-gated) | + +#### `dify.tool.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.tool.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.tool.name` | string | Tool name | +| `dify.tool.duration` | float | Duration (seconds) | +| `dify.tool.status` | string | `succeeded`, `failed` | +| `dify.tool.error` | string | Error message (if failed) | +| `dify.tool.inputs` | string/JSON | Inputs (content-gated) | +| `dify.tool.outputs` | string/JSON | Outputs (content-gated) | +| `dify.tool.parameters` | string/JSON | Parameters (content-gated) | +| `dify.tool.config` | string/JSON | Configuration (content-gated) | + +#### `dify.moderation.check` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.moderation.check"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.moderation.type` | string | `input`, `output` | +| `dify.moderation.action` | string | `pass`, `block`, `flag` | +| `dify.moderation.flagged` | boolean | Whether flagged | +| `dify.moderation.categories` | JSON array | Flagged categories | +| `dify.moderation.query` | string | Content (content-gated) | + +#### `dify.suggested_question.generation` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.suggested_question.generation"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.suggested_question.count` | int | Number of questions | +| `dify.suggested_question.duration` | float | Duration (seconds) | +| `dify.suggested_question.status` | string | `succeeded`, `failed` | +| `dify.suggested_question.error` | string | Error message (if failed) | +| `dify.suggested_question.questions` | JSON array | Questions (content-gated) | + +#### `dify.dataset.retrieval` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.dataset.retrieval"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.dataset.id` | string | Dataset identifier | +| `dify.dataset.name` | string | Dataset name | +| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) | +| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) | +| `dify.retrieval.rerank_provider` | string | Rerank model provider | +| `dify.retrieval.rerank_model` | string | Rerank model name | +| `dify.retrieval.query` | string | Search query (content-gated) | +| `dify.retrieval.document_count` | int | Documents retrieved | +| `dify.retrieval.duration` | float | Duration (seconds) | +| `dify.retrieval.status` | string | `succeeded`, `failed` | +| `dify.retrieval.error` | string | Error message (if failed) | +| `dify.dataset.documents` | JSON array | Documents (content-gated) | + +#### `dify.generate_name.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.generate_name.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.conversation.id` | string | Conversation identifier | +| `dify.generate_name.duration` | float | Duration (seconds) | +| `dify.generate_name.status` | string | `succeeded`, `failed` | +| `dify.generate_name.error` | string | Error message (if failed) | +| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) | +| `dify.generate_name.outputs` | string | Generated name (content-gated) | + +#### `dify.prompt_generation.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.prompt_generation.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.prompt_generation.duration` | float | Duration (seconds) | +| `dify.prompt_generation.status` | string | `succeeded`, `failed` | +| `dify.prompt_generation.error` | string | Error message (if failed) | +| `dify.prompt_generation.instruction` | string | Instruction (content-gated) | +| `dify.prompt_generation.output` | string/JSON | Output (content-gated) | + +#### `dify.app.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` | +| `dify.app.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.updated` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.updated"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.updated_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.deleted` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.deleted"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.deleted_at` | string | Timestamp (ISO 8601) | + +#### `dify.feedback.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.feedback.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.feedback.rating` | string | `like`, `dislike`, `null` | +| `dify.feedback.content` | string | Feedback text (content-gated) | +| `dify.feedback.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.telemetry.rehydration_failed` + +Diagnostic event for telemetry system health monitoring. + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.telemetry.error` | string | Error message | +| `dify.telemetry.payload_type` | string | Payload type (see appendix) | +| `dify.telemetry.correlation_id` | string | Correlation ID | + +## Content-Gated Attributes + +When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`). + +| Attribute | Signal | +|-----------|--------| +| `dify.workflow.inputs` | `dify.workflow.run` | +| `dify.workflow.outputs` | `dify.workflow.run` | +| `dify.workflow.query` | `dify.workflow.run` | +| `dify.node.inputs` | `dify.node.execution` | +| `dify.node.outputs` | `dify.node.execution` | +| `dify.node.process_data` | `dify.node.execution` | +| `dify.message.inputs` | `dify.message.run` | +| `dify.message.outputs` | `dify.message.run` | +| `dify.tool.inputs` | `dify.tool.execution` | +| `dify.tool.outputs` | `dify.tool.execution` | +| `dify.tool.parameters` | `dify.tool.execution` | +| `dify.tool.config` | `dify.tool.execution` | +| `dify.moderation.query` | `dify.moderation.check` | +| `dify.suggested_question.questions` | `dify.suggested_question.generation` | +| `dify.retrieval.query` | `dify.dataset.retrieval` | +| `dify.dataset.documents` | `dify.dataset.retrieval` | +| `dify.generate_name.inputs` | `dify.generate_name.execution` | +| `dify.generate_name.outputs` | `dify.generate_name.execution` | +| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` | +| `dify.prompt_generation.output` | `dify.prompt_generation.execution` | +| `dify.feedback.content` | `dify.feedback.created` | + +## Appendix + +### Operation Types + +- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify` + +### Node Types + +- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input` + +### Workflow Statuses + +- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused` + +### Payload Types + +- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback` + +### Null Value Behavior + +**Spans:** Attributes with `null` values are omitted. + +**Logs:** Attributes with `null` values appear as `null` in JSON. + +**Content-Gated:** Replaced with reference strings, not set to `null`. diff --git a/api/enterprise/telemetry/README.md b/api/enterprise/telemetry/README.md new file mode 100644 index 0000000000..e43c0b1ea2 --- /dev/null +++ b/api/enterprise/telemetry/README.md @@ -0,0 +1,121 @@ +# Dify Enterprise Telemetry + +This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb. + +## Overview + +Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage. + +- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes). +- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`. +- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking. + +### Signal Architecture + +```mermaid +graph TD + A[Workflow Run] -->|Span| B(dify.workflow.run) + A -->|Log| C(dify.workflow.run detail) + B ---|trace_id| C + + D[Node Execution] -->|Span| E(dify.node.execution) + D -->|Log| F(dify.node.execution detail) + E ---|span_id| F + + G[Message/Tool/etc] -->|Log| H(dify.* event) + G -->|Metric| I(dify.* counter/histogram) +``` + +## Configuration + +The Enterprise OTEL exporter is configured via environment variables. + +| Variable | Description | Default | +|----------|-------------|---------| +| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` | +| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` | +| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - | +| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - | +| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` | +| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - | +| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `false` | +| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` | +| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` | + +## Correlation Model + +Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks. + +### ID Generation Rules + +- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))` +- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)` + +### Scenario A: Simple Workflow + +A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`). + +``` +trace_id = UUID(workflow_run_id) +├── [root span] dify.workflow.run (span_id = hash(workflow_run_id)) +│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1)) +│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2)) +│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3)) +``` + +### Scenario B: Nested Sub-Workflow + +A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id. + +``` +trace_id = UUID(outer_workflow_run_id) ← shared across both workflows +├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id)) +│ ├── dify.node.execution - "Start Node" +│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow) +│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id)) +│ │ ├── dify.node.execution - "Inner Start" +│ │ └── dify.node.execution - "Inner End" +│ └── dify.node.execution - "End Node" +``` + +**Key attributes for nested workflows:** + +- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id` +- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.app.id` = outer `app_id` + +### Scenario C: Draft Node Execution + +A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root. + +``` +trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow +└── dify.node.execution.draft (span_id = hash(node_execution_id)) +``` + +**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace. + +## Content Gating + +When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector. + +**Reference String Format:** + +``` +ref:{id_type}={uuid} +``` + +**Examples:** + +``` +ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000 +ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001 +ref:message_id=770e8400-e29b-41d4-a716-446655440002 +``` + +To retrieve actual content when gating is enabled, query the Dify database using the provided UUID. + +## Reference + +For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md). diff --git a/api/enterprise/telemetry/__init__.py b/api/enterprise/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/enterprise/telemetry/contracts.py b/api/enterprise/telemetry/contracts.py new file mode 100644 index 0000000000..91398cb8cb --- /dev/null +++ b/api/enterprise/telemetry/contracts.py @@ -0,0 +1,73 @@ +"""Telemetry gateway contracts and data structures. + +This module defines the envelope format for telemetry events and the routing +configuration that determines how each event type is processed. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class TelemetryCase(StrEnum): + """Enumeration of all known telemetry event cases.""" + + WORKFLOW_RUN = "workflow_run" + NODE_EXECUTION = "node_execution" + DRAFT_NODE_EXECUTION = "draft_node_execution" + MESSAGE_RUN = "message_run" + TOOL_EXECUTION = "tool_execution" + MODERATION_CHECK = "moderation_check" + SUGGESTED_QUESTION = "suggested_question" + DATASET_RETRIEVAL = "dataset_retrieval" + GENERATE_NAME = "generate_name" + PROMPT_GENERATION = "prompt_generation" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + FEEDBACK_CREATED = "feedback_created" + + +class SignalType(StrEnum): + """Signal routing type for telemetry cases.""" + + TRACE = "trace" + METRIC_LOG = "metric_log" + + +class CaseRoute(BaseModel): + """Routing configuration for a telemetry case. + + Attributes: + signal_type: The type of signal (trace or metric_log). + ce_eligible: Whether this case is eligible for community edition tracing. + """ + + signal_type: SignalType + ce_eligible: bool + + +class TelemetryEnvelope(BaseModel): + """Envelope for telemetry events. + + Attributes: + case: The telemetry case type. + tenant_id: The tenant identifier. + event_id: Unique event identifier for deduplication. + payload: The main event payload (inline for small payloads, + empty when offloaded to storage via ``payload_ref``). + metadata: Optional metadata dictionary. When the gateway + offloads a large payload to object storage, this contains + ``{"payload_ref": ""}``. + """ + + model_config = ConfigDict(extra="forbid", use_enum_values=False) + + case: TelemetryCase + tenant_id: str + event_id: str + payload: dict[str, Any] + metadata: dict[str, Any] | None = None diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py new file mode 100644 index 0000000000..dff558988c --- /dev/null +++ b/api/enterprise/telemetry/draft_trace.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from graphon.enums import WorkflowNodeExecutionMetadataKey +from models.workflow import WorkflowNodeExecutionModel + + +def enqueue_draft_node_execution_trace( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, + user_id: str, +) -> None: + node_data = _build_node_execution_data( + execution=execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=execution.tenant_id, + user_id=user_id, + app_id=execution.app_id, + ), + payload={"node_execution_data": node_data}, + ) + ) + + +def _build_node_execution_data( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, +) -> dict[str, Any]: + metadata = execution.execution_metadata_dict + node_outputs = outputs if outputs is not None else execution.outputs_dict + execution_id = workflow_execution_id or execution.workflow_run_id or execution.id + process_data = execution.process_data_dict or {} + + # Extract token breakdown from outputs.usage (set by LLM node) + usage: Mapping[str, Any] = {} + if isinstance(node_outputs, Mapping): + raw_usage = node_outputs.get("usage") + if isinstance(raw_usage, Mapping): + usage = raw_usage + + return { + "workflow_id": execution.workflow_id, + "workflow_execution_id": execution_id, + "tenant_id": execution.tenant_id, + "app_id": execution.app_id, + "node_execution_id": execution.id, + "node_id": execution.node_id, + "node_type": execution.node_type, + "title": execution.title, + "status": execution.status, + "error": execution.error, + "elapsed_time": execution.elapsed_time, + "index": execution.index, + "predecessor_node_id": execution.predecessor_node_id, + "created_at": execution.created_at, + "finished_at": execution.finished_at, + "total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "model_provider": process_data.get("model_provider"), + "model_name": process_data.get("model_name"), + "prompt_tokens": usage.get("prompt_tokens"), + "completion_tokens": usage.get("completion_tokens"), + "tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": execution.inputs_dict, + "node_outputs": node_outputs, + "process_data": execution.process_data_dict, + } diff --git a/api/enterprise/telemetry/enterprise_trace.py b/api/enterprise/telemetry/enterprise_trace.py new file mode 100644 index 0000000000..fc17d9d93e --- /dev/null +++ b/api/enterprise/telemetry/enterprise_trace.py @@ -0,0 +1,966 @@ +"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass. + +Invoked directly in the Celery task, not through OpsTraceManager dispatch. +Only requires a matching ``trace(trace_info)`` method signature. + +Signal strategy: +- **Traces (spans)**: workflow run, node execution, draft node execution only. +- **Metrics + structured logs**: all other event types. + +Token metric labels (unified structure): +All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the +same label set for consistent filtering and aggregation: +- tenant_id: Tenant identifier +- app_id: Application identifier +- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.) +- model_provider: LLM provider name (empty string if not applicable) +- model_name: LLM model name (empty string if not applicable) +- node_type: Workflow node type (empty string if not node_execution) + +This unified structure allows filtering by operation_type to separate: +- Workflow-level aggregates (operation_type=workflow) +- Individual node executions (operation_type=node_execution) +- Direct message calls (operation_type=message) +- Prompt generation operations (operation_type=rule_generate, code_generate, etc.) + +Without this, tokens are double-counted when querying totals (workflow totals include +node totals, since workflow.total_tokens is the sum of all node tokens). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, cast + +from opentelemetry.util.types import AttributeValue + +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + OperationType, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryEvent, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, + TokenMetricLabels, +) +from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log + +logger = logging.getLogger(__name__) + + +class EnterpriseOtelTrace: + """Duck-typed enterprise trace handler. + + ``*_trace`` methods emit spans (workflow/node only) or structured logs + (all other events), plus metrics at 100 % accuracy. + """ + + def __init__(self) -> None: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if exporter is None: + raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized") + self._exporter = exporter + + def trace(self, trace_info: BaseTraceInfo) -> None: + if isinstance(trace_info, WorkflowTraceInfo): + self._workflow_trace(trace_info) + elif isinstance(trace_info, MessageTraceInfo): + self._message_trace(trace_info) + elif isinstance(trace_info, ToolTraceInfo): + self._tool_trace(trace_info) + elif isinstance(trace_info, DraftNodeExecutionTrace): + self._draft_node_execution_trace(trace_info) + elif isinstance(trace_info, WorkflowNodeTraceInfo): + self._node_execution_trace(trace_info) + elif isinstance(trace_info, ModerationTraceInfo): + self._moderation_trace(trace_info) + elif isinstance(trace_info, SuggestedQuestionTraceInfo): + self._suggested_question_trace(trace_info) + elif isinstance(trace_info, DatasetRetrievalTraceInfo): + self._dataset_retrieval_trace(trace_info) + elif isinstance(trace_info, GenerateNameTraceInfo): + self._generate_name_trace(trace_info) + elif isinstance(trace_info, PromptGenerationTraceInfo): + self._prompt_generation_trace(trace_info) + else: + raise AssertionError("this statment should be unreachable") + + def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + metadata = self._metadata(trace_info) + tenant_id, app_id, user_id = self._context_ids(trace_info, metadata) + return { + "dify.trace_id": trace_info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "dify.message.id": trace_info.message_id, + } + + def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + return trace_info.metadata + + def _context_ids( + self, + trace_info: BaseTraceInfo, + metadata: dict[str, Any], + ) -> tuple[str | None, str | None, str | None]: + tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id") + app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id") + user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id") + return tenant_id, app_id, user_id + + def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]: + return dict(values) + + def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None: + if isinstance(value, str): + return value + if isinstance(value, dict): + return cast(dict[str, Any], value) + if isinstance(value, list): + items: list[object] = [] + for item in cast(list[object], value): + items.append(item) + return items + return None + + def _content_or_ref(self, value: Any, ref: str) -> Any: + if self._exporter.include_content: + return self._maybe_json(value) + return ref + + def _maybe_json(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, default=str) + except (TypeError, ValueError): + return str(value) + + # ------------------------------------------------------------------ + # SPAN-emitting handlers (workflow, node execution, draft node) + # ------------------------------------------------------------------ + + def _workflow_trace(self, info: WorkflowTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.workflow.status": info.workflow_run_status, + "dify.workflow.error": info.error, + "dify.workflow.elapsed_time": info.workflow_run_elapsed_time, + "dify.invoke_from": metadata.get("triggered_from"), + "dify.conversation.id": info.conversation_id, + "dify.message.id": info.message_id, + "dify.invoked_by": info.invoked_by, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.user.id": user_id, + } + + trace_correlation_override, parent_span_id_source = info.resolved_parent_context + + parent_ctx = metadata.get("parent_trace_context") + if isinstance(parent_ctx, dict): + parent_ctx_dict = cast(dict[str, Any], parent_ctx) + span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id") + span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id") + span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id") + span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id") + + self._exporter.export_span( + EnterpriseTelemetrySpan.WORKFLOW_RUN, + span_attrs, + correlation_id=info.workflow_run_id, + span_id_source=info.workflow_run_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + parent_span_id_source=parent_span_id_source, + ) + + # -- Companion log: ALL attrs (span + detail) for full picture -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.workflow.version": info.workflow_run_version, + } + ) + + ref = f"ref:workflow_run_id={info.workflow_run_id}" + log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref) + log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref) + log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref) + + emit_telemetry_log( + event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.workflow_run_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + invoke_from = metadata.get("triggered_from", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="workflow", + status=info.workflow_run_status, + invoke_from=invoke_from, + ), + ) + # Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults + # to 0 in the DB and can be stale if the Celery write races with the trace task. + # start_time = workflow_run.created_at, end_time = workflow_run.finished_at. + if info.start_time and info.end_time: + workflow_duration = (info.end_time - info.start_time).total_seconds() + elif info.workflow_run_elapsed_time: + workflow_duration = float(info.workflow_run_elapsed_time) + else: + workflow_duration = 0.0 + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.WORKFLOW_DURATION, + workflow_duration, + self._labels( + **labels, + status=info.workflow_run_status, + ), + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="workflow", + ), + ) + + def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None: + self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node") + + def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None: + self._emit_node_execution_trace( + info, + EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION, + "draft_node", + correlation_id_override=info.node_execution_id, + trace_correlation_override_param=info.workflow_run_id, + ) + + def _emit_node_execution_trace( + self, + info: WorkflowNodeTraceInfo, + span_name: EnterpriseTelemetrySpan, + request_type: str, + correlation_id_override: str | None = None, + trace_correlation_override_param: str | None = None, + ) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.message.id": info.message_id, + "dify.conversation.id": metadata.get("conversation_id"), + "dify.node.execution_id": info.node_execution_id, + "dify.node.id": info.node_id, + "dify.node.type": info.node_type, + "dify.node.title": info.title, + "dify.node.status": info.status, + "dify.node.error": info.error, + "dify.node.elapsed_time": info.elapsed_time, + "dify.node.index": info.index, + "dify.node.predecessor_node_id": info.predecessor_node_id, + "dify.node.iteration_id": info.iteration_id, + "dify.node.loop_id": info.loop_id, + "dify.node.parallel_id": info.parallel_id, + "dify.node.invoked_by": info.invoked_by, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.request.model": info.model_name, + "gen_ai.provider.name": info.model_provider, + "gen_ai.user.id": user_id, + } + + resolved_override, _ = info.resolved_parent_context + trace_correlation_override = trace_correlation_override_param or resolved_override + + effective_correlation_id = correlation_id_override or info.workflow_run_id + self._exporter.export_span( + span_name, + span_attrs, + correlation_id=effective_correlation_id, + span_id_source=info.node_execution_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + ) + + # -- Companion log: ALL attrs (span + detail) -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.invoke_from": metadata.get("invoke_from"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.node.total_price": info.total_price, + "dify.node.currency": info.currency, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.tool.name": info.tool_name, + "dify.node.iteration_index": info.iteration_index, + "dify.node.loop_index": info.loop_index, + "dify.plugin.name": metadata.get("plugin_name"), + "dify.credential.name": metadata.get("credential_name"), + "dify.credential.id": metadata.get("credential_id"), + "dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")), + "dify.dataset.names": self._maybe_json(metadata.get("dataset_names")), + } + ) + + ref = f"ref:node_execution_id={info.node_execution_id}" + log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref) + log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref) + log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref) + + emit_telemetry_log( + event_name=span_name.value, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + node_type=info.node_type, + model_provider=info.model_provider or "", + ) + if info.total_tokens: + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.NODE_EXECUTION, + model_provider=info.model_provider or "", + model_name=info.model_name or "", + node_type=info.node_type, + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels + ) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type=request_type, + status=info.status, + model_name=info.model_name or "", + ), + ) + duration_labels = dict(labels) + duration_labels["model_name"] = info.model_name or "" + plugin_name = metadata.get("plugin_name") + if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}: + duration_labels["plugin_name"] = plugin_name + self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type=request_type, + model_name=info.model_name or "", + ), + ) + + # ------------------------------------------------------------------ + # METRIC-ONLY handlers (structured log + counters/histograms) + # ------------------------------------------------------------------ + + def _message_trace(self, info: MessageTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.invoke_from": metadata.get("from_source"), + "dify.conversation.id": metadata.get("conversation_id"), + "dify.conversation.mode": info.conversation_mode, + "gen_ai.provider.name": metadata.get("ls_provider"), + "gen_ai.request.model": metadata.get("ls_model_name"), + "gen_ai.usage.input_tokens": info.message_tokens, + "gen_ai.usage.output_tokens": info.answer_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.message.status": metadata.get("status"), + "dify.message.error": info.error, + "dify.message.from_source": metadata.get("from_source"), + "dify.message.from_end_user_id": metadata.get("from_end_user_id"), + "dify.message.from_account_id": metadata.get("from_account_id"), + "dify.streaming": info.is_streaming_request, + "dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token, + "dify.message.streaming_duration": info.llm_streaming_time_to_generate, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + + if info.start_time and info.end_time: + attrs["dify.message.duration"] = (info.end_time - info.start_time).total_seconds() + + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MESSAGE_RUN, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.MESSAGE, + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.message_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels) + if info.answer_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels) + invoke_from = metadata.get("from_source", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="message", + status=metadata.get("status", ""), + invoke_from=invoke_from, + ), + ) + + if info.start_time and info.end_time: + duration = (info.end_time - info.start_time).total_seconds() + self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels) + + if info.gen_ai_server_time_to_first_token is not None: + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="message", + ), + ) + + def _tool_trace(self, info: ToolTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.tool.name": info.tool_name, + "dify.tool.duration": float(info.time_cost), + "dify.tool.status": "failed" if info.error else "succeeded", + "dify.tool.error": info.error, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref) + attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref) + attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref) + attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + tool_name=info.tool_name, + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + + def _moderation_trace(self, info: ModerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.moderation.flagged": info.flagged, + "dify.moderation.action": info.action, + "dify.moderation.preset_response": info.preset_response, + "dify.moderation.type": metadata.get("moderation_type", "input"), + "dify.moderation.categories": self._maybe_json(metadata.get("moderation_categories", [])), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.moderation.query"] = self._content_or_ref( + info.query, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MODERATION_CHECK, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="moderation", + ), + ) + + def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error = info.error or (info.metadata.get("error") if info.metadata else None) + status = "failed" if error else (info.status or "succeeded") + attrs.update( + { + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.suggested_question.status": status, + "dify.suggested_question.error": error, + "dify.suggested_question.duration": duration, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_id, + "dify.suggested_question.count": len(info.suggested_question), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.suggested_question.questions"] = self._content_or_ref( + info.suggested_question, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="suggested_question", + model_provider=info.model_provider or "", + model_name=info.model_id or "", + ), + ) + + def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.retrieval.error"] = info.error + attrs["dify.retrieval.status"] = "failed" if info.error else "succeeded" + if info.start_time and info.end_time: + attrs["dify.retrieval.duration"] = (info.end_time - info.start_time).total_seconds() + attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id") + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + docs: list[dict[str, Any]] = [] + documents_any: Any = info.documents + documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else [] + for entry in documents_list: + if isinstance(entry, dict): + entry_dict: dict[str, Any] = cast(dict[str, Any], entry) + docs.append(entry_dict) + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + structured_docs: list[dict[str, Any]] = [] + for doc in docs: + meta_raw = doc.get("metadata") + meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {} + did = meta.get("dataset_id") + dname = meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + structured_docs.append( + { + "dataset_id": did, + "document_id": meta.get("document_id"), + "segment_id": meta.get("segment_id"), + "score": meta.get("score"), + } + ) + + attrs["dify.dataset.id"] = self._maybe_json(dataset_ids) + attrs["dify.dataset.name"] = self._maybe_json(dataset_names) + attrs["dify.retrieval.document_count"] = len(docs) + + embedding_models_raw: Any = metadata.get("embedding_models") + embedding_models: dict[str, Any] = ( + cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {} + ) + if embedding_models: + providers: list[str] = [] + models: list[str] = [] + for ds_info in embedding_models.values(): + if isinstance(ds_info, dict): + ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info) + p = ds_info_dict.get("embedding_model_provider", "") + m = ds_info_dict.get("embedding_model", "") + if p and p not in providers: + providers.append(p) + if m and m not in models: + models.append(m) + attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers) + attrs["dify.dataset.embedding_models"] = self._maybe_json(models) + + # Add rerank model to logs + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + if rerank_provider or rerank_model: + attrs["dify.retrieval.rerank_provider"] = rerank_provider + attrs["dify.retrieval.rerank_model"] = rerank_model + + ref = f"ref:message_id={info.message_id}" + retrieval_inputs = self._safe_payload_value(info.inputs) + attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref) + attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None), + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="dataset_retrieval", + ), + ) + + for did in dataset_ids: + # Get embedding model for this specific dataset + ds_embedding_info = embedding_models.get(did, {}) + embedding_provider = ds_embedding_info.get("embedding_model_provider", "") + embedding_model = ds_embedding_info.get("embedding_model", "") + + # Get rerank model (same for all datasets in this retrieval) + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + + self._exporter.increment_counter( + EnterpriseTelemetryCounter.DATASET_RETRIEVALS, + 1, + self._labels( + **labels, + dataset_id=did, + embedding_model_provider=embedding_provider, + embedding_model=embedding_model, + rerank_model_provider=rerank_provider, + rerank_model=rerank_model, + ), + ) + + def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.conversation.id"] = info.conversation_id + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error: str | None = metadata.get("error") if metadata else None + status = "failed" if error else "succeeded" + attrs["dify.generate_name.duration"] = duration + attrs["dify.generate_name.status"] = status + attrs["dify.generate_name.error"] = error + + ref = f"ref:conversation_id={info.conversation_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="generate_name", + ), + ) + + def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "gen_ai.user.id": user_id, + "dify.app_id": app_id or "", + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.prompt_generation.operation_type": info.operation_type, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.prompt_generation.duration": info.latency, + "dify.prompt_generation.status": "failed" if info.error else "succeeded", + "dify.prompt_generation.error": info.error, + } + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + if info.total_price is not None: + attrs["dify.prompt_generation.total_price"] = info.total_price + attrs["dify.prompt_generation.currency"] = info.currency + + ref = f"ref:trace_id={info.trace_id}" + outputs = self._safe_payload_value(info.outputs) + attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref) + attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + node_type="", + ).to_dict() + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + ) + + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + + prompt_status = "failed" if info.error else "succeeded" + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="prompt_generation", + status=prompt_status, + ), + ) + + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION, + info.latency, + labels, + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="prompt_generation", + ), + ) diff --git a/api/enterprise/telemetry/entities/__init__.py b/api/enterprise/telemetry/entities/__init__.py new file mode 100644 index 0000000000..4a9bd3dbf8 --- /dev/null +++ b/api/enterprise/telemetry/entities/__init__.py @@ -0,0 +1,121 @@ +from enum import StrEnum +from typing import cast + +from opentelemetry.util.types import AttributeValue +from pydantic import BaseModel, ConfigDict + + +class EnterpriseTelemetrySpan(StrEnum): + WORKFLOW_RUN = "dify.workflow.run" + NODE_EXECUTION = "dify.node.execution" + DRAFT_NODE_EXECUTION = "dify.node.execution.draft" + + +class EnterpriseTelemetryEvent(StrEnum): + """Event names for enterprise telemetry logs.""" + + APP_CREATED = "dify.app.created" + APP_UPDATED = "dify.app.updated" + APP_DELETED = "dify.app.deleted" + FEEDBACK_CREATED = "dify.feedback.created" + WORKFLOW_RUN = "dify.workflow.run" + MESSAGE_RUN = "dify.message.run" + TOOL_EXECUTION = "dify.tool.execution" + MODERATION_CHECK = "dify.moderation.check" + SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation" + DATASET_RETRIEVAL = "dify.dataset.retrieval" + GENERATE_NAME_EXECUTION = "dify.generate_name.execution" + PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution" + REHYDRATION_FAILED = "dify.telemetry.rehydration_failed" + + +class EnterpriseTelemetryCounter(StrEnum): + TOKENS = "tokens" + INPUT_TOKENS = "input_tokens" + OUTPUT_TOKENS = "output_tokens" + REQUESTS = "requests" + ERRORS = "errors" + FEEDBACK = "feedback" + DATASET_RETRIEVALS = "dataset_retrievals" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + + +class EnterpriseTelemetryHistogram(StrEnum): + WORKFLOW_DURATION = "workflow_duration" + NODE_DURATION = "node_duration" + MESSAGE_DURATION = "message_duration" + MESSAGE_TTFT = "message_ttft" + TOOL_DURATION = "tool_duration" + PROMPT_GENERATION_DURATION = "prompt_generation_duration" + + +class TokenMetricLabels(BaseModel): + """Unified label structure for all dify.token.* metrics. + + All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST + use this exact label set to ensure consistent filtering and aggregation across + different operation types. + + Attributes: + tenant_id: Tenant identifier. + app_id: Application identifier. + operation_type: Source of token usage (workflow | node_execution | message | + rule_generate | code_generate | structured_output | instruction_modify). + model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level). + model_name: LLM model name. Empty string if not applicable (e.g., workflow-level). + node_type: Workflow node type. Empty string unless operation_type=node_execution. + + Usage: + labels = TokenMetricLabels( + tenant_id="tenant-123", + app_id="app-456", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, + 100, + labels.to_dict() + ) + + Design rationale: + Without this unified structure, tokens get double-counted when querying totals + because workflow.total_tokens is already the sum of all node tokens. The + operation_type label allows filtering to separate workflow-level aggregates from + node-level detail, while keeping the same label cardinality for consistent queries. + """ + + tenant_id: str + app_id: str + operation_type: str + model_provider: str + model_name: str + node_type: str + + model_config = ConfigDict(extra="forbid", frozen=True) + + def to_dict(self) -> dict[str, AttributeValue]: + return cast( + dict[str, AttributeValue], + { + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "operation_type": self.operation_type, + "model_provider": self.model_provider, + "model_name": self.model_name, + "node_type": self.node_type, + }, + ) + + +__all__ = [ + "EnterpriseTelemetryCounter", + "EnterpriseTelemetryEvent", + "EnterpriseTelemetryHistogram", + "EnterpriseTelemetrySpan", + "TokenMetricLabels", +] diff --git a/api/enterprise/telemetry/event_handlers.py b/api/enterprise/telemetry/event_handlers.py new file mode 100644 index 0000000000..d8b4208c69 --- /dev/null +++ b/api/enterprise/telemetry/event_handlers.py @@ -0,0 +1,72 @@ +"""Blinker signal handlers for enterprise telemetry. + +Registered at import time via ``@signal.connect`` decorators. +Import must happen during ``ext_enterprise_telemetry.init_app()`` to +ensure handlers fire. Each handler delegates to ``core.telemetry.gateway`` +which handles routing, EE-gating, and dispatch. + +All handlers are best-effort: exceptions are caught and logged so that +telemetry failures never break user-facing operations. +""" + +from __future__ import annotations + +import logging + +from events.app_event import app_was_created, app_was_deleted, app_was_updated + +logger = logging.getLogger(__name__) + +__all__ = [ + "_handle_app_created", + "_handle_app_deleted", + "_handle_app_updated", +] + + +@app_was_created.connect +def _handle_app_created(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={ + "app_id": getattr(sender, "id", None), + "mode": getattr(sender, "mode", None), + }, + ) + except Exception: + logger.warning("Failed to emit app_created telemetry", exc_info=True) + + +@app_was_updated.connect +def _handle_app_updated(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_UPDATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_updated telemetry", exc_info=True) + + +@app_was_deleted.connect +def _handle_app_deleted(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_DELETED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_deleted telemetry", exc_info=True) diff --git a/api/enterprise/telemetry/exporter.py b/api/enterprise/telemetry/exporter.py new file mode 100644 index 0000000000..b2f860764f --- /dev/null +++ b/api/enterprise/telemetry/exporter.py @@ -0,0 +1,283 @@ +"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation. + +Uses dedicated TracerProvider and MeterProvider instances (configurable sampling, +independent from ext_otel.py infrastructure). + +Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py). +Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process. +""" + +import logging +import socket +import uuid +from datetime import UTC, datetime +from typing import Any, cast + +from opentelemetry import trace +from opentelemetry.baggage import get_all +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace import SpanContext, TraceFlags +from opentelemetry.util.types import Attributes, AttributeValue + +from configs import dify_config +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_correlation_id, + set_span_id_source, +) + +logger = logging.getLogger(__name__) + + +def is_enterprise_telemetry_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def _parse_otlp_headers(raw: str) -> dict[str, str]: + ctx = W3CBaggagePropagator().extract({"baggage": raw}) + return {k: v for k, v in get_all(ctx).items() if isinstance(v, str)} + + +def _datetime_to_ns(dt: datetime) -> int: + """Convert a datetime to nanoseconds since epoch (OTEL convention).""" + # Ensure we always interpret naive datetimes as UTC instead of local time. + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + else: + dt = dt.astimezone(UTC) + return int(dt.timestamp() * 1_000_000_000) + + +class _ExporterFactory: + def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool): + self._protocol = protocol + self._endpoint = endpoint + self._headers = headers + self._grpc_headers = tuple(headers.items()) if headers else None + self._http_headers = headers or None + self._insecure = insecure + + def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter: + if self._protocol == "grpc": + return GRPCSpanExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else "" + return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers) + + def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter: + if self._protocol == "grpc": + return GRPCMetricExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else "" + return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers) + + +class EnterpriseExporter: + """Shared OTEL exporter for all enterprise telemetry. + + ``export_span`` creates spans with optional real timestamps, deterministic + span/trace IDs, and cross-workflow parent linking. + ``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy. + """ + + def __init__(self, config: object) -> None: + endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "") + headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "") + protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower() + service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify") + sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0) + self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True) + api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "") + + # Auto-detect TLS: https:// uses secure, everything else is insecure + insecure = not endpoint.startswith("https://") + + resource = Resource( + attributes={ + ResourceAttributes.SERVICE_NAME: service_name, + ResourceAttributes.HOST_NAME: socket.gethostname(), + } + ) + sampler = ParentBasedTraceIdRatio(sampling_rate) + id_generator = CorrelationIdGenerator() + self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator) + + headers = _parse_otlp_headers(headers_raw) + if api_key: + if "authorization" in headers: + logger.warning( + "ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains " + "'authorization'; the API key will take precedence." + ) + headers["authorization"] = f"Bearer {api_key}" + factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure) + + trace_exporter = factory.create_trace_exporter() + self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + self._tracer = self._tracer_provider.get_tracer("dify.enterprise") + + metric_exporter = factory.create_metric_exporter() + self._meter_provider = MeterProvider( + resource=resource, + metric_readers=[PeriodicExportingMetricReader(metric_exporter)], + ) + meter = self._meter_provider.get_meter("dify.enterprise") + self._counters = { + EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"), + EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"), + EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"), + EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"), + EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"), + EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"), + EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter( + "dify.dataset.retrievals.total", unit="{retrieval}" + ), + EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"), + } + self._histograms = { + EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"), + EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram( + "dify.message.time_to_first_token", unit="s" + ), + EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"), + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram( + "dify.prompt_generation.duration", unit="s" + ), + } + + def export_span( + self, + name: str, + attributes: dict[str, Any], + correlation_id: str | None = None, + span_id_source: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + trace_correlation_override: str | None = None, + parent_span_id_source: str | None = None, + ) -> None: + """Export an OTEL span with optional deterministic IDs and real timestamps. + + Args: + name: Span operation name. + attributes: Span attributes dict. + correlation_id: Source for trace_id derivation (groups spans in one trace). + span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id). + start_time: Real span start time. When None, uses current time. + end_time: Real span end time. When None, span ends immediately. + trace_correlation_override: Override trace_id source (for cross-workflow linking). + When set, trace_id is derived from this instead of ``correlation_id``. + parent_span_id_source: Override parent span_id source (for cross-workflow linking). + When set, parent span_id is derived from this value. When None and + ``correlation_id`` is set, parent is the workflow root span. + """ + effective_trace_correlation = trace_correlation_override or correlation_id + set_correlation_id(effective_trace_correlation) + set_span_id_source(span_id_source) + + try: + parent_context: Context | None = None + # A span is the "root" of its correlation group when span_id_source == correlation_id + # (i.e. a workflow root span). All other spans are children. + if parent_span_id_source: + # Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow) + parent_span_id = compute_deterministic_span_id(parent_span_id_source) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0 + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for cross-workflow link: %s, span=%s", + effective_trace_correlation, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + elif correlation_id and correlation_id != span_id_source: + # Child span: parent is the correlation-group root (workflow root span) + parent_span_id = compute_deterministic_span_id(correlation_id) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id)) + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for child span link: %s, span=%s", + effective_trace_correlation or correlation_id, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + + span_start_time = _datetime_to_ns(start_time) if start_time is not None else None + span_end_on_exit = end_time is None + + with self._tracer.start_as_current_span( + name, + context=parent_context, + start_time=span_start_time, + end_on_exit=span_end_on_exit, + ) as span: + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + if end_time is not None: + span.end(end_time=_datetime_to_ns(end_time)) + except Exception: + logger.exception("Failed to export span %s", name) + finally: + set_correlation_id(None) + set_span_id_source(None) + + def increment_counter( + self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue] + ) -> None: + counter = self._counters.get(name) + if counter: + counter.add(value, cast(Attributes, labels)) + + def record_histogram( + self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue] + ) -> None: + histogram = self._histograms.get(name) + if histogram: + histogram.record(value, cast(Attributes, labels)) + + def shutdown(self) -> None: + self._tracer_provider.shutdown() + self._meter_provider.shutdown() diff --git a/api/enterprise/telemetry/id_generator.py b/api/enterprise/telemetry/id_generator.py new file mode 100644 index 0000000000..f3e5d6d0d6 --- /dev/null +++ b/api/enterprise/telemetry/id_generator.py @@ -0,0 +1,75 @@ +"""Custom OTEL ID Generator for correlation-based trace/span ID derivation. + +Uses contextvars for thread-safe correlation_id -> trace_id mapping. +When a span_id_source is set, the span_id is derived deterministically +from that value, enabling any span to reference another as parent +without depending on span creation order. +""" + +import random +import uuid +from contextvars import ContextVar + +from opentelemetry.sdk.trace.id_generator import IdGenerator + +_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None) +_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None) + + +def set_correlation_id(correlation_id: str | None) -> None: + _correlation_id_context.set(correlation_id) + + +def get_correlation_id() -> str | None: + return _correlation_id_context.get() + + +def set_span_id_source(source_id: str | None) -> None: + """Set the source for deterministic span_id generation. + + When set, ``generate_span_id()`` derives the span_id from this value + (lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow + root spans or ``node_execution_id`` for node spans. + """ + _span_id_source_context.set(source_id) + + +def compute_deterministic_span_id(source_id: str) -> int: + """Derive a deterministic span_id from any UUID string. + + Uses the lower 64 bits of the UUID, guaranteeing non-zero output + (OTEL requires span_id != 0). + """ + span_id = uuid.UUID(source_id).int & ((1 << 64) - 1) + return span_id if span_id != 0 else 1 + + +class CorrelationIdGenerator(IdGenerator): + """ID generator that derives trace_id and optionally span_id from context. + + - trace_id: always derived from correlation_id (groups all spans in one trace) + - span_id: derived from span_id_source when set (enables deterministic + parent-child linking), otherwise random + """ + + def generate_trace_id(self) -> int: + correlation_id = _correlation_id_context.get() + if correlation_id: + try: + return uuid.UUID(correlation_id).int + except (ValueError, AttributeError): + pass + return random.getrandbits(128) + + def generate_span_id(self) -> int: + source = _span_id_source_context.get() + if source: + try: + return compute_deterministic_span_id(source) + except (ValueError, AttributeError): + pass + + span_id = random.getrandbits(64) + while span_id == 0: + span_id = random.getrandbits(64) + return span_id diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py new file mode 100644 index 0000000000..ffd9a7e2b5 --- /dev/null +++ b/api/enterprise/telemetry/metric_handler.py @@ -0,0 +1,421 @@ +"""Enterprise metric/log event handler. + +This module processes metric and log telemetry events after they've been +dequeued from the enterprise_telemetry Celery queue. It handles case routing, +idempotency checking, and payload rehydration. +""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from typing import Any + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage + +logger = logging.getLogger(__name__) + + +class EnterpriseMetricHandler: + """Handler for enterprise metric and log telemetry events. + + Processes envelopes from the enterprise_telemetry queue, routing each + case to the appropriate handler method. Implements idempotency checking + and payload rehydration with fallback. + """ + + def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None: + """Increment a diagnostic counter for operational monitoring. + + Args: + counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total'). + labels: Optional labels for the counter. + """ + try: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + return + + full_counter_name = f"enterprise_telemetry.handler.{counter_name}" + logger.debug( + "Diagnostic counter: %s, labels=%s", + full_counter_name, + labels or {}, + ) + except Exception: + logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True) + + def handle(self, envelope: TelemetryEnvelope) -> None: + """Main entry point for processing telemetry envelopes. + + Args: + envelope: The telemetry envelope to process. + """ + # Check for duplicate events + if self._is_duplicate(envelope): + logger.debug( + "Skipping duplicate event: tenant_id=%s, event_id=%s", + envelope.tenant_id, + envelope.event_id, + ) + self._increment_diagnostic_counter("deduped_total") + return + + # Route to appropriate handler based on case + case = envelope.case + if case == TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + elif case == TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + elif case == TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + elif case == TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + elif case == TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + elif case == TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + elif case == TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + elif case == TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + elif case == TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + elif case == TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + elif case == TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + else: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) + + def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: + """Check if this event has already been processed. + + Uses Redis with TTL for deduplication. Returns True if duplicate, + False if first time seeing this event. + + Args: + envelope: The telemetry envelope to check. + + Returns: + True if this event_id has been seen before, False otherwise. + """ + dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}" + + try: + # Atomic set-if-not-exists with 1h TTL + # Returns True if key was set (first time), None if already exists (duplicate) + was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600) + return was_set is None + except Exception: + # Fail open: if Redis is unavailable, process the event + # (prefer occasional duplicate over lost data) + logger.warning( + "Redis unavailable for deduplication check, processing event anyway: %s", + envelope.event_id, + exc_info=True, + ) + return False + + def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]: + """Rehydrate payload from storage reference or inline data. + + If the envelope payload is empty and metadata contains a + ``payload_ref``, the full payload is loaded from object storage + (where the gateway wrote it as JSON). When both the inline + payload and storage resolution fail, a degraded-event marker + is emitted so the gap is observable. + + Args: + envelope: The telemetry envelope containing payload data. + + Returns: + The rehydrated payload dictionary, or ``{}`` on total failure. + """ + payload = envelope.payload + + # Resolve from object storage when the gateway offloaded a large payload. + if not payload and envelope.metadata: + payload_ref = envelope.metadata.get("payload_ref") + if payload_ref: + try: + payload_bytes = storage.load(payload_ref) + payload = json.loads(payload_bytes.decode("utf-8")) + logger.debug("Loaded payload from storage: key=%s", payload_ref) + except Exception: + logger.warning( + "Failed to load payload from storage: key=%s, event_id=%s", + payload_ref, + envelope.event_id, + exc_info=True, + ) + + if not payload: + # Storage resolution failed or no data available — emit degraded event. + logger.error( + "Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s", + envelope.event_id, + envelope.tenant_id, + envelope.case, + ) + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED, + attributes={ + "tenant_id": envelope.tenant_id, + "dify.telemetry.error": f"Payload rehydration failed for event_id={envelope.event_id}", + "dify.telemetry.payload_type": envelope.case, + "dify.telemetry.correlation_id": envelope.event_id, + }, + tenant_id=envelope.tenant_id, + ) + self._increment_diagnostic_counter("rehydration_failed_total") + return {} + + return payload + + # Stub methods for each metric/log case + # These will be implemented in later tasks with actual emission logic + + def _on_app_created(self, envelope: TelemetryEnvelope) -> None: + """Handle app created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.mode": payload.get("mode"), + "dify.app.created_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_CREATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "mode": str(payload.get("mode", "")), + }, + ) + + def _on_app_updated(self, envelope: TelemetryEnvelope) -> None: + """Handle app updated event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.updated_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_UPDATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_UPDATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None: + """Handle app deleted event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.deleted_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_DELETED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_DELETED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None: + """Handle feedback created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + include_content = exporter.include_content + attrs: dict = { + "dify.message.id": payload.get("message_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app_id": payload.get("app_id"), + "dify.conversation.id": payload.get("conversation_id"), + "gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"), + "dify.feedback.rating": payload.get("rating"), + "dify.feedback.from_source": payload.get("from_source"), + "dify.feedback.created_at": datetime.now(UTC).isoformat(), + } + if include_content: + attrs["dify.feedback.content"] = payload.get("content") + + user_id = payload.get("from_end_user_id") or payload.get("from_account_id") + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + user_id=str(user_id or ""), + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.FEEDBACK, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "rating": str(payload.get("rating", "")), + }, + ) + + def _on_message_run(self, envelope: TelemetryEnvelope) -> None: + """Handle message run event. + + Intentionally a no-op: metrics and structured logs for message runs are + emitted directly by EnterpriseOtelTrace._message_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id) + + def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None: + """Handle tool execution event. + + Intentionally a no-op: metrics and structured logs for tool executions + are emitted directly by EnterpriseOtelTrace._tool_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id) + + def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None: + """Handle moderation check event. + + Intentionally a no-op: metrics and structured logs for moderation checks + are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id) + + def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None: + """Handle suggested question event. + + Intentionally a no-op: metrics and structured logs for suggested questions + are emitted directly by EnterpriseOtelTrace._suggested_question_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id) + + def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None: + """Handle dataset retrieval event. + + Intentionally a no-op: metrics and structured logs for dataset retrievals + are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id) + + def _on_generate_name(self, envelope: TelemetryEnvelope) -> None: + """Handle generate name event. + + Intentionally a no-op: metrics and structured logs for generate name + operations are emitted directly by EnterpriseOtelTrace._generate_name_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id) + + def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None: + """Handle prompt generation event. + + Intentionally a no-op: metrics and structured logs for prompt generation + operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id) diff --git a/api/enterprise/telemetry/telemetry_log.py b/api/enterprise/telemetry/telemetry_log.py new file mode 100644 index 0000000000..8cce4a9fcd --- /dev/null +++ b/api/enterprise/telemetry/telemetry_log.py @@ -0,0 +1,122 @@ +"""Structured-log emitter for enterprise telemetry events. + +Emits structured JSON log lines correlated with OTEL traces via trace_id. +Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic. +""" + +from __future__ import annotations + +import logging +import uuid +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + +logger = logging.getLogger("dify.telemetry") + + +@lru_cache(maxsize=4096) +def compute_trace_id_hex(uuid_str: str | None) -> str: + """Convert a business UUID string to a 32-hex OTEL-compatible trace_id. + + Returns empty string when *uuid_str* is ``None`` or invalid. + """ + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + return f"{uuid.UUID(normalized).int:032x}" + except (ValueError, AttributeError): + return "" + + +@lru_cache(maxsize=4096) +def compute_span_id_hex(uuid_str: str | None) -> str: + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + return f"{compute_deterministic_span_id(normalized):016x}" + except (ValueError, AttributeError): + return "" + + +def emit_telemetry_log( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + signal: str = "metric_only", + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + """Emit a structured log line for a telemetry event. + + Parameters + ---------- + event_name: + Canonical event name, e.g. ``"dify.workflow.run"``. + attributes: + All event-specific attributes (already built by the caller). + signal: + ``"metric_only"`` for events with no span, ``"span_detail"`` + for detail logs accompanying a slim span. + trace_id_source: + A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex + trace_id for cross-signal correlation. + tenant_id: + Tenant identifier (for the ``IdentityContextFilter``). + user_id: + User identifier (for the ``IdentityContextFilter``). + """ + if not logger.isEnabledFor(logging.INFO): + return + attrs = { + "dify.event.name": event_name, + "dify.event.signal": signal, + **attributes, + } + + extra: dict[str, Any] = {"attributes": attrs} + + trace_id_hex = compute_trace_id_hex(trace_id_source) + if trace_id_hex: + extra["trace_id"] = trace_id_hex + span_id_hex = compute_span_id_hex(span_id_source) + if span_id_hex: + extra["span_id"] = span_id_hex + if tenant_id: + extra["tenant_id"] = tenant_id + if user_id: + extra["user_id"] = user_id + + logger.info("telemetry.%s", signal, extra=extra) + + +def emit_metric_only_event( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + emit_telemetry_log( + event_name=event_name, + attributes=attributes, + signal="metric_only", + trace_id_source=trace_id_source, + span_id_source=span_id_source, + tenant_id=tenant_id, + user_id=user_id, + ) diff --git a/api/events/app_event.py b/api/events/app_event.py index f2ce71bbbb..2fba0028f9 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -11,3 +11,9 @@ app_published_workflow_was_updated = signal("app-published-workflow-was-updated" # sender: app, kwargs: synced_draft_workflow app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") + +# sender: app +app_was_updated = signal("app-was-updated") + +# sender: app +app_was_deleted = signal("app-was-deleted") diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 367a4c1ede..4eed34436a 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery: "schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL), } + if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED: + imports.append("tasks.enterprise_telemetry_task") celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_enterprise_telemetry.py b/api/extensions/ext_enterprise_telemetry.py new file mode 100644 index 0000000000..b3cfa01aee --- /dev/null +++ b/api/extensions/ext_enterprise_telemetry.py @@ -0,0 +1,50 @@ +"""Flask extension for enterprise telemetry lifecycle management. + +Initializes the EnterpriseExporter singleton during ``create_app()`` +(single-threaded), registers blinker event handlers, and hooks atexit +for graceful shutdown. + +Skipped entirely when either ``ENTERPRISE_ENABLED`` or ``ENTERPRISE_TELEMETRY_ENABLED`` +is false (``is_enabled()`` gate). +""" + +from __future__ import annotations + +import atexit +import logging +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from dify_app import DifyApp + from enterprise.telemetry.exporter import EnterpriseExporter + +logger = logging.getLogger(__name__) + +_exporter: EnterpriseExporter | None = None + + +def is_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def init_app(app: DifyApp) -> None: + global _exporter + + if not is_enabled(): + return + + from enterprise.telemetry.exporter import EnterpriseExporter + + _exporter = EnterpriseExporter(dify_config) + atexit.register(_exporter.shutdown) + + # Import to trigger @signal.connect decorator registration + import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport] + + logger.info("Enterprise telemetry initialized") + + +def get_enterprise_exporter() -> EnterpriseExporter | None: + return _exporter diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index a5baa21018..63edbe93e7 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -78,16 +78,24 @@ def init_app(app: DifyApp): protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower() if dify_config.OTEL_EXPORTER_TYPE == "otlp": if protocol == "grpc": + # Auto-detect TLS: https:// uses secure, everything else is insecure + endpoint = dify_config.OTLP_BASE_ENDPOINT + insecure = not endpoint.startswith("https://") + + # Header field names must consist of lowercase letters, check RFC7540 + grpc_headers = ( + (("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else () + ) + exporter = GRPCSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - # Header field names must consist of lowercase letters, check RFC7540 - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) metric_exporter = GRPCMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) else: headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py index 164db7c275..c671e8b409 100644 --- a/api/extensions/otel/parser/__init__.py +++ b/api/extensions/otel/parser/__init__.py @@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set OpenTelemetry span attributes according to semantic conventions. """ -from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.parser.llm import LLMNodeOTelParser from extensions.otel.parser.retrieval import RetrievalNodeOTelParser from extensions.otel.parser.tool import ToolNodeOTelParser @@ -17,4 +17,5 @@ __all__ = [ "RetrievalNodeOTelParser", "ToolNodeOTelParser", "safe_json_dumps", + "should_include_content", ] diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index a2f552cac1..eefcaa126e 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -1,5 +1,10 @@ """ Base parser interface and utilities for OpenTelemetry node parsers. + +Content gating: ``should_include_content()`` controls whether content-bearing +span attributes (inputs, outputs, prompts, completions, documents) are written. +Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when +``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged. """ import json @@ -9,6 +14,7 @@ from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel +from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes from graphon.enums import BuiltinNodeTypes from graphon.file.models import File @@ -17,6 +23,16 @@ from graphon.nodes.base.node import Node from graphon.variables import Segment +def should_include_content() -> bool: + """Return True if content should be written to spans. + + CE (ENTERPRISE_ENABLED=False): always True — no behaviour change. + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + return dify_config.ENTERPRISE_INCLUDE_CONTENT + + def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str: """ Safely serialize objects to JSON, handling non-serializable types. @@ -101,10 +117,11 @@ class DefaultNodeOTelParser: # Extract inputs and outputs from result_event if result_event and result_event.node_run_result: node_run_result = result_event.node_run_result - if node_run_result.inputs: - span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) - if node_run_result.outputs: - span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) + if should_include_content(): + if node_run_result.inputs: + span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) + if node_run_result.outputs: + span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) if error: span.record_exception(error) diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py index a20b9b358d..301ddd11aa 100644 --- a/api/extensions/otel/semconv/dify.py +++ b/api/extensions/otel/semconv/dify.py @@ -21,3 +21,15 @@ class DifySpanAttributes: INVOKE_FROM = "dify.invoke_from" """Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER.""" + + INVOKED_BY = "dify.invoked_by" + """Invoked by, e.g. end_user, account, user.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens (prompt tokens) used.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens (completion tokens) generated.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens used.""" diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index dc0adbf50d..cf002df2a9 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -109,6 +109,15 @@ core/trigger/debug/event_selectors.py core/trigger/entities/entities.py core/trigger/provider.py core/workflow/workflow_entry.py +enterprise/telemetry/contracts.py +enterprise/telemetry/draft_trace.py +enterprise/telemetry/enterprise_trace.py +enterprise/telemetry/entities/__init__.py +enterprise/telemetry/event_handlers.py +enterprise/telemetry/exporter.py +enterprise/telemetry/id_generator.py +enterprise/telemetry/metric_handler.py +enterprise/telemetry/telemetry_log.py graphon/entities/workflow_execution.py graphon/file/file_manager.py graphon/graph_engine/error_handler.py diff --git a/api/services/app_service.py b/api/services/app_service.py index a9ec357455..9413a93fc4 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -12,7 +12,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from events.app_event import app_was_created +from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -281,6 +281,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_name(self, app: App, name: str) -> App: @@ -296,6 +298,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: @@ -313,6 +317,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_site_status(self, app: App, enable_site: bool) -> App: @@ -330,6 +336,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_api_status(self, app: App, enable_api: bool) -> App: @@ -348,6 +356,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def delete_app(self, app: App): @@ -355,6 +365,8 @@ class AppService: Delete app :param app: App instance """ + app_was_deleted.send(app) + db.session.delete(app) db.session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 8a28537528..46a6221fcc 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -46,6 +46,7 @@ from core.workflow.system_variables import ( ) from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, @@ -577,6 +578,13 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + if workflow_node_execution_db_model is not None: + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=account.id, + ) return workflow_node_execution_db_model def run_datasource_workflow_node( @@ -1339,6 +1347,12 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=current_user.id, + ) return workflow_node_execution_db_model def get_recommended_plugins(self, type: str) -> dict: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 785f6f108c..bef99458be 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -27,6 +27,7 @@ from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_deb from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -849,6 +850,13 @@ class WorkflowService: draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution, + outputs=outputs, + workflow_execution_id=None, + user_id=account.id, + ) + return workflow_node_execution def get_human_input_form_preview( diff --git a/api/tasks/enterprise_telemetry_task.py b/api/tasks/enterprise_telemetry_task.py new file mode 100644 index 0000000000..7d5ea7c0a5 --- /dev/null +++ b/api/tasks/enterprise_telemetry_task.py @@ -0,0 +1,52 @@ +"""Celery worker for enterprise metric/log telemetry events. + +This module defines the Celery task that processes telemetry envelopes +from the enterprise_telemetry queue. It deserializes envelopes and +dispatches them to the EnterpriseMetricHandler. +""" + +import json +import logging + +from celery import shared_task + +from enterprise.telemetry.contracts import TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + +logger = logging.getLogger(__name__) + + +@shared_task(queue="enterprise_telemetry") +def process_enterprise_telemetry(envelope_json: str) -> None: + """Process enterprise metric/log telemetry envelope. + + This task is enqueued by the TelemetryGateway for metric/log-only + events. It deserializes the envelope and dispatches to the handler. + + Best-effort processing: logs errors but never raises, to avoid + failing user requests due to telemetry issues. + + Args: + envelope_json: JSON-serialized TelemetryEnvelope. + """ + try: + # Deserialize envelope + envelope_dict = json.loads(envelope_json) + envelope = TelemetryEnvelope.model_validate(envelope_dict) + + # Process through handler + handler = EnterpriseMetricHandler() + handler.handle(envelope) + + logger.debug( + "Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s", + envelope.tenant_id, + envelope.event_id, + envelope.case, + ) + except Exception: + # Best-effort: log and drop on error, never fail user request + logger.warning( + "Failed to process enterprise telemetry envelope, dropping event", + exc_info=True, + ) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 72e3b42ca7..c95b8db078 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -39,17 +39,36 @@ def process_trace_tasks(file_info): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + + from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled + + if is_ee_telemetry_enabled(): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + try: + EnterpriseOtelTrace().trace(trace_info) + except Exception: + logger.exception("Enterprise trace failed for app_id: %s", app_id) + if trace_instance: with current_app.app_context(): - trace_type = trace_info_info_map.get(trace_info_type) - if trace_type: - trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) + logger.info("Processing trace tasks success, app_id: %s", app_id) except Exception as e: - logger.info("error:\n\n\n%s\n\n\n\n", e) + logger.exception("Processing trace tasks failed, app_id: %s", app_id) failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" redis_client.incr(failed_key) - logger.info("Processing trace tasks failed, app_id: %s", app_id) finally: - storage.delete(file_path) + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tests/unit_tests/core/ops/test_lookup_helpers.py b/api/tests/unit_tests/core/ops/test_lookup_helpers.py new file mode 100644 index 0000000000..86aa68643d --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_lookup_helpers.py @@ -0,0 +1,554 @@ +"""Unit tests for lookup helper functions in core.ops.ops_trace_manager. + +Covers: +- _lookup_app_and_workspace_names +- _lookup_credential_name +- _lookup_llm_credential_info +- TraceTask._get_user_id_from_metadata +""" + +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_db_and_session_patches(scalar_side_effect=None, scalar_return_value=None): + """Return (mock_db, cm, session) ready to patch 'core.ops.ops_trace_manager.db' + and 'core.ops.ops_trace_manager.Session'. + + Provide either scalar_side_effect (list, for multiple calls) or + scalar_return_value (single value). + """ + mock_db = MagicMock() + mock_db.engine = MagicMock() + + session = MagicMock() + if scalar_side_effect is not None: + session.scalar.side_effect = scalar_side_effect + else: + session.scalar.return_value = scalar_return_value + + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=session) + cm.__exit__ = MagicMock(return_value=False) + + return mock_db, cm, session + + +# --------------------------------------------------------------------------- +# _lookup_app_and_workspace_names +# --------------------------------------------------------------------------- + + +class TestLookupAppAndWorkspaceNames: + """Tests for _lookup_app_and_workspace_names(app_id, tenant_id).""" + + def test_both_found(self): + """Returns (app_name, workspace_name) when both records exist.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "MyWorkspace" + + def test_app_only_found(self): + """Returns (app_name, '') when tenant record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "" + + def test_tenant_only_found(self): + """Returns ('', workspace_name) when app record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "MyWorkspace" + + def test_neither_found(self): + """Returns ('', '') when both DB lookups return None.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "" + + def test_none_inputs_skips_db(self): + """Returns ('', '') immediately when both IDs are None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, None) + + mock_session_cls.assert_not_called() + assert app_name == "" + assert workspace_name == "" + + def test_app_id_none_only_queries_tenant(self): + """When app_id is None, only the tenant query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyWorkspace") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, "tenant-456") + + assert app_name == "" + assert workspace_name == "OnlyWorkspace" + assert session.scalar.call_count == 1 + + def test_tenant_id_none_only_queries_app(self): + """When tenant_id is None, only the app query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyApp") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", None) + + assert app_name == "OnlyApp" + assert workspace_name == "" + assert session.scalar.call_count == 1 + + +# --------------------------------------------------------------------------- +# _lookup_credential_name +# --------------------------------------------------------------------------- + + +class TestLookupCredentialName: + """Tests for _lookup_credential_name(credential_id, provider_type).""" + + @pytest.mark.parametrize("provider_type", ["builtin", "plugin", "api", "workflow", "mcp"]) + def test_known_provider_types_return_name(self, provider_type): + """Each valid provider_type results in a DB query and returns the credential name.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="CredentialA") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-123", provider_type) + + assert result == "CredentialA" + session.scalar.assert_called_once() + + def test_credential_not_found_returns_empty_string(self): + """Returns '' when DB yields None for the given credential_id.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-999", "api") + + assert result == "" + + def test_invalid_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately for an unrecognised provider_type — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", "unknown_type") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_credential_id_returns_empty_string_without_db(self): + """Returns '' immediately when credential_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name(None, "api") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately when provider_type is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", None) + + mock_session_cls.assert_not_called() + assert result == "" + + def test_builtin_and_plugin_map_to_same_model(self): + """Both 'builtin' and 'plugin' provider_types query BuiltinToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import BuiltinToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["builtin"] is BuiltinToolProvider + assert _PROVIDER_TYPE_TO_MODEL["plugin"] is BuiltinToolProvider + + def test_api_maps_to_api_tool_provider(self): + """'api' maps to ApiToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import ApiToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["api"] is ApiToolProvider + + def test_workflow_maps_to_workflow_tool_provider(self): + """'workflow' maps to WorkflowToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import WorkflowToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["workflow"] is WorkflowToolProvider + + def test_mcp_maps_to_mcp_tool_provider(self): + """'mcp' maps to MCPToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import MCPToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["mcp"] is MCPToolProvider + + +# --------------------------------------------------------------------------- +# _lookup_llm_credential_info +# --------------------------------------------------------------------------- + + +class TestLookupLlmCredentialInfo: + """Tests for _lookup_llm_credential_info(tenant_id, provider, model, model_type).""" + + def _provider_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def _model_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def test_model_level_credential_found(self): + """Returns model-level credential_id and name when ProviderModel has a credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id="model-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel, (3) ProviderModelCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ModelCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "model-cred-id" + assert cred_name == "ModelCredName" + + def test_provider_level_fallback_when_no_model_credential(self): + """Falls back to provider-level credential when ProviderModel has no credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + model_record = self._model_record(credential_id=None) + + # scalar calls: (1) Provider, (2) ProviderModel (no cred), (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ProvCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_provider_level_fallback_when_no_model_record(self): + """Falls back to provider-level credential when no ProviderModel row exists.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel → None, (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, None, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_no_model_arg_uses_provider_level_only(self): + """When model is None, skips ProviderModel query and uses provider credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderCredential.credential_name — no ProviderModel + mock_db, cm, session = _make_db_and_session_patches(scalar_side_effect=[provider_record, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", None) + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + assert session.scalar.call_count == 2 + + def test_provider_not_found_returns_none_and_empty(self): + """Returns (None, '') when Provider record does not exist.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_none_tenant_id_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when tenant_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info(None, "openai", "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_none_provider_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when provider is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", None, "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_db_error_on_outer_query_returns_none_and_empty(self): + """Returns (None, '') and logs a warning when the outer DB query raises.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, session = _make_db_and_session_patches() + session.scalar.side_effect = Exception("DB connection failed") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_credential_name_lookup_failure_returns_id_with_empty_name(self): + """When credential name sub-query fails, returns cred_id but '' for name.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # Provider found, no model record, then name lookup raises + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, None, Exception("deleted")] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "" + + def test_no_credential_on_provider_or_model_returns_none_id(self): + """Returns (None, '') when neither provider nor model has a credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id=None) + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, model_record]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + +# --------------------------------------------------------------------------- +# TraceTask._get_user_id_from_metadata +# --------------------------------------------------------------------------- + + +class TestGetUserIdFromMetadata: + """Tests for TraceTask._get_user_id_from_metadata(metadata). + + Pure dict logic — no DB access required. + """ + + @pytest.fixture + def get_user_id(self): + """Return the classmethod under test.""" + from core.ops.ops_trace_manager import TraceTask + + return TraceTask._get_user_id_from_metadata + + def test_from_end_user_id_has_highest_priority(self, get_user_id): + """from_end_user_id takes precedence over all other keys.""" + metadata = { + "from_end_user_id": "eu-abc", + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "end_user:eu-abc" + + def test_from_account_id_used_when_no_end_user(self, get_user_id): + """from_account_id is used when from_end_user_id is absent.""" + metadata = { + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_user_id_used_when_no_end_user_or_account(self, get_user_id): + """user_id is used when both higher-priority keys are absent.""" + metadata = {"user_id": "u-123"} + assert get_user_id(metadata) == "user:u-123" + + def test_returns_anonymous_when_all_keys_absent(self, get_user_id): + """Returns 'anonymous' when metadata has none of the expected keys.""" + assert get_user_id({}) == "anonymous" + + def test_empty_string_end_user_id_is_skipped(self, get_user_id): + """Empty string for from_end_user_id is falsy and falls through to next key.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "acc-xyz", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_empty_string_account_id_is_skipped(self, get_user_id): + """Empty string for from_account_id is falsy and falls through to user_id.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "u-123", + } + assert get_user_id(metadata) == "user:u-123" + + def test_empty_string_user_id_falls_through_to_anonymous(self, get_user_id): + """Empty string for user_id is falsy, so 'anonymous' is returned.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "", + } + assert get_user_id(metadata) == "anonymous" + + def test_only_from_end_user_id_present(self, get_user_id): + """Minimal case: only from_end_user_id present.""" + assert get_user_id({"from_end_user_id": "eu-only"}) == "end_user:eu-only" + + def test_irrelevant_keys_do_not_interfere(self, get_user_id): + """Extra metadata keys have no effect on the result.""" + metadata = {"invoke_from": "web", "app_id": "a1"} + assert get_user_id(metadata) == "anonymous" diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index 2d325ccb0e..f81806c941 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -86,6 +86,7 @@ def make_message_data(**overrides): created_at = datetime(2025, 2, 20, 12, 0, 0) base = { "id": "msg-id", + "app_id": "app-id", "conversation_id": "conv-id", "created_at": created_at, "updated_at": created_at + timedelta(seconds=3), @@ -182,6 +183,9 @@ class DummySessionContext: def __exit__(self, exc_type, exc_val, exc_tb): return False + def execute(self, *args, **kwargs): + return self + def scalar(self, *args, **kwargs): if self._index >= len(self._values): return None @@ -189,6 +193,12 @@ class DummySessionContext: self._index += 1 return value + def scalars(self, *args, **kwargs): + return self + + def all(self): + return [] + @pytest.fixture(autouse=True) def patch_provider_map(monkeypatch): @@ -454,7 +464,7 @@ def test_trace_task_message_trace(trace_task_message, mock_db): def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db): DummySessionContext.scalar_values = ["wf-app-log", "message-ref"] - execution = SimpleNamespace(id_="run-id") + execution = SimpleNamespace(id_="run-id", total_tokens=0) task = TraceTask( trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user" ) diff --git a/api/tests/unit_tests/core/ops/test_trace_queue_manager.py b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py new file mode 100644 index 0000000000..a4903054e0 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py @@ -0,0 +1,194 @@ +"""Unit tests for TraceQueueManager telemetry guard. + +Verifies that TraceQueueManager.add_trace_task() only enqueues tasks when at +least one consumer is active: +- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR +- A third-party trace instance (Langfuse, etc.) is configured + +When neither is active, tasks are silently dropped to avoid unnecessary work. + +When BOTH are false, tasks are silently dropped (correct behavior). +""" + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def trace_queue_manager_and_task(monkeypatch): + """Fixture to provide TraceQueueManager and TraceTask with delayed imports.""" + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type): + self.trace_type = trace_type + self.app_id = None + + class StubTraceQueueManager: + def __init__(self, app_id=None): + self.app_id = app_id + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.ops.entities.trace_entity import TraceTaskName + + ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"]) + TraceQueueManager = ops_module.TraceQueueManager + TraceTask = ops_module.TraceTask + + return TraceQueueManager, TraceTask, TraceTaskName + + +class TestTraceQueueManagerTelemetryGuard: + """Test TraceQueueManager's telemetry guard in add_trace_task().""" + + def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task): + """Verify task is NOT enqueued when telemetry disabled and no trace instance. + + This is the core guard: when _enterprise_telemetry_enabled=False AND + trace_instance=None, the task should be silently dropped. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_not_called() + + def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when enterprise telemetry is enabled. + + When _enterprise_telemetry_enabled=True, the task should be enqueued + regardless of trace_instance state. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task): + """Verify task IS enqueued when third-party trace instance is configured. + + When trace_instance is not None (e.g., Langfuse configured), the task + should be enqueued even if enterprise telemetry is disabled. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when both telemetry and trace instance are enabled. + + When both _enterprise_telemetry_enabled=True AND trace_instance is set, + the task should definitely be enqueued. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task): + """Verify app_id is set on the task before enqueuing. + + The guard logic sets trace_task.app_id = self.app_id before calling + trace_manager_queue.put(trace_task). This test verifies that behavior. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="expected-app-id") + manager.add_trace_task(trace_task) + + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "expected-app-id" diff --git a/api/tests/unit_tests/core/telemetry/test_facade.py b/api/tests/unit_tests/core/telemetry/test_facade.py new file mode 100644 index 0000000000..36e8e1bbb1 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_facade.py @@ -0,0 +1,181 @@ +"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering.""" + +from __future__ import annotations + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent + + +@pytest.fixture +def telemetry_test_setup(monkeypatch): + module_name = "core.ops.ops_trace_manager" + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type, **kwargs): + self.trace_type = trace_type + self.app_id = None + self.kwargs = kwargs + + class StubTraceQueueManager: + def __init__(self, app_id=None, user_id=None): + self.app_id = app_id + self.user_id = user_id + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.telemetry import emit + + return emit, ops_stub.trace_manager_queue + + +class TestTelemetryEmit: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_enterprise_trace_creates_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"key": "value"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_emit_community_trace_enqueued(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + + def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + enterprise_only_traces = [ + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TraceTaskName.NODE_EXECUTION_TRACE, + TraceTaskName.PROMPT_GENERATION_TRACE, + ] + + for trace_name in enterprise_only_traces: + mock_queue.reset_mock() + + event = TelemetryEvent( + name=trace_name, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == trace_name + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_passes_name_directly_to_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"extra": "data"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert isinstance(called_task.trace_type, TraceTaskName) + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_with_provided_trace_manager(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + mock_trace_manager = MagicMock() + mock_trace_manager.add_trace_task = MagicMock() + + event = TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event, trace_manager=mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + called_task = mock_trace_manager.add_trace_task.call_args[0][0] + assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE diff --git a/api/tests/unit_tests/core/telemetry/test_gateway_integration.py b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py new file mode 100644 index 0000000000..a68fce5e7f --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled +from enterprise.telemetry.contracts import TelemetryCase + + +class TestTelemetryCoreExports: + def test_is_enterprise_telemetry_enabled_exported(self) -> None: + from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func + + assert callable(exported_func) + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayIntegrationTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_to_trace_manager( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_routed_when_ee_enabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationMetricRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_metric_case_routes_to_celery_task( + self, + mock_ee_enabled: MagicMock, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_tool_execution_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"} + + emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_moderation_check_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}} + + emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationCEEligibility: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_workflow_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_message_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"message_id": "msg-abc", "conversation_id": "conv-123"} + + emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_draft_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_execution_data": {}} + + emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_prompt_generation_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"operation_type": "generate", "instruction": "test"} + + emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + +class TestIsEnterpriseTelemetryEnabled: + def test_returns_false_when_exporter_import_fails(self) -> None: + with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}): + result = is_enterprise_telemetry_enabled() + assert result is False + + def test_function_is_callable(self) -> None: + assert callable(is_enterprise_telemetry_enabled) diff --git a/api/tests/unit_tests/enterprise/telemetry/__init__.py b/api/tests/unit_tests/enterprise/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/enterprise/telemetry/test_contracts.py b/api/tests/unit_tests/enterprise/telemetry/test_contracts.py new file mode 100644 index 0000000000..7453525bfc --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_contracts.py @@ -0,0 +1,230 @@ +"""Unit tests for telemetry gateway contracts.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from core.telemetry.gateway import CASE_ROUTING +from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope + + +class TestTelemetryCase: + """Tests for TelemetryCase enum.""" + + def test_all_cases_defined(self) -> None: + """Verify all 14 telemetry cases are defined.""" + expected_cases = { + "WORKFLOW_RUN", + "NODE_EXECUTION", + "DRAFT_NODE_EXECUTION", + "MESSAGE_RUN", + "TOOL_EXECUTION", + "MODERATION_CHECK", + "SUGGESTED_QUESTION", + "DATASET_RETRIEVAL", + "GENERATE_NAME", + "PROMPT_GENERATION", + "APP_CREATED", + "APP_UPDATED", + "APP_DELETED", + "FEEDBACK_CREATED", + } + actual_cases = {case.name for case in TelemetryCase} + assert actual_cases == expected_cases + + def test_case_values(self) -> None: + """Verify case enum values are correct.""" + assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run" + assert TelemetryCase.NODE_EXECUTION.value == "node_execution" + assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution" + assert TelemetryCase.MESSAGE_RUN.value == "message_run" + assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution" + assert TelemetryCase.MODERATION_CHECK.value == "moderation_check" + assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question" + assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval" + assert TelemetryCase.GENERATE_NAME.value == "generate_name" + assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation" + assert TelemetryCase.APP_CREATED.value == "app_created" + assert TelemetryCase.APP_UPDATED.value == "app_updated" + assert TelemetryCase.APP_DELETED.value == "app_deleted" + assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created" + + +class TestCaseRoute: + """Tests for CaseRoute model.""" + + def test_valid_trace_route(self) -> None: + """Verify valid trace route creation.""" + route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True) + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is True + + def test_valid_metric_log_route(self) -> None: + """Verify valid metric_log route creation.""" + route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False) + assert route.signal_type == SignalType.METRIC_LOG + assert route.ce_eligible is False + + def test_invalid_signal_type(self) -> None: + """Verify invalid signal_type is rejected.""" + with pytest.raises(ValidationError): + CaseRoute(signal_type="invalid", ce_eligible=True) + + +class TestTelemetryEnvelope: + """Tests for TelemetryEnvelope model.""" + + def test_valid_envelope_minimal(self) -> None: + """Verify valid minimal envelope creation.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + ) + assert envelope.case == TelemetryCase.WORKFLOW_RUN + assert envelope.tenant_id == "tenant-123" + assert envelope.event_id == "event-456" + assert envelope.payload == {"key": "value"} + assert envelope.metadata is None + + def test_valid_envelope_full(self) -> None: + """Verify valid envelope with all fields.""" + metadata = {"payload_ref": "telemetry/tenant-789/event-012.json"} + envelope = TelemetryEnvelope( + case=TelemetryCase.MESSAGE_RUN, + tenant_id="tenant-789", + event_id="event-012", + payload={"message": "hello"}, + metadata=metadata, + ) + assert envelope.case == TelemetryCase.MESSAGE_RUN + assert envelope.tenant_id == "tenant-789" + assert envelope.event_id == "event-012" + assert envelope.payload == {"message": "hello"} + assert envelope.metadata == metadata + + def test_missing_required_case(self) -> None: + """Verify missing case field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + ) + + def test_missing_required_tenant_id(self) -> None: + """Verify missing tenant_id field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + event_id="event-456", + payload={"key": "value"}, + ) + + def test_missing_required_event_id(self) -> None: + """Verify missing event_id field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + payload={"key": "value"}, + ) + + def test_missing_required_payload(self) -> None: + """Verify missing payload field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + ) + + def test_metadata_none(self) -> None: + """Verify metadata can be None.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + metadata=None, + ) + assert envelope.metadata is None + + +class TestCaseRouting: + """Tests for CASE_ROUTING table.""" + + def test_all_cases_routed(self) -> None: + """Verify all 14 cases have routing entries.""" + assert len(CASE_ROUTING) == 14 + for case in TelemetryCase: + assert case in CASE_ROUTING + + def test_trace_ce_eligible_cases(self) -> None: + """Verify trace cases with CE eligibility.""" + ce_eligible_trace_cases = { + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + } + for case in ce_eligible_trace_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is True + + def test_trace_enterprise_only_cases(self) -> None: + """Verify trace cases that are enterprise-only.""" + enterprise_only_trace_cases = { + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + } + for case in enterprise_only_trace_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is False + + def test_metric_log_cases(self) -> None: + """Verify metric/log-only cases.""" + metric_log_cases = { + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + } + for case in metric_log_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.METRIC_LOG + assert route.ce_eligible is False + + def test_routing_table_completeness(self) -> None: + """Verify routing table covers all cases with correct types.""" + trace_cases = { + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + } + metric_log_cases = { + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + } + + all_cases = trace_cases | metric_log_cases + assert len(all_cases) == 14 + assert all_cases == set(TelemetryCase) + + for case in trace_cases: + assert CASE_ROUTING[case].signal_type == SignalType.TRACE + + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG diff --git a/api/tests/unit_tests/enterprise/telemetry/test_draft_trace.py b/api/tests/unit_tests/enterprise/telemetry/test_draft_trace.py new file mode 100644 index 0000000000..c8c8de8595 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_draft_trace.py @@ -0,0 +1,519 @@ +"""Unit tests for enterprise/telemetry/draft_trace.py.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch + +from graphon.enums import WorkflowNodeExecutionMetadataKey + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_execution(**overrides) -> MagicMock: + """Return a minimal WorkflowNodeExecutionModel mock.""" + execution = MagicMock() + execution.tenant_id = overrides.get("tenant_id", "tenant-1") + execution.app_id = overrides.get("app_id", "app-1") + execution.workflow_id = overrides.get("workflow_id", "wf-1") + execution.id = overrides.get("id", "exec-1") + execution.node_id = overrides.get("node_id", "node-1") + execution.node_type = overrides.get("node_type", "llm") + execution.title = overrides.get("title", "My LLM Node") + execution.status = overrides.get("status", "succeeded") + execution.error = overrides.get("error") + execution.elapsed_time = overrides.get("elapsed_time", 1.5) + execution.index = overrides.get("index", 1) + execution.predecessor_node_id = overrides.get("predecessor_node_id") + execution.created_at = overrides.get("created_at", datetime(2024, 1, 1, tzinfo=UTC)) + execution.finished_at = overrides.get("finished_at", datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)) + execution.workflow_run_id = overrides.get("workflow_run_id", "run-1") + execution.inputs_dict = overrides.get("inputs_dict", {"prompt": "hello"}) + execution.outputs_dict = overrides.get("outputs_dict", {"answer": "world"}) + execution.process_data_dict = overrides.get("process_data_dict", {}) + execution.execution_metadata_dict = overrides.get("execution_metadata_dict", {}) + return execution + + +# --------------------------------------------------------------------------- +# _build_node_execution_data +# --------------------------------------------------------------------------- + + +class TestBuildNodeExecutionData: + def test_basic_fields_populated(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution() + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id="run-override", + ) + + assert result["workflow_id"] == "wf-1" + assert result["tenant_id"] == "tenant-1" + assert result["app_id"] == "app-1" + assert result["node_execution_id"] == "exec-1" + assert result["node_id"] == "node-1" + assert result["node_type"] == "llm" + assert result["title"] == "My LLM Node" + assert result["status"] == "succeeded" + assert result["error"] is None + assert result["elapsed_time"] == 1.5 + assert result["index"] == 1 + + def test_workflow_execution_id_prefers_parameter(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(workflow_run_id="run-from-model") + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id="explicit-run", + ) + assert result["workflow_execution_id"] == "explicit-run" + + def test_workflow_execution_id_falls_back_to_run_id(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(workflow_run_id="run-from-model") + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id=None, + ) + assert result["workflow_execution_id"] == "run-from-model" + + def test_workflow_execution_id_falls_back_to_execution_id(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(workflow_run_id=None, id="exec-fallback") + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id=None, + ) + assert result["workflow_execution_id"] == "exec-fallback" + + def test_outputs_param_overrides_execution_outputs(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(outputs_dict={"from_model": True}) + result = _build_node_execution_data( + execution=execution, + outputs={"from_param": True}, + workflow_execution_id=None, + ) + assert result["node_outputs"] == {"from_param": True} + + def test_outputs_none_uses_execution_outputs_dict(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(outputs_dict={"from_model": True}) + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id=None, + ) + assert result["node_outputs"] == {"from_model": True} + + def test_metadata_token_fields_default_to_zero(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(execution_metadata_dict={}) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["total_tokens"] == 0 + assert result["total_price"] == 0.0 + assert result["currency"] is None + + def test_metadata_token_fields_populated_from_metadata(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + metadata = { + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 200, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.05, + WorkflowNodeExecutionMetadataKey.CURRENCY: "USD", + } + execution = _make_execution(execution_metadata_dict=metadata) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["total_tokens"] == 200 + assert result["total_price"] == 0.05 + assert result["currency"] == "USD" + + def test_tool_name_extracted_from_tool_info_dict(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + metadata = { + WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"tool_name": "web_search"}, + } + execution = _make_execution(execution_metadata_dict=metadata) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["tool_name"] == "web_search" + + def test_tool_name_is_none_when_tool_info_not_dict(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: "not-a-dict"} + execution = _make_execution(execution_metadata_dict=metadata) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["tool_name"] is None + + def test_tool_name_is_none_when_tool_info_absent(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(execution_metadata_dict={}) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["tool_name"] is None + + def test_iteration_and_loop_fields(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + metadata = { + WorkflowNodeExecutionMetadataKey.ITERATION_ID: "iter-1", + WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: 3, + WorkflowNodeExecutionMetadataKey.LOOP_ID: "loop-1", + WorkflowNodeExecutionMetadataKey.LOOP_INDEX: 2, + WorkflowNodeExecutionMetadataKey.PARALLEL_ID: "par-1", + } + execution = _make_execution(execution_metadata_dict=metadata) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["iteration_id"] == "iter-1" + assert result["iteration_index"] == 3 + assert result["loop_id"] == "loop-1" + assert result["loop_index"] == 2 + assert result["parallel_id"] == "par-1" + + def test_node_inputs_and_process_data_included(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution( + inputs_dict={"q": "test"}, + process_data_dict={"step": 1}, + ) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["node_inputs"] == {"q": "test"} + assert result["process_data"] == {"step": 1} + + +# --------------------------------------------------------------------------- +# enqueue_draft_node_execution_trace +# --------------------------------------------------------------------------- + + +class TestEnqueueDraftNodeExecutionTrace: + @patch("enterprise.telemetry.draft_trace.telemetry_emit") + def test_emits_telemetry_event(self, mock_emit: MagicMock) -> None: + from core.telemetry import TelemetryEvent, TraceTaskName + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + execution = _make_execution() + enqueue_draft_node_execution_trace( + execution=execution, + outputs={"result": "ok"}, + workflow_execution_id="run-x", + user_id="user-1", + ) + + mock_emit.assert_called_once() + event: TelemetryEvent = mock_emit.call_args[0][0] + assert event.name == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert event.context.tenant_id == "tenant-1" + assert event.context.user_id == "user-1" + assert event.context.app_id == "app-1" + + @patch("enterprise.telemetry.draft_trace.telemetry_emit") + def test_payload_contains_node_execution_data(self, mock_emit: MagicMock) -> None: + from core.telemetry import TelemetryEvent + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + execution = _make_execution() + enqueue_draft_node_execution_trace( + execution=execution, + outputs=None, + workflow_execution_id=None, + user_id="user-2", + ) + + event: TelemetryEvent = mock_emit.call_args[0][0] + node_data = event.payload["node_execution_data"] + assert node_data["workflow_id"] == "wf-1" + assert node_data["node_type"] == "llm" + assert node_data["status"] == "succeeded" + + @patch("enterprise.telemetry.draft_trace.telemetry_emit") + def test_outputs_forwarded_to_build(self, mock_emit: MagicMock) -> None: + from core.telemetry import TelemetryEvent + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + execution = _make_execution(outputs_dict={"default": True}) + enqueue_draft_node_execution_trace( + execution=execution, + outputs={"explicit": True}, + workflow_execution_id=None, + user_id="user-3", + ) + + event: TelemetryEvent = mock_emit.call_args[0][0] + assert event.payload["node_execution_data"]["node_outputs"] == {"explicit": True} + + @patch("enterprise.telemetry.draft_trace.telemetry_emit") + def test_none_outputs_uses_execution_outputs(self, mock_emit: MagicMock) -> None: + from core.telemetry import TelemetryEvent + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + execution = _make_execution(outputs_dict={"from_model": "yes"}) + enqueue_draft_node_execution_trace( + execution=execution, + outputs=None, + workflow_execution_id=None, + user_id="user-4", + ) + + event: TelemetryEvent = mock_emit.call_args[0][0] + assert event.payload["node_execution_data"]["node_outputs"] == {"from_model": "yes"} + + +# --------------------------------------------------------------------------- +# End-to-end token/model data flow: _build_node_execution_data → +# ops_trace_manager.draft_node_execution_trace → DraftNodeExecutionTrace +# --------------------------------------------------------------------------- + + +def _make_llm_execution() -> MagicMock: + """Return a WorkflowNodeExecutionModel mock that mimics a real LLM node. + + The field values match what graphon/nodes/llm/node.py produces: + - process_data_dict contains model_provider, model_name, and usage + - outputs_dict contains usage with prompt/completion breakdown + - execution_metadata_dict contains total_tokens/total_price/currency + """ + return _make_execution( + tenant_id="tenant-flow", + app_id="app-flow", + workflow_id="wf-flow", + id="exec-flow", + node_id="node-llm", + node_type="llm", + title="GPT-4o Node", + status="succeeded", + elapsed_time=2.3, + workflow_run_id=None, + process_data_dict={ + "model_mode": "chat", + "model_provider": "openai", + "model_name": "gpt-4o", + "prompts": [{"role": "user", "text": "hello"}], + "usage": { + "prompt_tokens": 50, + "prompt_unit_price": 0.00001, + "prompt_price_unit": 0.001, + "prompt_price": 0.0005, + "completion_tokens": 30, + "completion_unit_price": 0.00003, + "completion_price_unit": 0.001, + "completion_price": 0.0009, + "total_tokens": 80, + "total_price": 0.0014, + "currency": "USD", + "latency": 2.3, + }, + "finish_reason": "stop", + }, + outputs_dict={ + "text": "world", + "usage": { + "prompt_tokens": 50, + "completion_tokens": 30, + "total_tokens": 80, + "total_price": 0.0014, + "currency": "USD", + }, + "finish_reason": "stop", + }, + execution_metadata_dict={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 80, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.0014, + WorkflowNodeExecutionMetadataKey.CURRENCY: "USD", + }, + ) + + +class TestDraftTraceTokenDataFlow: + """End-to-end test: verify all token and model fields survive from + _build_node_execution_data through ops_trace_manager.draft_node_execution_trace + to the DraftNodeExecutionTrace that enterprise_trace.py consumes. + """ + + def test_all_token_and_model_fields_reach_trace_info(self) -> None: + """Simulate the full draft trace data flow for an LLM node and + assert every token/model field that enterprise_trace._emit_node_execution_trace + reads is populated correctly on the resulting DraftNodeExecutionTrace.""" + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_llm_execution() + node_data = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id="run-flow", + ) + + # Simulate what ops_trace_manager.draft_node_execution_trace does: + # it calls node_execution_trace(node_execution_data=node_data) which + # reads top-level keys from node_data. Verify all expected keys exist. + expected_keys = { + # Token fields — read by enterprise_trace._emit_node_execution_trace + "total_tokens", + "total_price", + "currency", + "prompt_tokens", + "completion_tokens", + # Model fields — read for span attrs and metric labels + "model_provider", + "model_name", + # Node identity — read for span attrs + "node_type", + "node_execution_id", + "node_id", + "title", + "status", + "error", + "elapsed_time", + # Workflow context + "workflow_id", + "workflow_execution_id", + "tenant_id", + "app_id", + # Structure fields + "index", + "predecessor_node_id", + "iteration_id", + "iteration_index", + "loop_id", + "loop_index", + "parallel_id", + # Tool field + "tool_name", + # Content fields + "node_inputs", + "node_outputs", + "process_data", + # Timestamps + "created_at", + "finished_at", + } + assert set(node_data.keys()) == expected_keys + + # Verify token/model values are correct (not None/zero when data exists) + assert node_data["total_tokens"] == 80 + assert node_data["total_price"] == 0.0014 + assert node_data["currency"] == "USD" + assert node_data["prompt_tokens"] == 50 + assert node_data["completion_tokens"] == 30 + assert node_data["model_provider"] == "openai" + assert node_data["model_name"] == "gpt-4o" + assert node_data["node_type"] == "llm" + + def test_non_llm_node_has_none_for_model_and_token_breakdown(self) -> None: + """For non-LLM nodes (e.g. code, IF), model and token breakdown + should be None, but total_tokens from metadata should still work.""" + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution( + node_type="code", + process_data_dict={"code": "print('hi')"}, + outputs_dict={"result": "hi"}, + execution_metadata_dict={}, + ) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["model_provider"] is None + assert result["model_name"] is None + assert result["prompt_tokens"] is None + assert result["completion_tokens"] is None + assert result["total_tokens"] == 0 + + def test_none_process_data_and_none_outputs(self) -> None: + """Both process_data_dict and outputs_dict are None — exercises + the `or {}` fallback and isinstance guard together.""" + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(process_data_dict=None, outputs_dict=None) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["model_provider"] is None + assert result["model_name"] is None + assert result["prompt_tokens"] is None + assert result["completion_tokens"] is None + + def test_node_data_feeds_into_draft_node_execution_trace(self) -> None: + """Verify the node_data dict can be consumed by + ops_trace_manager.draft_node_execution_trace without error and + produces a DraftNodeExecutionTrace with correct token/model fields.""" + + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_llm_execution() + node_data = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id="run-e2e", + ) + + # Directly construct DraftNodeExecutionTrace the way + # ops_trace_manager.node_execution_trace does (lines 1315-1350), + # skipping DB lookups by providing minimal metadata. + from core.ops.entities.trace_entity import DraftNodeExecutionTrace + + trace_info = DraftNodeExecutionTrace( + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata={}, + ) + + # These are the fields enterprise_trace._emit_node_execution_trace reads + assert trace_info.total_tokens == 80 + assert trace_info.prompt_tokens == 50 + assert trace_info.completion_tokens == 30 + assert trace_info.model_provider == "openai" + assert trace_info.model_name == "gpt-4o" + assert trace_info.node_type == "llm" + assert trace_info.total_price == 0.0014 + assert trace_info.currency == "USD" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py new file mode 100644 index 0000000000..bb1f78b80c --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py @@ -0,0 +1,1327 @@ +"""Unit tests for EnterpriseOtelTrace.""" + +from __future__ import annotations + +import json +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryEvent, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_exporter(): + exporter = MagicMock() + exporter.include_content = True + return exporter + + +@pytest.fixture +def trace_handler(mock_exporter): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=mock_exporter): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + handler = EnterpriseOtelTrace() + return handler + + +# --------------------------------------------------------------------------- +# Factory helpers +# --------------------------------------------------------------------------- + +_T0 = datetime(2024, 1, 10, 12, 0, 0, tzinfo=UTC) +_T1 = datetime(2024, 1, 10, 12, 0, 5, tzinfo=UTC) + + +def make_workflow_info(**overrides) -> WorkflowTraceInfo: + defaults: dict = { + "workflow_id": "wf-001", + "tenant_id": "tenant-abc", + "workflow_run_id": "run-001", + "workflow_run_elapsed_time": 5.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"query": "hello"}, + "workflow_run_outputs": {"answer": "world"}, + "workflow_run_version": "1", + "total_tokens": 100, + "prompt_tokens": 60, + "completion_tokens": 40, + "file_list": [], + "query": "hello", + "start_time": _T0, + "end_time": _T1, + "metadata": { + "app_id": "app-001", + "tenant_id": "tenant-abc", + "app_name": "MyApp", + "workspace_name": "WS", + "triggered_from": "api", + }, + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def make_node_info(**overrides) -> WorkflowNodeTraceInfo: + defaults: dict = { + "workflow_id": "wf-001", + "workflow_run_id": "run-001", + "tenant_id": "tenant-abc", + "node_execution_id": "ne-001", + "node_id": "node-001", + "node_type": "llm", + "title": "LLM Node", + "status": "succeeded", + "elapsed_time": 2.5, + "index": 1, + "total_tokens": 80, + "prompt_tokens": 50, + "completion_tokens": 30, + "model_provider": "openai", + "model_name": "gpt-4", + "start_time": _T0, + "end_time": _T1, + "metadata": { + "app_id": "app-001", + "tenant_id": "tenant-abc", + "app_name": "MyApp", + }, + } + defaults.update(overrides) + return WorkflowNodeTraceInfo(**defaults) + + +def make_draft_node_info(**overrides) -> DraftNodeExecutionTrace: + defaults: dict = { + "workflow_id": "wf-001", + "workflow_run_id": "run-draft-001", + "tenant_id": "tenant-abc", + "node_execution_id": "ne-draft-001", + "node_id": "node-001", + "node_type": "llm", + "title": "Draft LLM", + "status": "succeeded", + "elapsed_time": 1.2, + "index": 0, + "total_tokens": 50, + "start_time": _T0, + "end_time": _T1, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return DraftNodeExecutionTrace(**defaults) + + +def make_message_info(**overrides) -> MessageTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "conversation_model": "gpt-4", + "message_tokens": 40, + "answer_tokens": 60, + "total_tokens": 100, + "conversation_mode": "chat", + "start_time": _T0, + "end_time": _T1, + "inputs": "user input", + "outputs": "assistant output", + "metadata": { + "app_id": "app-001", + "tenant_id": "tenant-abc", + "from_source": "api", + "ls_provider": "openai", + "ls_model_name": "gpt-4", + "status": "succeeded", + }, + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def make_tool_info(**overrides) -> ToolTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "tool_name": "web_search", + "tool_inputs": {"query": "test"}, + "tool_outputs": "search results", + "tool_config": {"max_results": 5}, + "tool_parameters": {"verbose": True}, + "time_cost": 1.5, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def make_moderation_info(**overrides) -> ModerationTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "flagged": False, + "action": "pass", + "preset_response": "", + "query": "is this ok?", + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return ModerationTraceInfo(**defaults) + + +def make_suggested_question_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "total_tokens": 30, + "suggested_question": ["Question A?", "Question B?"], + "level": "info", + "status": "succeeded", + "model_provider": "openai", + "model_id": "gpt-3.5-turbo", + "start_time": _T0, + "end_time": _T1, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +def make_dataset_retrieval_info(**overrides) -> DatasetRetrievalTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "documents": [ + { + "metadata": { + "dataset_id": "ds-001", + "dataset_name": "MyDataset", + "document_id": "doc-001", + "segment_id": "seg-001", + "score": 0.95, + } + } + ], + "inputs": "search query", + "metadata": { + "app_id": "app-001", + "tenant_id": "tenant-abc", + "embedding_models": { + "ds-001": { + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-3-small", + } + }, + }, + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def make_generate_name_info(**overrides) -> GenerateNameTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "tenant_id": "tenant-abc", + "conversation_id": "conv-001", + "inputs": "some content", + "outputs": "My Conversation", + "start_time": _T0, + "end_time": _T1, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return GenerateNameTraceInfo(**defaults) + + +def make_prompt_generation_info(**overrides) -> PromptGenerationTraceInfo: + defaults: dict = { + "tenant_id": "tenant-abc", + "user_id": "user-001", + "app_id": "app-001", + "operation_type": "rule_generate", + "instruction": "Generate a helpful prompt", + "prompt_tokens": 50, + "completion_tokens": 100, + "total_tokens": 150, + "model_provider": "openai", + "model_name": "gpt-4", + "latency": 3.2, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return PromptGenerationTraceInfo(**defaults) + + +# --------------------------------------------------------------------------- +# Constructor +# --------------------------------------------------------------------------- + + +def test_init_raises_when_exporter_is_none(): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=None): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + with pytest.raises(RuntimeError, match="exporter is not initialized"): + EnterpriseOtelTrace() + + +def test_init_succeeds_with_valid_exporter(mock_exporter): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=mock_exporter): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + handler = EnterpriseOtelTrace() + assert handler._exporter is mock_exporter + + +# --------------------------------------------------------------------------- +# Helper methods +# --------------------------------------------------------------------------- + + +class TestSafePayloadValue: + def test_string_passthrough(self, trace_handler): + assert trace_handler._safe_payload_value("hello") == "hello" + + def test_dict_passthrough(self, trace_handler): + d = {"key": "val"} + assert trace_handler._safe_payload_value(d) == d + + def test_list_passthrough(self, trace_handler): + lst = [1, 2, 3] + assert trace_handler._safe_payload_value(lst) == lst + + def test_none_returns_none(self, trace_handler): + assert trace_handler._safe_payload_value(None) is None + + def test_int_returns_none(self, trace_handler): + assert trace_handler._safe_payload_value(42) is None + + def test_bool_returns_none(self, trace_handler): + assert trace_handler._safe_payload_value(True) is None + + +class TestMaybeJson: + def test_none_returns_none(self, trace_handler): + assert trace_handler._maybe_json(None) is None + + def test_string_passthrough(self, trace_handler): + assert trace_handler._maybe_json("hello") == "hello" + + def test_dict_serialised(self, trace_handler): + result = trace_handler._maybe_json({"a": 1}) + assert result == json.dumps({"a": 1}) + + def test_list_serialised(self, trace_handler): + result = trace_handler._maybe_json([1, 2]) + assert result == "[1, 2]" + + def test_non_serialisable_falls_back_to_str(self, trace_handler): + class Unserializable: + def __repr__(self): + return "Unserializable()" + + obj = Unserializable() + result = trace_handler._maybe_json(obj) + assert isinstance(result, str) + + +class TestContentOrRef: + def test_returns_content_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + result = trace_handler._content_or_ref("actual content", "ref:x=1") + assert result == "actual content" + + def test_returns_ref_when_include_content_false(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + result = trace_handler._content_or_ref("actual content", "ref:x=1") + assert result == "ref:x=1" + + def test_dict_serialised_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + result = trace_handler._content_or_ref({"key": "val"}, "ref:x=1") + assert result == json.dumps({"key": "val"}) + + def test_none_returns_none_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + result = trace_handler._content_or_ref(None, "ref:x=1") + assert result is None + + +# --------------------------------------------------------------------------- +# trace() dispatcher +# --------------------------------------------------------------------------- + + +class TestTraceDispatcher: + def test_dispatches_workflow_trace(self, trace_handler): + with patch.object(trace_handler, "_workflow_trace") as mock_method: + info = make_workflow_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_message_trace(self, trace_handler): + with patch.object(trace_handler, "_message_trace") as mock_method: + info = make_message_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_tool_trace(self, trace_handler): + with patch.object(trace_handler, "_tool_trace") as mock_method: + info = make_tool_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_draft_node_execution_trace(self, trace_handler): + with patch.object(trace_handler, "_draft_node_execution_trace") as mock_method: + info = make_draft_node_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_node_execution_trace(self, trace_handler): + with patch.object(trace_handler, "_node_execution_trace") as mock_method: + info = make_node_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_moderation_trace(self, trace_handler): + with patch.object(trace_handler, "_moderation_trace") as mock_method: + info = make_moderation_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_suggested_question_trace(self, trace_handler): + with patch.object(trace_handler, "_suggested_question_trace") as mock_method: + info = make_suggested_question_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_dataset_retrieval_trace(self, trace_handler): + with patch.object(trace_handler, "_dataset_retrieval_trace") as mock_method: + info = make_dataset_retrieval_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_generate_name_trace(self, trace_handler): + with patch.object(trace_handler, "_generate_name_trace") as mock_method: + info = make_generate_name_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_prompt_generation_trace(self, trace_handler): + with patch.object(trace_handler, "_prompt_generation_trace") as mock_method: + info = make_prompt_generation_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_draft_node_dispatched_before_node(self, trace_handler): + """DraftNodeExecutionTrace is a subclass of WorkflowNodeTraceInfo; + it must be dispatched to _draft_node_execution_trace, not _node_execution_trace.""" + with ( + patch.object(trace_handler, "_draft_node_execution_trace") as mock_draft, + patch.object(trace_handler, "_node_execution_trace") as mock_node, + ): + info = make_draft_node_info() + trace_handler.trace(info) + mock_draft.assert_called_once_with(info) + mock_node.assert_not_called() + + +# --------------------------------------------------------------------------- +# _workflow_trace +# --------------------------------------------------------------------------- + + +class TestWorkflowTrace: + def test_emits_correct_span_attributes(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + info = make_workflow_info() + trace_handler._workflow_trace(info) + + mock_exporter.export_span.assert_called_once() + span_call = mock_exporter.export_span.call_args + assert span_call[0][0] == EnterpriseTelemetrySpan.WORKFLOW_RUN + attrs = span_call[0][1] + assert attrs["dify.workflow.run_id"] == "run-001" + assert attrs["dify.workflow.id"] == "wf-001" + assert attrs["dify.tenant_id"] == "tenant-abc" + assert attrs["dify.workflow.status"] == "succeeded" + assert attrs["gen_ai.usage.total_tokens"] == 100 + + def test_span_timing_passed_correctly(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info() + trace_handler._workflow_trace(info) + + span_call = mock_exporter.export_span.call_args + assert span_call[1]["start_time"] == _T0 + assert span_call[1]["end_time"] == _T1 + + def test_emits_companion_log_with_event_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._workflow_trace(make_workflow_info()) + + mock_log.assert_called_once() + assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetryEvent.WORKFLOW_RUN + assert mock_log.call_args[1]["tenant_id"] == "tenant-abc" + + def test_companion_log_includes_content_when_enabled(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._workflow_trace(make_workflow_info()) + + log_attrs = mock_log.call_args[1]["attributes"] + assert log_attrs["dify.workflow.inputs"] == json.dumps({"query": "hello"}) + assert log_attrs["dify.workflow.outputs"] == json.dumps({"answer": "world"}) + + def test_companion_log_uses_ref_when_content_disabled(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._workflow_trace(make_workflow_info()) + + log_attrs = mock_log.call_args[1]["attributes"] + assert log_attrs["dify.workflow.inputs"].startswith("ref:workflow_run_id=") + assert log_attrs["dify.workflow.outputs"].startswith("ref:workflow_run_id=") + + def test_increments_token_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._workflow_trace(make_workflow_info()) + + token_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.TOKENS + ] + assert len(token_calls) == 1 + assert token_calls[0][0][1] == 100 + + def test_increments_input_and_output_token_counters(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._workflow_trace(make_workflow_info()) + + all_calls = mock_exporter.increment_counter.call_args_list + counter_names = [c[0][0] for c in all_calls] + assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names + assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names + + def test_no_input_token_counter_when_prompt_tokens_zero(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info(prompt_tokens=0) + trace_handler._workflow_trace(info) + + all_calls = mock_exporter.increment_counter.call_args_list + counter_names = [c[0][0] for c in all_calls] + assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names + + def test_records_workflow_duration_histogram(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._workflow_trace(make_workflow_info()) + + mock_exporter.record_histogram.assert_called_once() + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][0] == EnterpriseTelemetryHistogram.WORKFLOW_DURATION + assert hist_call[0][1] == pytest.approx(5.0) + + def test_duration_falls_back_to_elapsed_time_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=7.3) + trace_handler._workflow_trace(info) + + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][1] == pytest.approx(7.3) + + def test_duration_defaults_to_zero_when_no_timing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=0) + trace_handler._workflow_trace(info) + + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][1] == pytest.approx(0.0) + + def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info(error="Something went wrong", workflow_run_status="failed") + trace_handler._workflow_trace(info) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_no_error_counter_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._workflow_trace(make_workflow_info()) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 0 + + def test_parent_trace_context_injected_into_span_attrs(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info( + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "parent_trace_context": { + "trace_id": "outer-trace", + "parent_node_execution_id": "outer-ne-001", + "parent_workflow_run_id": "outer-run-001", + "parent_app_id": "outer-app-001", + }, + } + ) + trace_handler._workflow_trace(info) + + attrs = mock_exporter.export_span.call_args[0][1] + assert attrs["dify.parent.trace_id"] == "outer-trace" + assert attrs["dify.parent.node.execution_id"] == "outer-ne-001" + assert attrs["dify.parent.workflow.run_id"] == "outer-run-001" + assert attrs["dify.parent.app.id"] == "outer-app-001" + + +# --------------------------------------------------------------------------- +# _node_execution_trace / _emit_node_execution_trace +# --------------------------------------------------------------------------- + + +class TestNodeExecutionTrace: + def test_emits_span_with_node_execution_span_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info()) + + span_call = mock_exporter.export_span.call_args + assert span_call[0][0] == EnterpriseTelemetrySpan.NODE_EXECUTION + + def test_span_contains_core_node_attributes(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info()) + + attrs = mock_exporter.export_span.call_args[0][1] + assert attrs["dify.node.execution_id"] == "ne-001" + assert attrs["dify.node.id"] == "node-001" + assert attrs["dify.node.type"] == "llm" + assert attrs["dify.node.status"] == "succeeded" + assert attrs["gen_ai.request.model"] == "gpt-4" + assert attrs["gen_ai.provider.name"] == "openai" + + def test_increments_token_counters_when_tokens_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info()) + + counter_names = [c[0][0] for c in mock_exporter.increment_counter.call_args_list] + assert EnterpriseTelemetryCounter.TOKENS in counter_names + assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names + assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names + + def test_no_token_counters_when_total_tokens_zero(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info(total_tokens=0)) + + counter_names = [c[0][0] for c in mock_exporter.increment_counter.call_args_list] + assert EnterpriseTelemetryCounter.TOKENS not in counter_names + assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names + + def test_records_node_duration_histogram(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info()) + + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][0] == EnterpriseTelemetryHistogram.NODE_DURATION + assert hist_call[0][1] == pytest.approx(2.5) + + def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info(error="Node failed", status="failed")) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_emits_companion_log_with_span_name_as_event(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._node_execution_trace(make_node_info()) + + mock_log.assert_called_once() + assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetrySpan.NODE_EXECUTION.value + + def test_plugin_name_added_to_duration_labels_for_tool_node(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_node_info( + node_type="tool", + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "plugin_name": "my-plugin", + }, + ) + trace_handler._node_execution_trace(info) + + hist_call = mock_exporter.record_histogram.call_args + duration_labels = hist_call[0][2] + assert duration_labels.get("plugin_name") == "my-plugin" + + def test_plugin_name_not_added_for_non_tool_node(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_node_info( + node_type="llm", + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "plugin_name": "my-plugin", + }, + ) + trace_handler._node_execution_trace(info) + + hist_call = mock_exporter.record_histogram.call_args + duration_labels = hist_call[0][2] + assert "plugin_name" not in duration_labels + + def test_companion_log_inputs_use_ref_when_content_disabled(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._node_execution_trace( + make_node_info(node_inputs={"prompt": "hello"}, node_outputs={"text": "world"}) + ) + + log_attrs = mock_log.call_args[1]["attributes"] + assert log_attrs["dify.node.inputs"].startswith("ref:node_execution_id=") + assert log_attrs["dify.node.outputs"].startswith("ref:node_execution_id=") + + +# --------------------------------------------------------------------------- +# _draft_node_execution_trace +# --------------------------------------------------------------------------- + + +class TestDraftNodeExecutionTrace: + def test_uses_draft_span_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._draft_node_execution_trace(make_draft_node_info()) + + span_call = mock_exporter.export_span.call_args + assert span_call[0][0] == EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION + + def test_correlation_id_is_node_execution_id(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_draft_node_info() + trace_handler._draft_node_execution_trace(info) + + span_call = mock_exporter.export_span.call_args + assert span_call[1]["correlation_id"] == "ne-draft-001" + + def test_trace_correlation_override_is_workflow_run_id(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_draft_node_info() + trace_handler._draft_node_execution_trace(info) + + span_call = mock_exporter.export_span.call_args + assert span_call[1]["trace_correlation_override"] == "run-draft-001" + + def test_companion_log_uses_draft_span_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._draft_node_execution_trace(make_draft_node_info()) + + assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION.value + + +# --------------------------------------------------------------------------- +# _message_trace +# --------------------------------------------------------------------------- + + +class TestMessageTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info()) + + mock_emit.assert_called_once() + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MESSAGE_RUN + + def test_emits_correct_tenant_and_user(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info()) + + assert mock_emit.call_args[1]["tenant_id"] == "tenant-abc" + + def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.message.duration"] == pytest.approx(5.0) + + def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info(start_time=None, end_time=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.message.duration" not in attrs + + def test_records_duration_histogram_when_timestamps_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info()) + + hist_calls = [ + c + for c in mock_exporter.record_histogram.call_args_list + if c[0][0] == EnterpriseTelemetryHistogram.MESSAGE_DURATION + ] + assert len(hist_calls) == 1 + assert hist_calls[0][0][1] == pytest.approx(5.0) + + def test_no_duration_histogram_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info(start_time=None, end_time=None)) + + hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list] + assert EnterpriseTelemetryHistogram.MESSAGE_DURATION not in hist_names + + def test_records_ttft_histogram_when_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=0.42)) + + ttft_calls = [ + c + for c in mock_exporter.record_histogram.call_args_list + if c[0][0] == EnterpriseTelemetryHistogram.MESSAGE_TTFT + ] + assert len(ttft_calls) == 1 + assert ttft_calls[0][0][1] == pytest.approx(0.42) + + def test_no_ttft_histogram_when_not_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=None)) + + hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list] + assert EnterpriseTelemetryHistogram.MESSAGE_TTFT not in hist_names + + def test_increments_token_counters(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info()) + + counter_names = [c[0][0] for c in mock_exporter.increment_counter.call_args_list] + assert EnterpriseTelemetryCounter.TOKENS in counter_names + assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names + assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names + + def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info(error="LLM failed")) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.message.inputs"].startswith("ref:message_id=") + assert attrs["dify.message.outputs"].startswith("ref:message_id=") + + +# --------------------------------------------------------------------------- +# _tool_trace +# --------------------------------------------------------------------------- + + +class TestToolTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.TOOL_EXECUTION + + def test_status_is_succeeded_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.tool.status"] == "succeeded" + + def test_status_is_failed_on_error(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info(error="Tool error")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.tool.status"] == "failed" + + def test_records_tool_duration_histogram(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._tool_trace(make_tool_info()) + + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][0] == EnterpriseTelemetryHistogram.TOOL_DURATION + assert hist_call[0][1] == pytest.approx(1.5) + + def test_error_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._tool_trace(make_tool_info(error="Tool crashed")) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.tool.inputs"].startswith("ref:message_id=") + assert attrs["dify.tool.outputs"].startswith("ref:message_id=") + + def test_inputs_present_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.tool.inputs"] == json.dumps({"query": "test"}) + assert attrs["dify.tool.outputs"] == "search results" + + def test_increments_requests_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._tool_trace(make_tool_info()) + + request_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.REQUESTS + ] + assert len(request_calls) == 1 + assert request_calls[0][0][2]["type"] == "tool" + + +# --------------------------------------------------------------------------- +# _moderation_trace +# --------------------------------------------------------------------------- + + +class TestModerationTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MODERATION_CHECK + + def test_flagged_true_sets_attribute(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info(flagged=True)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.moderation.flagged"] is True + + def test_flagged_false_sets_attribute(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info(flagged=False)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.moderation.flagged"] is False + + def test_query_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.moderation.query"].startswith("ref:message_id=") + + def test_query_present_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.moderation.query"] == "is this ok?" + + def test_increments_requests_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._moderation_trace(make_moderation_info()) + + request_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.REQUESTS + ] + assert len(request_calls) == 1 + assert request_calls[0][0][2]["type"] == "moderation" + + +# --------------------------------------------------------------------------- +# _suggested_question_trace +# --------------------------------------------------------------------------- + + +class TestSuggestedQuestionTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION + + def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.duration"] == pytest.approx(5.0) + + def test_duration_is_none_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info(start_time=None, end_time=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.duration"] is None + + def test_status_is_failed_when_error_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info(error="Generation failed")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.status"] == "failed" + + def test_status_falls_back_to_succeeded_when_no_error(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info(status=None, error=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.status"] == "succeeded" + + def test_question_count_attribute(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.count"] == 2 + + def test_questions_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.questions"].startswith("ref:message_id=") + + def test_increments_requests_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._suggested_question_trace(make_suggested_question_info()) + + request_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.REQUESTS + ] + assert len(request_calls) == 1 + assert request_calls[0][0][2]["type"] == "suggested_question" + + +# --------------------------------------------------------------------------- +# _dataset_retrieval_trace +# --------------------------------------------------------------------------- + + +class TestDatasetRetrievalTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.DATASET_RETRIEVAL + + def test_document_count_attribute(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.document_count"] == 1 + + def test_dataset_ids_extracted(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert "ds-001" in attrs["dify.dataset.id"] + + def test_empty_documents_has_zero_count(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[])) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.document_count"] == 0 + + def test_status_succeeded_when_no_error(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.status"] == "succeeded" + + def test_status_failed_when_error_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(error="DB error")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.status"] == "failed" + + def test_embedding_model_attributes_set_when_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.dataset.embedding_providers" in attrs + assert "dify.dataset.embedding_models" in attrs + + def test_no_embedding_model_attributes_when_not_provided(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace( + make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"}) + ) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.dataset.embedding_providers" not in attrs + assert "dify.dataset.embedding_models" not in attrs + + def test_rerank_attributes_set_when_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace( + make_dataset_retrieval_info( + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "rerank_model_provider": "cohere", + "rerank_model_name": "rerank-english", + } + ) + ) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.rerank_provider"] == "cohere" + assert attrs["dify.retrieval.rerank_model"] == "rerank-english" + + def test_no_rerank_attributes_when_not_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace( + make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"}) + ) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.retrieval.rerank_provider" not in attrs + assert "dify.retrieval.rerank_model" not in attrs + + def test_dataset_retrieval_counter_incremented_per_dataset(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + ds_calls = [ + c + for c in mock_exporter.increment_counter.call_args_list + if c[0][0] == EnterpriseTelemetryCounter.DATASET_RETRIEVALS + ] + assert len(ds_calls) == 1 + assert ds_calls[0][0][2]["dataset_id"] == "ds-001" + + def test_no_dataset_retrieval_counter_when_no_documents(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[])) + + ds_calls = [ + c + for c in mock_exporter.increment_counter.call_args_list + if c[0][0] == EnterpriseTelemetryCounter.DATASET_RETRIEVALS + ] + assert len(ds_calls) == 0 + + def test_query_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.query"].startswith("ref:message_id=") + + +# --------------------------------------------------------------------------- +# _generate_name_trace +# --------------------------------------------------------------------------- + + +class TestGenerateNameTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION + + def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.duration"] == pytest.approx(5.0) + + def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info(start_time=None, end_time=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.duration"] is None + + def test_status_succeeded_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.status"] == "succeeded" + + def test_status_failed_when_metadata_has_error(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace( + make_generate_name_info( + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "error": "Name generation failed", + } + ) + ) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.status"] == "failed" + + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.inputs"].startswith("ref:conversation_id=") + assert attrs["dify.generate_name.outputs"].startswith("ref:conversation_id=") + + def test_increments_requests_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._generate_name_trace(make_generate_name_info()) + + request_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.REQUESTS + ] + assert len(request_calls) == 1 + assert request_calls[0][0][2]["type"] == "generate_name" + + +# --------------------------------------------------------------------------- +# _prompt_generation_trace +# --------------------------------------------------------------------------- + + +class TestPromptGenerationTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION + + def test_status_succeeded_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.prompt_generation.status"] == "succeeded" + + def test_status_failed_when_error_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Generation error")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.prompt_generation.status"] == "failed" + + def test_token_counters_incremented(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + counter_names = [c[0][0] for c in mock_exporter.increment_counter.call_args_list] + assert EnterpriseTelemetryCounter.TOKENS in counter_names + assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names + assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names + + def test_records_duration_histogram(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + hist_calls = [ + c + for c in mock_exporter.record_histogram.call_args_list + if c[0][0] == EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION + ] + assert len(hist_calls) == 1 + assert hist_calls[0][0][1] == pytest.approx(3.2) + + def test_total_price_attribute_set_when_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=0.05, currency="USD")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.prompt_generation.total_price"] == pytest.approx(0.05) + assert attrs["dify.prompt_generation.currency"] == "USD" + + def test_no_total_price_attribute_when_none(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.prompt_generation.total_price" not in attrs + + def test_error_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Prompt failed")) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_no_error_counter_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 0 + + def test_instruction_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.prompt_generation.instruction"].startswith("ref:trace_id=") + + def test_operation_type_label_used_in_token_counters(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info(operation_type="code_generate")) + + token_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.TOKENS + ] + assert len(token_calls) == 1 + assert token_calls[0][0][2]["operation_type"] == "code_generate" + + def test_emits_correct_tenant_id(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + assert mock_emit.call_args[1]["tenant_id"] == "tenant-abc" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py b/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py new file mode 100644 index 0000000000..b70c0260d5 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py @@ -0,0 +1,54 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry import event_handlers +from enterprise.telemetry.contracts import TelemetryCase + + +@pytest.fixture +def mock_gateway_emit(): + with patch("core.telemetry.gateway.emit") as mock: + yield mock + + +def test_handle_app_created_calls_task(mock_gateway_emit): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + sender.mode = "chat" + + event_handlers._handle_app_created(sender) + + mock_gateway_emit.assert_called_once_with( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": "tenant-456"}, + payload={"app_id": "app-123", "mode": "chat"}, + ) + + +def test_handle_app_created_no_exporter(mock_gateway_emit): + """Gateway handles exporter availability internally; handler always calls gateway.""" + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + + event_handlers._handle_app_created(sender) + + mock_gateway_emit.assert_called_once() + + +def test_handlers_create_valid_envelopes(mock_gateway_emit): + """Verify handlers pass correct TelemetryCase and payload structure.""" + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + sender.mode = "chat" + + event_handlers._handle_app_created(sender) + + call_kwargs = mock_gateway_emit.call_args[1] + assert call_kwargs["case"] == TelemetryCase.APP_CREATED + assert call_kwargs["context"]["tenant_id"] == "tenant-456" + assert call_kwargs["payload"]["app_id"] == "app-123" + assert call_kwargs["payload"]["mode"] == "chat" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_exporter.py b/api/tests/unit_tests/enterprise/telemetry/test_exporter.py new file mode 100644 index 0000000000..6bdae13923 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_exporter.py @@ -0,0 +1,628 @@ +"""Unit tests for EnterpriseExporter and _ExporterFactory.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from configs.enterprise import EnterpriseTelemetryConfig +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.exporter import EnterpriseExporter, _datetime_to_ns, _parse_otlp_headers + + +def test_config_api_key_default_empty(): + """Test that ENTERPRISE_OTLP_API_KEY defaults to empty string.""" + config = EnterpriseTelemetryConfig() + assert config.ENTERPRISE_OTLP_API_KEY == "" + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_only_injects_bearer_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key alone injects Bearer authorization header.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-secret-key", + ) + + EnterpriseExporter(mock_config) + + # Verify span exporter was called with Bearer header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-secret-key") in headers + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_empty_api_key_no_auth_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that empty API key does not inject authorization header.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify span exporter was called without authorization header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + # Headers should be None or not contain authorization + if headers is not None: + assert not any(key == "authorization" for key, _ in headers) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_and_custom_headers_merge(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key and custom headers are merged correctly.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="x-custom=foo", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify both headers are present + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-key") in headers + assert ("x-custom", "foo") in headers + + +@patch("enterprise.telemetry.exporter.logger") +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_overrides_conflicting_header( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock, mock_logger: MagicMock +) -> None: + """Test that API key overrides conflicting authorization header and logs warning.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="authorization=Basic+old", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify Bearer header takes precedence + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-key") in headers + # Verify old authorization header is not present + assert ("authorization", "Basic old") not in headers + + # Verify warning was logged + mock_logger.warning.assert_called_once() + assert mock_logger.warning.call_args is not None + warning_message = mock_logger.warning.call_args[0][0] + assert "ENTERPRISE_OTLP_API_KEY is set" in warning_message + assert "authorization" in warning_message + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_https_endpoint_uses_secure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that https:// endpoint enables TLS (insecure=False) for gRPC.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=False for both exporters (https:// scheme) + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is False + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is False + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_http_endpoint_uses_insecure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that http:// endpoint uses insecure gRPC (insecure=True).""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for both exporters (http:// scheme) + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True + + +@patch("enterprise.telemetry.exporter.HTTPSpanExporter") +@patch("enterprise.telemetry.exporter.HTTPMetricExporter") +def test_insecure_not_passed_to_http_exporters(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that insecure parameter is not passed to HTTP exporters.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="http", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure kwarg is NOT in HTTP exporter calls + assert mock_span_exporter.call_args is not None + assert "insecure" not in mock_span_exporter.call_args.kwargs + + assert mock_metric_exporter.call_args is not None + assert "insecure" not in mock_metric_exporter.call_args.kwargs + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_with_special_chars_preserved(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key with special characters is preserved without mangling.""" + special_key = "abc+def/ghi=jkl==" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY=special_key, + ) + + EnterpriseExporter(mock_config) + + # Verify special characters are preserved in Bearer header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", f"Bearer {special_key}") in headers + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_no_scheme_localhost_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that endpoint without scheme defaults to insecure for localhost.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="localhost:4317", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for localhost without scheme + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_no_scheme_production_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that endpoint without scheme defaults to insecure (not https://).""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="collector.example.com:4317", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for any endpoint without https:// scheme + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True + + +# --------------------------------------------------------------------------- +# _parse_otlp_headers (line 55 — pair without "=" is skipped) +# --------------------------------------------------------------------------- + + +def test_parse_otlp_headers_empty_returns_empty_dict() -> None: + assert _parse_otlp_headers("") == {} + + +def test_parse_otlp_headers_value_may_contain_equals() -> None: + result = _parse_otlp_headers("token=abc=def==") + assert result == {"token": "abc=def=="} + + +def test_parse_otlp_headers_url_encoded() -> None: + result = _parse_otlp_headers("key=%E4%BD%A0%E5%A5%BD") + + assert result == {"key": "你好"} + + +# --------------------------------------------------------------------------- +# _datetime_to_ns (lines 64-68) +# --------------------------------------------------------------------------- + + +def test_datetime_to_ns_naive_treated_as_utc() -> None: + """Naive datetime must be interpreted as UTC (line 64-65).""" + naive = datetime(2024, 1, 1, 0, 0, 0) # no tzinfo + aware_utc = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + assert _datetime_to_ns(naive) == _datetime_to_ns(aware_utc) + + +def test_datetime_to_ns_tz_aware_converted_to_utc() -> None: + """Timezone-aware datetime must be converted to UTC before computing ns (line 66-67).""" + import zoneinfo + + eastern = zoneinfo.ZoneInfo("America/New_York") + dt_east = datetime(2024, 6, 1, 12, 0, 0, tzinfo=eastern) # UTC-4 in summer + dt_utc = dt_east.astimezone(UTC) + assert _datetime_to_ns(dt_east) == _datetime_to_ns(dt_utc) + + +def test_datetime_to_ns_returns_integer_nanoseconds() -> None: + dt = datetime(2024, 1, 1, 0, 0, 1, tzinfo=UTC) + result = _datetime_to_ns(dt) + # 2024-01-01 00:00:01 UTC = epoch + some_seconds; result should be in nanoseconds + assert isinstance(result, int) + # 1 second past epoch start of 2024 — should be > 1_700_000_000_000_000_000 (rough lower bound) + assert result > 1_700_000_000_000_000_000 + + +# --------------------------------------------------------------------------- +# EnterpriseExporter constructor — include_content property (line 115 / 288-289) +# --------------------------------------------------------------------------- + + +def _make_grpc_config(**overrides) -> SimpleNamespace: + defaults = { + "ENTERPRISE_OTLP_ENDPOINT": "https://collector.example.com", + "ENTERPRISE_OTLP_HEADERS": "", + "ENTERPRISE_OTLP_PROTOCOL": "grpc", + "ENTERPRISE_SERVICE_NAME": "dify", + "ENTERPRISE_OTEL_SAMPLING_RATE": 1.0, + "ENTERPRISE_INCLUDE_CONTENT": True, + "ENTERPRISE_OTLP_API_KEY": "", + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_include_content_true_stored_on_exporter( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """include_content=True is stored as a public attribute (line 115).""" + exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=True)) + assert exporter.include_content is True + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_include_content_false_stored_on_exporter( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """include_content=False is preserved (lines 288-289 path exercised by callers).""" + exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=False)) + assert exporter.include_content is False + + +# --------------------------------------------------------------------------- +# EnterpriseExporter constructor — gRPC setup (lines 64-68 exporter-init path) +# --------------------------------------------------------------------------- + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_grpc_exporter_created_with_correct_endpoint( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """GRPCSpanExporter and GRPCMetricExporter receive the configured endpoint.""" + EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT="https://my-collector:4317")) + + assert mock_span_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317" + assert mock_metric_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317" + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_grpc_exporter_empty_endpoint_passes_none( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """Empty string endpoint is normalised to None for both gRPC exporters.""" + EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT="")) + + assert mock_span_exporter.call_args.kwargs["endpoint"] is None + assert mock_metric_exporter.call_args.kwargs["endpoint"] is None + + +# --------------------------------------------------------------------------- +# EnterpriseExporter.export_span (lines 204-271) +# --------------------------------------------------------------------------- + + +def _make_exporter_with_mock_tracer() -> tuple[EnterpriseExporter, MagicMock, MagicMock]: + """Return (exporter, mock_tracer, mock_span) with OTEL internals fully mocked.""" + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + + mock_tracer = MagicMock() + mock_tracer.start_as_current_span.return_value = mock_span + + with ( + patch("enterprise.telemetry.exporter.GRPCSpanExporter"), + patch("enterprise.telemetry.exporter.GRPCMetricExporter"), + ): + exporter = EnterpriseExporter(_make_grpc_config()) + + exporter._tracer = mock_tracer + return exporter, mock_tracer, mock_span + + +@patch("enterprise.telemetry.exporter.set_correlation_id") +@patch("enterprise.telemetry.exporter.set_span_id_source") +def test_export_span_sets_and_clears_context(mock_set_span: MagicMock, mock_set_corr: MagicMock) -> None: + """export_span sets correlation/span context before the span and clears them in finally.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + exporter.export_span( + name="test.span", + attributes={"k": "v"}, + correlation_id="corr-1", + span_id_source="span-src-1", + ) + + # Context was set at the start of the call + mock_set_corr.assert_any_call("corr-1") + mock_set_span.assert_any_call("span-src-1") + # Context was cleared in finally + mock_set_corr.assert_called_with(None) + mock_set_span.assert_called_with(None) + + +def test_export_span_sets_attributes_on_span() -> None: + """All non-None attribute values are set on the span via set_attribute.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + exporter.export_span( + name="test.span", + attributes={"key1": "value1", "key2": None, "key3": 42}, + ) + + # set_attribute should be called for non-None values only + calls = list(mock_span.set_attribute.call_args_list) + keys_set = {c[0][0] for c in calls} + assert "key1" in keys_set + assert "key3" in keys_set + assert "key2" not in keys_set + + +def test_export_span_no_end_time_uses_end_on_exit() -> None: + """When end_time is None, end_on_exit=True is passed to start_as_current_span.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + exporter.export_span(name="test.span", attributes={}) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["end_on_exit"] is True + + +def test_export_span_with_end_time_calls_span_end() -> None: + """When end_time is provided, span.end() is called with the converted ns timestamp.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + start = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + end = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) + + exporter.export_span(name="test.span", attributes={}, start_time=start, end_time=end) + + mock_span.end.assert_called_once() + end_ns = mock_span.end.call_args.kwargs["end_time"] + assert end_ns == _datetime_to_ns(end) + + +def test_export_span_with_start_time_passed_to_start_as_current_span() -> None: + """When start_time is provided it is converted to ns and passed to start_as_current_span.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + start = datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC) + exporter.export_span(name="test.span", attributes={}, start_time=start) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["start_time"] == _datetime_to_ns(start) + + +def test_export_span_root_span_no_parent_context() -> None: + """When span_id_source == correlation_id the span is root — no parent context.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + uid = "123e4567-e89b-12d3-a456-426614174000" + exporter.export_span( + name="root.span", + attributes={}, + correlation_id=uid, + span_id_source=uid, + ) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["context"] is None + + +def test_export_span_child_span_has_parent_context() -> None: + """When correlation_id != span_id_source the child span gets a parent context.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + corr_uid = "123e4567-e89b-12d3-a456-426614174000" + node_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3" + + exporter.export_span( + name="child.span", + attributes={}, + correlation_id=corr_uid, + span_id_source=node_uid, + ) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["context"] is not None + + +def test_export_span_cross_workflow_parent_context() -> None: + """When parent_span_id_source is set, the cross-workflow parent context is built.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + corr_uid = "123e4567-e89b-12d3-a456-426614174000" + parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3" + + exporter.export_span( + name="cross.span", + attributes={}, + correlation_id=corr_uid, + parent_span_id_source=parent_uid, + ) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["context"] is not None + + +@patch("enterprise.telemetry.exporter.logger") +def test_export_span_logs_exception_on_error(mock_logger: MagicMock) -> None: + """If the span block raises, the exception is logged and context is still cleared.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + mock_tracer.start_as_current_span.side_effect = RuntimeError("boom") + + exporter.export_span(name="bad.span", attributes={}) # must not raise + + mock_logger.exception.assert_called_once() + assert "bad.span" in mock_logger.exception.call_args[0][1] + + +@patch("enterprise.telemetry.exporter.logger") +def test_export_span_invalid_trace_correlation_logs_warning(mock_logger: MagicMock) -> None: + """Invalid UUID for trace_correlation_override triggers a warning log.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3" + exporter.export_span( + name="link.span", + attributes={}, + correlation_id="not-a-valid-uuid", + parent_span_id_source=parent_uid, + ) + + mock_logger.warning.assert_called() + + +# --------------------------------------------------------------------------- +# EnterpriseExporter.increment_counter (lines 276-278) +# --------------------------------------------------------------------------- + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_increment_counter_calls_add_on_counter(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """increment_counter calls .add() on the matching counter instrument.""" + exporter = EnterpriseExporter(_make_grpc_config()) + + mock_counter = MagicMock() + exporter._counters[EnterpriseTelemetryCounter.TOKENS] = mock_counter + + labels = {"tenant_id": "t1", "app_id": "app-1"} + exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 50, labels) + + mock_counter.add.assert_called_once_with(50, labels) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_increment_counter_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """increment_counter silently does nothing when the counter is not found.""" + exporter = EnterpriseExporter(_make_grpc_config()) + exporter._counters.clear() + + # Should not raise + exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 5, {}) + + +# --------------------------------------------------------------------------- +# EnterpriseExporter.record_histogram (lines 283-285) +# --------------------------------------------------------------------------- + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_record_histogram_calls_record_on_histogram( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """record_histogram calls .record() on the matching histogram instrument.""" + exporter = EnterpriseExporter(_make_grpc_config()) + + mock_histogram = MagicMock() + exporter._histograms[EnterpriseTelemetryHistogram.WORKFLOW_DURATION] = mock_histogram + + labels = {"tenant_id": "t1"} + exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 3.14, labels) + + mock_histogram.record.assert_called_once_with(3.14, labels) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_record_histogram_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """record_histogram silently does nothing when the histogram is not found.""" + exporter = EnterpriseExporter(_make_grpc_config()) + exporter._histograms.clear() + + # Should not raise + exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 1.0, {}) diff --git a/api/tests/unit_tests/enterprise/telemetry/test_gateway.py b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py new file mode 100644 index 0000000000..7e6ae64693 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.gateway import ( + CASE_ROUTING, + CASE_TO_TRACE_TASK, + PAYLOAD_SIZE_THRESHOLD_BYTES, + emit, +) +from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope + + +class TestCaseRoutingTable: + def test_all_cases_have_routing(self) -> None: + for case in TelemetryCase: + assert case in CASE_ROUTING, f"Missing routing for {case}" + + def test_trace_cases(self) -> None: + trace_cases = [ + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in trace_cases: + assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace" + + def test_metric_log_cases(self) -> None: + metric_log_cases = [ + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + ] + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log" + + def test_ce_eligible_cases(self) -> None: + ce_eligible_cases = [ + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + ] + for case in ce_eligible_cases: + assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible" + + def test_enterprise_only_cases(self) -> None: + enterprise_only_cases = [ + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in enterprise_only_cases: + assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only" + + def test_trace_cases_have_task_name_mapping(self) -> None: + trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE] + for case in trace_cases: + assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}" + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_trace_case_routes_to_trace_manager( + self, + mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False) + def test_ce_eligible_trace_enqueued_when_ee_disabled( + self, + mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False) + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_enterprise_only_trace_enqueued_when_ee_enabled( + self, + mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayMetricLogRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_metric_case_routes_to_celery_task( + self, + mock_delay: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_envelope_has_unique_event_id( + self, + mock_delay: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc"} + + emit(TelemetryCase.APP_CREATED, context, payload) + emit(TelemetryCase.APP_CREATED, context, payload) + + assert mock_delay.call_count == 2 + envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0]) + envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0]) + assert envelope1.event_id != envelope2.event_id + + +class TestGatewayPayloadSizing: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_small_payload_inlined( + self, + mock_delay: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"key": "small_value"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("core.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_stored( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_storage.save.assert_called_once() + storage_key = mock_storage.save.call_args[0][0] + assert storage_key.startswith("telemetry/tenant-123/") + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == {} + assert envelope.metadata is not None + assert envelope.metadata["payload_ref"] == storage_key + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("core.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_fallback_on_storage_error( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + mock_storage.save.side_effect = Exception("Storage failure") + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + +class TestTraceTaskNameMapping: + def test_workflow_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE + + def test_message_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE + + def test_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE + + def test_draft_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_prompt_generation_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE diff --git a/api/tests/unit_tests/enterprise/telemetry/test_id_generator.py b/api/tests/unit_tests/enterprise/telemetry/test_id_generator.py new file mode 100644 index 0000000000..dc2be14ebf --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_id_generator.py @@ -0,0 +1,201 @@ +"""Unit tests for enterprise/telemetry/id_generator.py.""" + +from __future__ import annotations + +import uuid +from unittest.mock import patch + +# --------------------------------------------------------------------------- +# compute_deterministic_span_id +# --------------------------------------------------------------------------- + + +class TestComputeDeterministicSpanId: + def test_returns_lower_64_bits_of_uuid(self) -> None: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + uid = "123e4567-e89b-12d3-a456-426614174000" + expected = uuid.UUID(uid).int & ((1 << 64) - 1) + assert compute_deterministic_span_id(uid) == expected + + def test_non_zero_result_returned_unchanged(self) -> None: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + # This UUID has non-zero lower 64 bits + uid = "123e4567-e89b-12d3-a456-426614174000" + result = compute_deterministic_span_id(uid) + assert result != 0 + + def test_zero_lower_bits_returns_one(self) -> None: + """When the lower 64 bits of the UUID int are 0, the function must return 1 (OTEL requirement).""" + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + # Craft a UUID whose lower 64 bits are 0: upper 64 bits are 1, lower 64 bits are 0. + # int = (1 << 64), UUID fields constructed to produce this integer. + target_int = 1 << 64 # lower 64 bits are 0x0000000000000000 + crafted_uuid = uuid.UUID(int=target_int) + result = compute_deterministic_span_id(str(crafted_uuid)) + assert result == 1 + + def test_raises_on_invalid_uuid(self) -> None: + import pytest + + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + with pytest.raises((ValueError, AttributeError)): + compute_deterministic_span_id("not-a-uuid") + + def test_different_uuids_produce_different_span_ids(self) -> None: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + uid1 = "123e4567-e89b-12d3-a456-426614174000" + uid2 = "987fbc97-4bed-5078-9f07-9141ba07c9f3" + assert compute_deterministic_span_id(uid1) != compute_deterministic_span_id(uid2) + + def test_deterministic_same_input_same_output(self) -> None: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + uid = "123e4567-e89b-12d3-a456-426614174000" + assert compute_deterministic_span_id(uid) == compute_deterministic_span_id(uid) + + +# --------------------------------------------------------------------------- +# Context variable helpers +# --------------------------------------------------------------------------- + + +class TestContextVariableHelpers: + def test_set_and_get_correlation_id(self) -> None: + from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id + + set_correlation_id("corr-123") + assert get_correlation_id() == "corr-123" + + def test_clear_correlation_id(self) -> None: + from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id + + set_correlation_id("corr-abc") + set_correlation_id(None) + assert get_correlation_id() is None + + def test_correlation_id_default_is_none(self) -> None: + from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id + + set_correlation_id(None) + assert get_correlation_id() is None + + def test_set_span_id_source_stored_in_context(self) -> None: + from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source + + set_span_id_source("span-src-1") + assert _span_id_source_context.get() == "span-src-1" + + def test_clear_span_id_source(self) -> None: + from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source + + set_span_id_source("span-src-1") + set_span_id_source(None) + assert _span_id_source_context.get() is None + + +# --------------------------------------------------------------------------- +# CorrelationIdGenerator.generate_trace_id +# --------------------------------------------------------------------------- + + +class TestCorrelationIdGeneratorGenerateTraceId: + def setup_method(self) -> None: + from enterprise.telemetry.id_generator import set_correlation_id + + set_correlation_id(None) + + def test_returns_uuid_int_when_correlation_id_set(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id + + uid = "123e4567-e89b-12d3-a456-426614174000" + set_correlation_id(uid) + gen = CorrelationIdGenerator() + trace_id = gen.generate_trace_id() + assert trace_id == uuid.UUID(uid).int + + def test_returns_random_when_no_correlation_id(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id + + set_correlation_id(None) + gen = CorrelationIdGenerator() + # Should return a non-zero int without raising + trace_id = gen.generate_trace_id() + assert isinstance(trace_id, int) + assert trace_id > 0 + + def test_returns_random_when_correlation_id_is_invalid_uuid(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id + + set_correlation_id("not-a-valid-uuid") + gen = CorrelationIdGenerator() + with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=42) as mock_rng: + trace_id = gen.generate_trace_id() + mock_rng.assert_called_once_with(128) + assert trace_id == 42 + + +# --------------------------------------------------------------------------- +# CorrelationIdGenerator.generate_span_id +# --------------------------------------------------------------------------- + + +class TestCorrelationIdGeneratorGenerateSpanId: + def setup_method(self) -> None: + from enterprise.telemetry.id_generator import set_span_id_source + + set_span_id_source(None) + + def test_uses_deterministic_span_id_when_source_set(self) -> None: + from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_span_id_source, + ) + + uid = "123e4567-e89b-12d3-a456-426614174000" + set_span_id_source(uid) + gen = CorrelationIdGenerator() + span_id = gen.generate_span_id() + assert span_id == compute_deterministic_span_id(uid) + + def test_returns_random_when_no_source(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source + + set_span_id_source(None) + gen = CorrelationIdGenerator() + span_id = gen.generate_span_id() + assert isinstance(span_id, int) + assert span_id != 0 + + def test_returns_random_when_source_is_invalid_uuid(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source + + set_span_id_source("not-a-uuid") + gen = CorrelationIdGenerator() + with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=7) as mock_rng: + span_id = gen.generate_span_id() + assert span_id == 7 + + def test_random_span_id_retried_if_zero(self) -> None: + """generate_span_id must never return 0 — it retries until non-zero.""" + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source + + set_span_id_source(None) + gen = CorrelationIdGenerator() + # First call returns 0 (invalid), second returns 99 + with patch("enterprise.telemetry.id_generator.random.getrandbits", side_effect=[0, 99]): + span_id = gen.generate_span_id() + assert span_id == 99 + + def test_generate_span_id_always_non_zero(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source + + set_span_id_source(None) + gen = CorrelationIdGenerator() + for _ in range(20): + assert gen.generate_span_id() != 0 diff --git a/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py b/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py new file mode 100644 index 0000000000..56c42a57d5 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py @@ -0,0 +1,511 @@ +"""Unit tests for EnterpriseMetricHandler.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + + +@pytest.fixture +def mock_redis(): + with patch("enterprise.telemetry.metric_handler.redis_client") as mock: + yield mock + + +@pytest.fixture +def sample_envelope(): + return TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123", "name": "Test App"}, + ) + + +def test_dispatch_app_created(sample_envelope, mock_redis): + mock_redis.set.return_value = True + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_created") as mock_handler: + handler.handle(sample_envelope) + mock_handler.assert_called_once_with(sample_envelope) + + +def test_dispatch_app_updated(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id="test-tenant", + event_id="test-event-456", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_updated") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_app_deleted(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id="test-tenant", + event_id="test-event-789", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_deleted") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_feedback_created(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="test-tenant", + event_id="test-event-abc", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_feedback_created") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_message_run(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.MESSAGE_RUN, + tenant_id="test-tenant", + event_id="test-event-msg", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_message_run") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_tool_execution(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.TOOL_EXECUTION, + tenant_id="test-tenant", + event_id="test-event-tool", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_tool_execution") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_moderation_check(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.MODERATION_CHECK, + tenant_id="test-tenant", + event_id="test-event-mod", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_moderation_check") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_suggested_question(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.SUGGESTED_QUESTION, + tenant_id="test-tenant", + event_id="test-event-sq", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_suggested_question") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_dataset_retrieval(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.DATASET_RETRIEVAL, + tenant_id="test-tenant", + event_id="test-event-ds", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_dataset_retrieval") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_generate_name(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.GENERATE_NAME, + tenant_id="test-tenant", + event_id="test-event-gn", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_generate_name") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_prompt_generation(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.PROMPT_GENERATION, + tenant_id="test-tenant", + event_id="test-event-pg", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_prompt_generation") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_all_known_cases_have_handlers(mock_redis): + mock_redis.set.return_value = True + handler = EnterpriseMetricHandler() + + for case in TelemetryCase: + envelope = TelemetryEnvelope( + case=case, + tenant_id="test-tenant", + event_id=f"test-{case.value}", + payload={}, + ) + handler.handle(envelope) + + +def test_idempotency_duplicate(sample_envelope, mock_redis): + mock_redis.set.return_value = None + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_created") as mock_handler: + handler.handle(sample_envelope) + mock_handler.assert_not_called() + + +def test_idempotency_first_seen(sample_envelope, mock_redis): + mock_redis.set.return_value = True + + handler = EnterpriseMetricHandler() + is_dup = handler._is_duplicate(sample_envelope) + + assert is_dup is False + mock_redis.set.assert_called_once_with( + "telemetry:dedup:test-tenant:test-event-123", + b"1", + nx=True, + ex=3600, + ) + + +def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog): + mock_redis.set.side_effect = Exception("Redis unavailable") + + handler = EnterpriseMetricHandler() + is_dup = handler._is_duplicate(sample_envelope) + + assert is_dup is False + assert "Redis unavailable for deduplication check" in caplog.text + + +def test_rehydration_uses_payload(sample_envelope): + handler = EnterpriseMetricHandler() + payload = handler._rehydrate(sample_envelope) + + assert payload == {"app_id": "app-123", "name": "Test App"} + + +def test_rehydration_from_storage(): + """Verify _rehydrate loads payload from object storage via payload_ref.""" + stored_data = {"app_id": "app-stored", "mode": "workflow"} + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-fb", + payload={}, + metadata={"payload_ref": "telemetry/test-tenant/test-event-fb.json"}, + ) + + handler = EnterpriseMetricHandler() + with patch("enterprise.telemetry.metric_handler.storage") as mock_storage: + mock_storage.load.return_value = json.dumps(stored_data).encode("utf-8") + payload = handler._rehydrate(envelope) + + assert payload == stored_data + mock_storage.load.assert_called_once_with("telemetry/test-tenant/test-event-fb.json") + + +def test_rehydration_storage_failure_emits_degraded_event(): + """Verify _rehydrate emits degraded event when storage load fails.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-fail", + payload={}, + metadata={"payload_ref": "telemetry/test-tenant/test-event-fail.json"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("enterprise.telemetry.metric_handler.storage") as mock_storage, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_storage.load.side_effect = Exception("Storage unavailable") + payload = handler._rehydrate(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + assert payload == {} + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED + assert "dify.telemetry.error" in call_args[1]["attributes"] + + +def test_rehydration_emits_degraded_event_on_empty_payload(): + """Verify _rehydrate emits degraded event when payload is empty and no ref exists.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-empty", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit: + payload = handler._rehydrate(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + assert payload == {} + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED + assert "dify.telemetry.error" in call_args[1]["attributes"] + + +def test_on_app_created_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789", "mode": "chat"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_created(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_CREATED + assert call_args[1]["tenant_id"] == "tenant-123" + attrs = call_args[1]["attributes"] + assert attrs["dify.app_id"] == "app-789" + assert attrs["dify.tenant_id"] == "tenant-123" + assert attrs["dify.event.id"] == "event-456" + assert attrs["dify.app.mode"] == "chat" + assert "dify.app.created_at" in attrs + + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + counter_call = mock_exporter.increment_counter.call_args + assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_CREATED + assert counter_call[0][1] == 1 + assert counter_call[0][2]["tenant_id"] == "tenant-123" + assert counter_call[0][2]["app_id"] == "app-789" + assert counter_call[0][2]["mode"] == "chat" + + +def test_on_app_updated_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_updated(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_UPDATED + assert call_args[1]["tenant_id"] == "tenant-123" + attrs = call_args[1]["attributes"] + assert attrs["dify.app_id"] == "app-789" + assert attrs["dify.tenant_id"] == "tenant-123" + assert attrs["dify.event.id"] == "event-456" + assert "dify.app.updated_at" in attrs + + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + counter_call = mock_exporter.increment_counter.call_args + assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_UPDATED + assert counter_call[0][1] == 1 + assert counter_call[0][2]["tenant_id"] == "tenant-123" + assert counter_call[0][2]["app_id"] == "app-789" + + +def test_on_app_deleted_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_deleted(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_DELETED + assert call_args[1]["tenant_id"] == "tenant-123" + attrs = call_args[1]["attributes"] + assert attrs["dify.app_id"] == "app-789" + assert attrs["dify.tenant_id"] == "tenant-123" + assert attrs["dify.event.id"] == "event-456" + assert "dify.app.deleted_at" in attrs + + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + counter_call = mock_exporter.increment_counter.call_args + assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_DELETED + assert counter_call[0][1] == 1 + assert counter_call[0][2]["tenant_id"] == "tenant-123" + assert counter_call[0][2]["app_id"] == "app-789" + + +def test_on_feedback_created_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={ + "message_id": "msg-001", + "app_id": "app-789", + "conversation_id": "conv-123", + "from_end_user_id": "user-456", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great!", + }, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_exporter.include_content = True + mock_get_exporter.return_value = mock_exporter + + handler._on_feedback_created(envelope) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == "dify.feedback.created" + assert call_args[1]["attributes"]["dify.message.id"] == "msg-001" + assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!" + assert "dify.feedback.created_at" in call_args[1]["attributes"] + assert call_args[1]["tenant_id"] == "tenant-123" + assert call_args[1]["user_id"] == "user-456" + + mock_exporter.increment_counter.assert_called_once() + counter_args = mock_exporter.increment_counter.call_args + assert counter_args[0][2]["app_id"] == "app-789" + assert counter_args[0][2]["rating"] == "like" + + +def test_on_feedback_created_without_content(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={ + "message_id": "msg-001", + "app_id": "app-789", + "conversation_id": "conv-123", + "from_end_user_id": "user-456", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great!", + }, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_exporter.include_content = False + mock_get_exporter.return_value = mock_exporter + + handler._on_feedback_created(envelope) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert "dify.feedback.content" not in call_args[1]["attributes"] diff --git a/api/tests/unit_tests/enterprise/telemetry/test_telemetry_log.py b/api/tests/unit_tests/enterprise/telemetry/test_telemetry_log.py new file mode 100644 index 0000000000..0edd0ace27 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_telemetry_log.py @@ -0,0 +1,327 @@ +"""Unit tests for enterprise/telemetry/telemetry_log.py.""" + +from __future__ import annotations + +import uuid +from unittest.mock import MagicMock, patch + +# --------------------------------------------------------------------------- +# compute_trace_id_hex +# --------------------------------------------------------------------------- + + +class TestComputeTraceIdHex: + def setup_method(self) -> None: + # Clear lru_cache between tests to avoid cross-test pollution + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + compute_trace_id_hex.cache_clear() + + def test_none_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + assert compute_trace_id_hex(None) == "" + + def test_empty_string_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + assert compute_trace_id_hex("") == "" + + def test_already_32_hex_chars_returned_as_is(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + hex_id = "a" * 32 + assert compute_trace_id_hex(hex_id) == hex_id + + def test_valid_uuid_string_converted_to_32_hex(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + uid = "123e4567-e89b-12d3-a456-426614174000" + result = compute_trace_id_hex(uid) + assert len(result) == 32 + assert all(ch in "0123456789abcdef" for ch in result) + # Round-trip: int of the UUID should equal the int parsed from result + assert int(result, 16) == uuid.UUID(uid).int + + def test_invalid_string_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + assert compute_trace_id_hex("not-a-uuid") == "" + + def test_whitespace_stripped(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + uid = " 123e4567-e89b-12d3-a456-426614174000 " + result = compute_trace_id_hex(uid) + assert len(result) == 32 + + def test_uppercase_uuid_accepted(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + uid = "123E4567-E89B-12D3-A456-426614174000" + result = compute_trace_id_hex(uid) + assert len(result) == 32 + + def test_result_is_cached(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + uid = "123e4567-e89b-12d3-a456-426614174000" + r1 = compute_trace_id_hex(uid) + r2 = compute_trace_id_hex(uid) + assert r1 == r2 + info = compute_trace_id_hex.cache_info() + assert info.hits >= 1 + + +# --------------------------------------------------------------------------- +# compute_span_id_hex +# --------------------------------------------------------------------------- + + +class TestComputeSpanIdHex: + def setup_method(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + compute_span_id_hex.cache_clear() + + def test_none_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + assert compute_span_id_hex(None) == "" + + def test_empty_string_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + assert compute_span_id_hex("") == "" + + def test_already_16_hex_chars_returned_as_is(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + hex_id = "abcdef0123456789" + assert compute_span_id_hex(hex_id) == hex_id + + def test_valid_uuid_produces_16_hex_span_id(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + uid = "123e4567-e89b-12d3-a456-426614174000" + result = compute_span_id_hex(uid) + assert len(result) == 16 + assert all(ch in "0123456789abcdef" for ch in result) + + def test_invalid_string_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + assert compute_span_id_hex("not-a-uuid-at-all!") == "" + + def test_result_is_cached(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + uid = "123e4567-e89b-12d3-a456-426614174000" + compute_span_id_hex(uid) + compute_span_id_hex(uid) + info = compute_span_id_hex.cache_info() + assert info.hits >= 1 + + +# --------------------------------------------------------------------------- +# emit_telemetry_log +# --------------------------------------------------------------------------- + + +class TestEmitTelemetryLog: + def setup_method(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex + + compute_trace_id_hex.cache_clear() + compute_span_id_hex.cache_clear() + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_logs_info_with_event_name_and_signal(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log( + event_name="dify.workflow.run", + attributes={"tenant_id": "t1"}, + signal="metric_only", + ) + + mock_logger.info.assert_called_once() + args, kwargs = mock_logger.info.call_args + assert args[0] == "telemetry.%s" + assert args[1] == "metric_only" + extra = kwargs["extra"] + assert extra["attributes"]["dify.event.name"] == "dify.workflow.run" + assert extra["attributes"]["dify.event.signal"] == "metric_only" + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_no_log_when_info_disabled(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = False + + emit_telemetry_log(event_name="dify.workflow.run", attributes={}) + + mock_logger.info.assert_not_called() + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_trace_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + uid = "123e4567-e89b-12d3-a456-426614174000" + + emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source=uid) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "trace_id" in extra + assert len(extra["trace_id"]) == 32 + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_trace_id_absent_when_invalid_source(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source="bad-id") + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "trace_id" not in extra + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_span_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + uid = "123e4567-e89b-12d3-a456-426614174000" + + emit_telemetry_log(event_name="test.event", attributes={}, span_id_source=uid) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "span_id" in extra + assert len(extra["span_id"]) == 16 + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_tenant_id_added_when_provided(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}, tenant_id="tenant-99") + + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["tenant_id"] == "tenant-99" + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_user_id_added_when_provided(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}, user_id="user-42") + + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["user_id"] == "user-42" + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_tenant_and_user_id_absent_when_not_provided(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "tenant_id" not in extra + assert "user_id" not in extra + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_caller_attributes_merged_into_attrs(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log( + event_name="dify.node.run", + attributes={"node_type": "code", "elapsed": 0.5}, + ) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["attributes"]["node_type"] == "code" + assert extra["attributes"]["elapsed"] == 0.5 + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_signal_span_detail_forwarded(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}, signal="span_detail") + + args = mock_logger.info.call_args[0] + assert args[1] == "span_detail" + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["attributes"]["dify.event.signal"] == "span_detail" + + +# --------------------------------------------------------------------------- +# emit_metric_only_event +# --------------------------------------------------------------------------- + + +class TestEmitMetricOnlyEvent: + def setup_method(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex + + compute_trace_id_hex.cache_clear() + compute_span_id_hex.cache_clear() + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_delegates_to_emit_telemetry_log_with_metric_only_signal(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + mock_logger.isEnabledFor.return_value = True + + emit_metric_only_event( + event_name="dify.app.created", + attributes={"app_id": "app-1"}, + tenant_id="t1", + user_id="u1", + ) + + mock_logger.info.assert_called_once() + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["attributes"]["dify.event.signal"] == "metric_only" + assert extra["attributes"]["dify.event.name"] == "dify.app.created" + assert extra["attributes"]["app_id"] == "app-1" + assert extra["tenant_id"] == "t1" + assert extra["user_id"] == "u1" + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_trace_and_span_ids_passed_through(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + mock_logger.isEnabledFor.return_value = True + uid = "123e4567-e89b-12d3-a456-426614174000" + + emit_metric_only_event( + event_name="dify.workflow.run", + attributes={}, + trace_id_source=uid, + span_id_source=uid, + ) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "trace_id" in extra + assert "span_id" in extra + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_no_log_emitted_when_logger_disabled(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + mock_logger.isEnabledFor.return_value = False + + emit_metric_only_event(event_name="dify.workflow.run", attributes={}) + + mock_logger.info.assert_not_called() diff --git a/api/tests/unit_tests/events/test_app_event_signals.py b/api/tests/unit_tests/events/test_app_event_signals.py new file mode 100644 index 0000000000..29582a50f6 --- /dev/null +++ b/api/tests/unit_tests/events/test_app_event_signals.py @@ -0,0 +1,206 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_db(): + with patch("services.app_service.db") as mock_db: + mock_db.session = MagicMock() + yield mock_db + + +@pytest.fixture +def _mock_deps(): + with ( + patch("services.app_service.BillingService"), + patch("services.app_service.FeatureService"), + patch("services.app_service.EnterpriseService"), + patch("services.app_service.remove_app_and_related_data_task"), + ): + yield + + +@pytest.fixture +def app_model(): + app = MagicMock() + app.id = "app-123" + app.tenant_id = "tenant-456" + app.name = "Old Name" + app.icon_type = "emoji" + app.icon = "🤖" + app.icon_background = "#fff" + app.enable_site = False + app.enable_api = False + return app + + +def _make_collector(target: list): + def handler(sender, **kw): + target.append(sender) + + return handler + + +@pytest.mark.usefixtures("mock_db", "_mock_deps") +class TestAppWasDeletedSignal: + def test_sends_signal(self, app_model): + from events.app_event import app_was_deleted + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_deleted.connect(handler) + try: + AppService().delete_app(app_model) + finally: + app_was_deleted.disconnect(handler) + + assert received == [app_model] + + def test_signal_fires_before_db_delete(self, app_model, mock_db): + from events.app_event import app_was_deleted + from services.app_service import AppService + + call_order: list[str] = [] + + def handler(sender, **kw): + call_order.append("signal") + + app_was_deleted.connect(handler) + mock_db.session.delete.side_effect = lambda _: call_order.append("db_delete") + + try: + AppService().delete_app(app_model) + finally: + app_was_deleted.disconnect(handler) + + assert call_order.index("signal") < call_order.index("db_delete") + + +@pytest.mark.usefixtures("mock_db") +class TestAppWasUpdatedSignal: + def test_update_app(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + AppService().update_app( + app_model, + { + "name": "New", + "description": "Desc", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#fff", + "use_icon_as_answer_icon": False, + "max_active_requests": 0, + }, + ) + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_name(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + AppService().update_app_name(app_model, "New Name") + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_icon(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + AppService().update_app_icon(app_model, "🎉", "#000") + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_site_status_sends_when_changed(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + app_model.enable_site = False + AppService().update_app_site_status(app_model, True) + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_site_status_skips_when_unchanged(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + try: + app_model.enable_site = True + AppService().update_app_site_status(app_model, True) + finally: + app_was_updated.disconnect(handler) + + assert received == [] + + def test_update_app_api_status_sends_when_changed(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + app_model.enable_api = False + AppService().update_app_api_status(app_model, True) + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_api_status_skips_when_unchanged(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + try: + app_model.enable_api = True + AppService().update_app_api_status(app_model, True) + finally: + app_was_updated.disconnect(handler) + + assert received == [] diff --git a/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py new file mode 100644 index 0000000000..b48c69a146 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py @@ -0,0 +1,69 @@ +"""Unit tests for enterprise telemetry Celery task.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + +@pytest.fixture +def sample_envelope_json(): + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123"}, + ) + return envelope.model_dump_json() + + +def test_process_enterprise_telemetry_success(sample_envelope_json): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + mock_handler.handle.assert_called_once() + call_args = mock_handler.handle.call_args[0][0] + assert isinstance(call_args, TelemetryEnvelope) + assert call_args.case == TelemetryCase.APP_CREATED + assert call_args.tenant_id == "test-tenant" + assert call_args.event_id == "test-event-123" + + +def test_process_enterprise_telemetry_invalid_json(caplog): + invalid_json = "not valid json" + + process_enterprise_telemetry(invalid_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler.handle.side_effect = Exception("Handler error") + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_validation_error(caplog): + invalid_envelope = json.dumps( + { + "case": "INVALID_CASE", + "tenant_id": "test-tenant", + "event_id": "test-event", + "payload": {}, + } + ) + + process_enterprise_telemetry(invalid_envelope) + + assert "Failed to process enterprise telemetry envelope" in caplog.text