feat: enterprise otel exporter (#33138)

Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Xiyuan Chen
2026-03-27 00:56:31 -07:00
committed by GitHub
parent 689761bfcb
commit 5a8a68cab8
51 changed files with 9650 additions and 46 deletions

View File

@ -0,0 +1,554 @@
"""Unit tests for lookup helper functions in core.ops.ops_trace_manager.
Covers:
- _lookup_app_and_workspace_names
- _lookup_credential_name
- _lookup_llm_credential_info
- TraceTask._get_user_id_from_metadata
"""
from unittest.mock import MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_db_and_session_patches(scalar_side_effect=None, scalar_return_value=None):
"""Return (mock_db, cm, session) ready to patch 'core.ops.ops_trace_manager.db'
and 'core.ops.ops_trace_manager.Session'.
Provide either scalar_side_effect (list, for multiple calls) or
scalar_return_value (single value).
"""
mock_db = MagicMock()
mock_db.engine = MagicMock()
session = MagicMock()
if scalar_side_effect is not None:
session.scalar.side_effect = scalar_side_effect
else:
session.scalar.return_value = scalar_return_value
cm = MagicMock()
cm.__enter__ = MagicMock(return_value=session)
cm.__exit__ = MagicMock(return_value=False)
return mock_db, cm, session
# ---------------------------------------------------------------------------
# _lookup_app_and_workspace_names
# ---------------------------------------------------------------------------
class TestLookupAppAndWorkspaceNames:
"""Tests for _lookup_app_and_workspace_names(app_id, tenant_id)."""
def test_both_found(self):
"""Returns (app_name, workspace_name) when both records exist."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", "MyWorkspace"])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
assert app_name == "MyApp"
assert workspace_name == "MyWorkspace"
def test_app_only_found(self):
"""Returns (app_name, '') when tenant record is absent."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", None])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
assert app_name == "MyApp"
assert workspace_name == ""
def test_tenant_only_found(self):
"""Returns ('', workspace_name) when app record is absent."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, "MyWorkspace"])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
assert app_name == ""
assert workspace_name == "MyWorkspace"
def test_neither_found(self):
"""Returns ('', '') when both DB lookups return None."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, None])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
assert app_name == ""
assert workspace_name == ""
def test_none_inputs_skips_db(self):
"""Returns ('', '') immediately when both IDs are None — no DB access."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
app_name, workspace_name = _lookup_app_and_workspace_names(None, None)
mock_session_cls.assert_not_called()
assert app_name == ""
assert workspace_name == ""
def test_app_id_none_only_queries_tenant(self):
"""When app_id is None, only the tenant query is issued."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyWorkspace")
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names(None, "tenant-456")
assert app_name == ""
assert workspace_name == "OnlyWorkspace"
assert session.scalar.call_count == 1
def test_tenant_id_none_only_queries_app(self):
"""When tenant_id is None, only the app query is issued."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyApp")
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", None)
assert app_name == "OnlyApp"
assert workspace_name == ""
assert session.scalar.call_count == 1
# ---------------------------------------------------------------------------
# _lookup_credential_name
# ---------------------------------------------------------------------------
class TestLookupCredentialName:
"""Tests for _lookup_credential_name(credential_id, provider_type)."""
@pytest.mark.parametrize("provider_type", ["builtin", "plugin", "api", "workflow", "mcp"])
def test_known_provider_types_return_name(self, provider_type):
"""Each valid provider_type results in a DB query and returns the credential name."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="CredentialA")
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
result = _lookup_credential_name("cred-123", provider_type)
assert result == "CredentialA"
session.scalar.assert_called_once()
def test_credential_not_found_returns_empty_string(self):
"""Returns '' when DB yields None for the given credential_id."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
result = _lookup_credential_name("cred-999", "api")
assert result == ""
def test_invalid_provider_type_returns_empty_string_without_db(self):
"""Returns '' immediately for an unrecognised provider_type — no DB access."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
result = _lookup_credential_name("cred-123", "unknown_type")
mock_session_cls.assert_not_called()
assert result == ""
def test_none_credential_id_returns_empty_string_without_db(self):
"""Returns '' immediately when credential_id is None — no DB access."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
result = _lookup_credential_name(None, "api")
mock_session_cls.assert_not_called()
assert result == ""
def test_none_provider_type_returns_empty_string_without_db(self):
"""Returns '' immediately when provider_type is None — no DB access."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
result = _lookup_credential_name("cred-123", None)
mock_session_cls.assert_not_called()
assert result == ""
def test_builtin_and_plugin_map_to_same_model(self):
"""Both 'builtin' and 'plugin' provider_types query BuiltinToolProvider."""
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
from models.tools import BuiltinToolProvider
assert _PROVIDER_TYPE_TO_MODEL["builtin"] is BuiltinToolProvider
assert _PROVIDER_TYPE_TO_MODEL["plugin"] is BuiltinToolProvider
def test_api_maps_to_api_tool_provider(self):
"""'api' maps to ApiToolProvider."""
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
from models.tools import ApiToolProvider
assert _PROVIDER_TYPE_TO_MODEL["api"] is ApiToolProvider
def test_workflow_maps_to_workflow_tool_provider(self):
"""'workflow' maps to WorkflowToolProvider."""
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
from models.tools import WorkflowToolProvider
assert _PROVIDER_TYPE_TO_MODEL["workflow"] is WorkflowToolProvider
def test_mcp_maps_to_mcp_tool_provider(self):
"""'mcp' maps to MCPToolProvider."""
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
from models.tools import MCPToolProvider
assert _PROVIDER_TYPE_TO_MODEL["mcp"] is MCPToolProvider
# ---------------------------------------------------------------------------
# _lookup_llm_credential_info
# ---------------------------------------------------------------------------
class TestLookupLlmCredentialInfo:
"""Tests for _lookup_llm_credential_info(tenant_id, provider, model, model_type)."""
def _provider_record(self, credential_id: str | None = None) -> MagicMock:
record = MagicMock()
record.credential_id = credential_id
return record
def _model_record(self, credential_id: str | None = None) -> MagicMock:
record = MagicMock()
record.credential_id = credential_id
return record
def test_model_level_credential_found(self):
"""Returns model-level credential_id and name when ProviderModel has a credential."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id=None)
model_record = self._model_record(credential_id="model-cred-id")
# scalar calls: (1) Provider, (2) ProviderModel, (3) ProviderModelCredential.credential_name
mock_db, cm, _session = _make_db_and_session_patches(
scalar_side_effect=[provider_record, model_record, "ModelCredName"]
)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id == "model-cred-id"
assert cred_name == "ModelCredName"
def test_provider_level_fallback_when_no_model_credential(self):
"""Falls back to provider-level credential when ProviderModel has no credential_id."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id="prov-cred-id")
model_record = self._model_record(credential_id=None)
# scalar calls: (1) Provider, (2) ProviderModel (no cred), (3) ProviderCredential.credential_name
mock_db, cm, _session = _make_db_and_session_patches(
scalar_side_effect=[provider_record, model_record, "ProvCredName"]
)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id == "prov-cred-id"
assert cred_name == "ProvCredName"
def test_provider_level_fallback_when_no_model_record(self):
"""Falls back to provider-level credential when no ProviderModel row exists."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id="prov-cred-id")
# scalar calls: (1) Provider, (2) ProviderModel → None, (3) ProviderCredential.credential_name
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, None, "ProvCredName"])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id == "prov-cred-id"
assert cred_name == "ProvCredName"
def test_no_model_arg_uses_provider_level_only(self):
"""When model is None, skips ProviderModel query and uses provider credential."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id="prov-cred-id")
# scalar calls: (1) Provider, (2) ProviderCredential.credential_name — no ProviderModel
mock_db, cm, session = _make_db_and_session_patches(scalar_side_effect=[provider_record, "ProvCredName"])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", None)
assert cred_id == "prov-cred-id"
assert cred_name == "ProvCredName"
assert session.scalar.call_count == 2
def test_provider_not_found_returns_none_and_empty(self):
"""Returns (None, '') when Provider record does not exist."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id is None
assert cred_name == ""
def test_none_tenant_id_returns_none_and_empty_without_db(self):
"""Returns (None, '') immediately when tenant_id is None — no DB access."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
cred_id, cred_name = _lookup_llm_credential_info(None, "openai", "gpt-4")
mock_session_cls.assert_not_called()
assert cred_id is None
assert cred_name == ""
def test_none_provider_returns_none_and_empty_without_db(self):
"""Returns (None, '') immediately when provider is None — no DB access."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", None, "gpt-4")
mock_session_cls.assert_not_called()
assert cred_id is None
assert cred_name == ""
def test_db_error_on_outer_query_returns_none_and_empty(self):
"""Returns (None, '') and logs a warning when the outer DB query raises."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
mock_db, cm, session = _make_db_and_session_patches()
session.scalar.side_effect = Exception("DB connection failed")
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id is None
assert cred_name == ""
def test_credential_name_lookup_failure_returns_id_with_empty_name(self):
"""When credential name sub-query fails, returns cred_id but '' for name."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id="prov-cred-id")
# Provider found, no model record, then name lookup raises
mock_db, cm, _session = _make_db_and_session_patches(
scalar_side_effect=[provider_record, None, Exception("deleted")]
)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id == "prov-cred-id"
assert cred_name == ""
def test_no_credential_on_provider_or_model_returns_none_id(self):
"""Returns (None, '') when neither provider nor model has a credential_id."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id=None)
model_record = self._model_record(credential_id=None)
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, model_record])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id is None
assert cred_name == ""
# ---------------------------------------------------------------------------
# TraceTask._get_user_id_from_metadata
# ---------------------------------------------------------------------------
class TestGetUserIdFromMetadata:
"""Tests for TraceTask._get_user_id_from_metadata(metadata).
Pure dict logic — no DB access required.
"""
@pytest.fixture
def get_user_id(self):
"""Return the classmethod under test."""
from core.ops.ops_trace_manager import TraceTask
return TraceTask._get_user_id_from_metadata
def test_from_end_user_id_has_highest_priority(self, get_user_id):
"""from_end_user_id takes precedence over all other keys."""
metadata = {
"from_end_user_id": "eu-abc",
"from_account_id": "acc-xyz",
"user_id": "u-123",
}
assert get_user_id(metadata) == "end_user:eu-abc"
def test_from_account_id_used_when_no_end_user(self, get_user_id):
"""from_account_id is used when from_end_user_id is absent."""
metadata = {
"from_account_id": "acc-xyz",
"user_id": "u-123",
}
assert get_user_id(metadata) == "account:acc-xyz"
def test_user_id_used_when_no_end_user_or_account(self, get_user_id):
"""user_id is used when both higher-priority keys are absent."""
metadata = {"user_id": "u-123"}
assert get_user_id(metadata) == "user:u-123"
def test_returns_anonymous_when_all_keys_absent(self, get_user_id):
"""Returns 'anonymous' when metadata has none of the expected keys."""
assert get_user_id({}) == "anonymous"
def test_empty_string_end_user_id_is_skipped(self, get_user_id):
"""Empty string for from_end_user_id is falsy and falls through to next key."""
metadata = {
"from_end_user_id": "",
"from_account_id": "acc-xyz",
}
assert get_user_id(metadata) == "account:acc-xyz"
def test_empty_string_account_id_is_skipped(self, get_user_id):
"""Empty string for from_account_id is falsy and falls through to user_id."""
metadata = {
"from_end_user_id": "",
"from_account_id": "",
"user_id": "u-123",
}
assert get_user_id(metadata) == "user:u-123"
def test_empty_string_user_id_falls_through_to_anonymous(self, get_user_id):
"""Empty string for user_id is falsy, so 'anonymous' is returned."""
metadata = {
"from_end_user_id": "",
"from_account_id": "",
"user_id": "",
}
assert get_user_id(metadata) == "anonymous"
def test_only_from_end_user_id_present(self, get_user_id):
"""Minimal case: only from_end_user_id present."""
assert get_user_id({"from_end_user_id": "eu-only"}) == "end_user:eu-only"
def test_irrelevant_keys_do_not_interfere(self, get_user_id):
"""Extra metadata keys have no effect on the result."""
metadata = {"invoke_from": "web", "app_id": "a1"}
assert get_user_id(metadata) == "anonymous"

View File

@ -86,6 +86,7 @@ def make_message_data(**overrides):
created_at = datetime(2025, 2, 20, 12, 0, 0)
base = {
"id": "msg-id",
"app_id": "app-id",
"conversation_id": "conv-id",
"created_at": created_at,
"updated_at": created_at + timedelta(seconds=3),
@ -182,6 +183,9 @@ class DummySessionContext:
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def execute(self, *args, **kwargs):
return self
def scalar(self, *args, **kwargs):
if self._index >= len(self._values):
return None
@ -189,6 +193,12 @@ class DummySessionContext:
self._index += 1
return value
def scalars(self, *args, **kwargs):
return self
def all(self):
return []
@pytest.fixture(autouse=True)
def patch_provider_map(monkeypatch):
@ -454,7 +464,7 @@ def test_trace_task_message_trace(trace_task_message, mock_db):
def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db):
DummySessionContext.scalar_values = ["wf-app-log", "message-ref"]
execution = SimpleNamespace(id_="run-id")
execution = SimpleNamespace(id_="run-id", total_tokens=0)
task = TraceTask(
trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user"
)

View File

@ -0,0 +1,194 @@
"""Unit tests for TraceQueueManager telemetry guard.
Verifies that TraceQueueManager.add_trace_task() only enqueues tasks when at
least one consumer is active:
- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR
- A third-party trace instance (Langfuse, etc.) is configured
When neither is active, tasks are silently dropped to avoid unnecessary work.
When BOTH are false, tasks are silently dropped (correct behavior).
"""
import queue
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def trace_queue_manager_and_task(monkeypatch):
"""Fixture to provide TraceQueueManager and TraceTask with delayed imports."""
module_name = "core.ops.ops_trace_manager"
if module_name not in sys.modules:
ops_stub = types.ModuleType(module_name)
class StubTraceTask:
def __init__(self, trace_type):
self.trace_type = trace_type
self.app_id = None
class StubTraceQueueManager:
def __init__(self, app_id=None):
self.app_id = app_id
from core.telemetry.gateway import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
def add_trace_task(self, trace_task):
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
from core.ops.ops_trace_manager import trace_manager_queue
trace_manager_queue.put(trace_task)
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id):
return None
ops_stub.TraceQueueManager = StubTraceQueueManager
ops_stub.TraceTask = StubTraceTask
ops_stub.OpsTraceManager = StubOpsTraceManager
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.ops.entities.trace_entity import TraceTaskName
ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"])
TraceQueueManager = ops_module.TraceQueueManager
TraceTask = ops_module.TraceTask
return TraceQueueManager, TraceTask, TraceTaskName
class TestTraceQueueManagerTelemetryGuard:
"""Test TraceQueueManager's telemetry guard in add_trace_task()."""
def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task):
"""Verify task is NOT enqueued when telemetry disabled and no trace instance.
This is the core guard: when _enterprise_telemetry_enabled=False AND
trace_instance=None, the task should be silently dropped.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_not_called()
def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when enterprise telemetry is enabled.
When _enterprise_telemetry_enabled=True, the task should be enqueued
regardless of trace_instance state.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when third-party trace instance is configured.
When trace_instance is not None (e.g., Langfuse configured), the task
should be enqueued even if enterprise telemetry is disabled.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
mock_trace_instance = MagicMock()
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False),
patch(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when both telemetry and trace instance are enabled.
When both _enterprise_telemetry_enabled=True AND trace_instance is set,
the task should definitely be enqueued.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
mock_trace_instance = MagicMock()
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task):
"""Verify app_id is set on the task before enqueuing.
The guard logic sets trace_task.app_id = self.app_id before calling
trace_manager_queue.put(trace_task). This test verifies that behavior.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="expected-app-id")
manager.add_trace_task(trace_task)
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "expected-app-id"

View File

@ -0,0 +1,181 @@
"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering."""
from __future__ import annotations
import queue
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
@pytest.fixture
def telemetry_test_setup(monkeypatch):
module_name = "core.ops.ops_trace_manager"
ops_stub = types.ModuleType(module_name)
class StubTraceTask:
def __init__(self, trace_type, **kwargs):
self.trace_type = trace_type
self.app_id = None
self.kwargs = kwargs
class StubTraceQueueManager:
def __init__(self, app_id=None, user_id=None):
self.app_id = app_id
self.user_id = user_id
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
def add_trace_task(self, trace_task):
trace_task.app_id = self.app_id
from core.ops.ops_trace_manager import trace_manager_queue
trace_manager_queue.put(trace_task)
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id):
return None
ops_stub.TraceQueueManager = StubTraceQueueManager
ops_stub.TraceTask = StubTraceTask
ops_stub.OpsTraceManager = StubOpsTraceManager
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.telemetry import emit
return emit, ops_stub.trace_manager_queue
class TestTelemetryEmit:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_enterprise_trace_creates_trace_task(self, mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={"key": "value"},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
def test_emit_community_trace_enqueued(self, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.WORKFLOW_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_called_once()
def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_not_called()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
enterprise_only_traces = [
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TraceTaskName.NODE_EXECUTION_TRACE,
TraceTaskName.PROMPT_GENERATION_TRACE,
]
for trace_name in enterprise_only_traces:
mock_queue.reset_mock()
event = TelemetryEvent(
name=trace_name,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == trace_name
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_passes_name_directly_to_trace_task(self, mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={"extra": "data"},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
assert isinstance(called_task.trace_type, TraceTaskName)
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_with_provided_trace_manager(self, mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
mock_trace_manager = MagicMock()
mock_trace_manager.add_trace_task = MagicMock()
event = TelemetryEvent(
name=TraceTaskName.NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event, trace_manager=mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
called_task = mock_trace_manager.add_trace_task.call_args[0][0]
assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE

View File

@ -0,0 +1,225 @@
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled
from enterprise.telemetry.contracts import TelemetryCase
class TestTelemetryCoreExports:
def test_is_enterprise_telemetry_enabled_exported(self) -> None:
from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func
assert callable(exported_func)
@pytest.fixture
def mock_ops_trace_manager():
mock_module = MagicMock()
mock_trace_task_class = MagicMock()
mock_trace_task_class.return_value = MagicMock()
mock_module.TraceTask = mock_trace_task_class
mock_module.TraceQueueManager = MagicMock()
mock_trace_entity = MagicMock()
mock_trace_task_name = MagicMock()
mock_trace_task_name.return_value = "workflow"
mock_trace_entity.TraceTaskName = mock_trace_task_name
with (
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
):
yield mock_module, mock_trace_entity
class TestGatewayIntegrationTraceRouting:
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_ce_eligible_trace_routed_to_trace_manager(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True):
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_ce_eligible_trace_routed_when_ee_disabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_enterprise_only_trace_dropped_when_ee_disabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_enterprise_only_trace_routed_when_ee_enabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayIntegrationMetricRouting:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_metric_case_routes_to_celery_task(
self,
mock_ee_enabled: MagicMock,
) -> None:
from enterprise.telemetry.contracts import TelemetryEnvelope
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc", "name": "My App"}
emit(TelemetryCase.APP_CREATED, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.APP_CREATED
assert envelope.tenant_id == "tenant-123"
assert envelope.payload["app_id"] == "app-abc"
@pytest.mark.usefixtures("mock_ops_trace_manager")
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_tool_execution_trace_routed(
self,
mock_ee_enabled: MagicMock,
) -> None:
mock_trace_manager = MagicMock()
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"}
emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_moderation_check_trace_routed(
self,
mock_ee_enabled: MagicMock,
) -> None:
mock_trace_manager = MagicMock()
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}}
emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayIntegrationCEEligibility:
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_workflow_run_is_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_message_run_is_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"message_id": "msg-abc", "conversation_id": "conv-123"}
emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_node_execution_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_draft_node_execution_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_execution_data": {}}
emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_prompt_generation_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"operation_type": "generate", "instruction": "test"}
emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
class TestIsEnterpriseTelemetryEnabled:
def test_returns_false_when_exporter_import_fails(self) -> None:
with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}):
result = is_enterprise_telemetry_enabled()
assert result is False
def test_function_is_callable(self) -> None:
assert callable(is_enterprise_telemetry_enabled)

View File

@ -0,0 +1,230 @@
"""Unit tests for telemetry gateway contracts."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from core.telemetry.gateway import CASE_ROUTING
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
class TestTelemetryCase:
"""Tests for TelemetryCase enum."""
def test_all_cases_defined(self) -> None:
"""Verify all 14 telemetry cases are defined."""
expected_cases = {
"WORKFLOW_RUN",
"NODE_EXECUTION",
"DRAFT_NODE_EXECUTION",
"MESSAGE_RUN",
"TOOL_EXECUTION",
"MODERATION_CHECK",
"SUGGESTED_QUESTION",
"DATASET_RETRIEVAL",
"GENERATE_NAME",
"PROMPT_GENERATION",
"APP_CREATED",
"APP_UPDATED",
"APP_DELETED",
"FEEDBACK_CREATED",
}
actual_cases = {case.name for case in TelemetryCase}
assert actual_cases == expected_cases
def test_case_values(self) -> None:
"""Verify case enum values are correct."""
assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run"
assert TelemetryCase.NODE_EXECUTION.value == "node_execution"
assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution"
assert TelemetryCase.MESSAGE_RUN.value == "message_run"
assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution"
assert TelemetryCase.MODERATION_CHECK.value == "moderation_check"
assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question"
assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval"
assert TelemetryCase.GENERATE_NAME.value == "generate_name"
assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation"
assert TelemetryCase.APP_CREATED.value == "app_created"
assert TelemetryCase.APP_UPDATED.value == "app_updated"
assert TelemetryCase.APP_DELETED.value == "app_deleted"
assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created"
class TestCaseRoute:
"""Tests for CaseRoute model."""
def test_valid_trace_route(self) -> None:
"""Verify valid trace route creation."""
route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True)
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is True
def test_valid_metric_log_route(self) -> None:
"""Verify valid metric_log route creation."""
route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False)
assert route.signal_type == SignalType.METRIC_LOG
assert route.ce_eligible is False
def test_invalid_signal_type(self) -> None:
"""Verify invalid signal_type is rejected."""
with pytest.raises(ValidationError):
CaseRoute(signal_type="invalid", ce_eligible=True)
class TestTelemetryEnvelope:
"""Tests for TelemetryEnvelope model."""
def test_valid_envelope_minimal(self) -> None:
"""Verify valid minimal envelope creation."""
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
)
assert envelope.case == TelemetryCase.WORKFLOW_RUN
assert envelope.tenant_id == "tenant-123"
assert envelope.event_id == "event-456"
assert envelope.payload == {"key": "value"}
assert envelope.metadata is None
def test_valid_envelope_full(self) -> None:
"""Verify valid envelope with all fields."""
metadata = {"payload_ref": "telemetry/tenant-789/event-012.json"}
envelope = TelemetryEnvelope(
case=TelemetryCase.MESSAGE_RUN,
tenant_id="tenant-789",
event_id="event-012",
payload={"message": "hello"},
metadata=metadata,
)
assert envelope.case == TelemetryCase.MESSAGE_RUN
assert envelope.tenant_id == "tenant-789"
assert envelope.event_id == "event-012"
assert envelope.payload == {"message": "hello"}
assert envelope.metadata == metadata
def test_missing_required_case(self) -> None:
"""Verify missing case field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
)
def test_missing_required_tenant_id(self) -> None:
"""Verify missing tenant_id field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
event_id="event-456",
payload={"key": "value"},
)
def test_missing_required_event_id(self) -> None:
"""Verify missing event_id field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
payload={"key": "value"},
)
def test_missing_required_payload(self) -> None:
"""Verify missing payload field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
)
def test_metadata_none(self) -> None:
"""Verify metadata can be None."""
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
metadata=None,
)
assert envelope.metadata is None
class TestCaseRouting:
"""Tests for CASE_ROUTING table."""
def test_all_cases_routed(self) -> None:
"""Verify all 14 cases have routing entries."""
assert len(CASE_ROUTING) == 14
for case in TelemetryCase:
assert case in CASE_ROUTING
def test_trace_ce_eligible_cases(self) -> None:
"""Verify trace cases with CE eligibility."""
ce_eligible_trace_cases = {
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
}
for case in ce_eligible_trace_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is True
def test_trace_enterprise_only_cases(self) -> None:
"""Verify trace cases that are enterprise-only."""
enterprise_only_trace_cases = {
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
}
for case in enterprise_only_trace_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is False
def test_metric_log_cases(self) -> None:
"""Verify metric/log-only cases."""
metric_log_cases = {
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
}
for case in metric_log_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.METRIC_LOG
assert route.ce_eligible is False
def test_routing_table_completeness(self) -> None:
"""Verify routing table covers all cases with correct types."""
trace_cases = {
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
TelemetryCase.TOOL_EXECUTION,
TelemetryCase.MODERATION_CHECK,
TelemetryCase.SUGGESTED_QUESTION,
TelemetryCase.DATASET_RETRIEVAL,
TelemetryCase.GENERATE_NAME,
}
metric_log_cases = {
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
}
all_cases = trace_cases | metric_log_cases
assert len(all_cases) == 14
assert all_cases == set(TelemetryCase)
for case in trace_cases:
assert CASE_ROUTING[case].signal_type == SignalType.TRACE
for case in metric_log_cases:
assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG

View File

@ -0,0 +1,519 @@
"""Unit tests for enterprise/telemetry/draft_trace.py."""
from __future__ import annotations
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from graphon.enums import WorkflowNodeExecutionMetadataKey
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_execution(**overrides) -> MagicMock:
"""Return a minimal WorkflowNodeExecutionModel mock."""
execution = MagicMock()
execution.tenant_id = overrides.get("tenant_id", "tenant-1")
execution.app_id = overrides.get("app_id", "app-1")
execution.workflow_id = overrides.get("workflow_id", "wf-1")
execution.id = overrides.get("id", "exec-1")
execution.node_id = overrides.get("node_id", "node-1")
execution.node_type = overrides.get("node_type", "llm")
execution.title = overrides.get("title", "My LLM Node")
execution.status = overrides.get("status", "succeeded")
execution.error = overrides.get("error")
execution.elapsed_time = overrides.get("elapsed_time", 1.5)
execution.index = overrides.get("index", 1)
execution.predecessor_node_id = overrides.get("predecessor_node_id")
execution.created_at = overrides.get("created_at", datetime(2024, 1, 1, tzinfo=UTC))
execution.finished_at = overrides.get("finished_at", datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC))
execution.workflow_run_id = overrides.get("workflow_run_id", "run-1")
execution.inputs_dict = overrides.get("inputs_dict", {"prompt": "hello"})
execution.outputs_dict = overrides.get("outputs_dict", {"answer": "world"})
execution.process_data_dict = overrides.get("process_data_dict", {})
execution.execution_metadata_dict = overrides.get("execution_metadata_dict", {})
return execution
# ---------------------------------------------------------------------------
# _build_node_execution_data
# ---------------------------------------------------------------------------
class TestBuildNodeExecutionData:
def test_basic_fields_populated(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution()
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id="run-override",
)
assert result["workflow_id"] == "wf-1"
assert result["tenant_id"] == "tenant-1"
assert result["app_id"] == "app-1"
assert result["node_execution_id"] == "exec-1"
assert result["node_id"] == "node-1"
assert result["node_type"] == "llm"
assert result["title"] == "My LLM Node"
assert result["status"] == "succeeded"
assert result["error"] is None
assert result["elapsed_time"] == 1.5
assert result["index"] == 1
def test_workflow_execution_id_prefers_parameter(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(workflow_run_id="run-from-model")
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id="explicit-run",
)
assert result["workflow_execution_id"] == "explicit-run"
def test_workflow_execution_id_falls_back_to_run_id(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(workflow_run_id="run-from-model")
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id=None,
)
assert result["workflow_execution_id"] == "run-from-model"
def test_workflow_execution_id_falls_back_to_execution_id(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(workflow_run_id=None, id="exec-fallback")
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id=None,
)
assert result["workflow_execution_id"] == "exec-fallback"
def test_outputs_param_overrides_execution_outputs(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(outputs_dict={"from_model": True})
result = _build_node_execution_data(
execution=execution,
outputs={"from_param": True},
workflow_execution_id=None,
)
assert result["node_outputs"] == {"from_param": True}
def test_outputs_none_uses_execution_outputs_dict(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(outputs_dict={"from_model": True})
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id=None,
)
assert result["node_outputs"] == {"from_model": True}
def test_metadata_token_fields_default_to_zero(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(execution_metadata_dict={})
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["total_tokens"] == 0
assert result["total_price"] == 0.0
assert result["currency"] is None
def test_metadata_token_fields_populated_from_metadata(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
metadata = {
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 200,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.05,
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
}
execution = _make_execution(execution_metadata_dict=metadata)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["total_tokens"] == 200
assert result["total_price"] == 0.05
assert result["currency"] == "USD"
def test_tool_name_extracted_from_tool_info_dict(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
metadata = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"tool_name": "web_search"},
}
execution = _make_execution(execution_metadata_dict=metadata)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["tool_name"] == "web_search"
def test_tool_name_is_none_when_tool_info_not_dict(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: "not-a-dict"}
execution = _make_execution(execution_metadata_dict=metadata)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["tool_name"] is None
def test_tool_name_is_none_when_tool_info_absent(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(execution_metadata_dict={})
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["tool_name"] is None
def test_iteration_and_loop_fields(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: "iter-1",
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: 3,
WorkflowNodeExecutionMetadataKey.LOOP_ID: "loop-1",
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: 2,
WorkflowNodeExecutionMetadataKey.PARALLEL_ID: "par-1",
}
execution = _make_execution(execution_metadata_dict=metadata)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["iteration_id"] == "iter-1"
assert result["iteration_index"] == 3
assert result["loop_id"] == "loop-1"
assert result["loop_index"] == 2
assert result["parallel_id"] == "par-1"
def test_node_inputs_and_process_data_included(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(
inputs_dict={"q": "test"},
process_data_dict={"step": 1},
)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["node_inputs"] == {"q": "test"}
assert result["process_data"] == {"step": 1}
# ---------------------------------------------------------------------------
# enqueue_draft_node_execution_trace
# ---------------------------------------------------------------------------
class TestEnqueueDraftNodeExecutionTrace:
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
def test_emits_telemetry_event(self, mock_emit: MagicMock) -> None:
from core.telemetry import TelemetryEvent, TraceTaskName
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
execution = _make_execution()
enqueue_draft_node_execution_trace(
execution=execution,
outputs={"result": "ok"},
workflow_execution_id="run-x",
user_id="user-1",
)
mock_emit.assert_called_once()
event: TelemetryEvent = mock_emit.call_args[0][0]
assert event.name == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
assert event.context.tenant_id == "tenant-1"
assert event.context.user_id == "user-1"
assert event.context.app_id == "app-1"
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
def test_payload_contains_node_execution_data(self, mock_emit: MagicMock) -> None:
from core.telemetry import TelemetryEvent
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
execution = _make_execution()
enqueue_draft_node_execution_trace(
execution=execution,
outputs=None,
workflow_execution_id=None,
user_id="user-2",
)
event: TelemetryEvent = mock_emit.call_args[0][0]
node_data = event.payload["node_execution_data"]
assert node_data["workflow_id"] == "wf-1"
assert node_data["node_type"] == "llm"
assert node_data["status"] == "succeeded"
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
def test_outputs_forwarded_to_build(self, mock_emit: MagicMock) -> None:
from core.telemetry import TelemetryEvent
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
execution = _make_execution(outputs_dict={"default": True})
enqueue_draft_node_execution_trace(
execution=execution,
outputs={"explicit": True},
workflow_execution_id=None,
user_id="user-3",
)
event: TelemetryEvent = mock_emit.call_args[0][0]
assert event.payload["node_execution_data"]["node_outputs"] == {"explicit": True}
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
def test_none_outputs_uses_execution_outputs(self, mock_emit: MagicMock) -> None:
from core.telemetry import TelemetryEvent
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
execution = _make_execution(outputs_dict={"from_model": "yes"})
enqueue_draft_node_execution_trace(
execution=execution,
outputs=None,
workflow_execution_id=None,
user_id="user-4",
)
event: TelemetryEvent = mock_emit.call_args[0][0]
assert event.payload["node_execution_data"]["node_outputs"] == {"from_model": "yes"}
# ---------------------------------------------------------------------------
# End-to-end token/model data flow: _build_node_execution_data →
# ops_trace_manager.draft_node_execution_trace → DraftNodeExecutionTrace
# ---------------------------------------------------------------------------
def _make_llm_execution() -> MagicMock:
"""Return a WorkflowNodeExecutionModel mock that mimics a real LLM node.
The field values match what graphon/nodes/llm/node.py produces:
- process_data_dict contains model_provider, model_name, and usage
- outputs_dict contains usage with prompt/completion breakdown
- execution_metadata_dict contains total_tokens/total_price/currency
"""
return _make_execution(
tenant_id="tenant-flow",
app_id="app-flow",
workflow_id="wf-flow",
id="exec-flow",
node_id="node-llm",
node_type="llm",
title="GPT-4o Node",
status="succeeded",
elapsed_time=2.3,
workflow_run_id=None,
process_data_dict={
"model_mode": "chat",
"model_provider": "openai",
"model_name": "gpt-4o",
"prompts": [{"role": "user", "text": "hello"}],
"usage": {
"prompt_tokens": 50,
"prompt_unit_price": 0.00001,
"prompt_price_unit": 0.001,
"prompt_price": 0.0005,
"completion_tokens": 30,
"completion_unit_price": 0.00003,
"completion_price_unit": 0.001,
"completion_price": 0.0009,
"total_tokens": 80,
"total_price": 0.0014,
"currency": "USD",
"latency": 2.3,
},
"finish_reason": "stop",
},
outputs_dict={
"text": "world",
"usage": {
"prompt_tokens": 50,
"completion_tokens": 30,
"total_tokens": 80,
"total_price": 0.0014,
"currency": "USD",
},
"finish_reason": "stop",
},
execution_metadata_dict={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 80,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.0014,
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
},
)
class TestDraftTraceTokenDataFlow:
"""End-to-end test: verify all token and model fields survive from
_build_node_execution_data through ops_trace_manager.draft_node_execution_trace
to the DraftNodeExecutionTrace that enterprise_trace.py consumes.
"""
def test_all_token_and_model_fields_reach_trace_info(self) -> None:
"""Simulate the full draft trace data flow for an LLM node and
assert every token/model field that enterprise_trace._emit_node_execution_trace
reads is populated correctly on the resulting DraftNodeExecutionTrace."""
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_llm_execution()
node_data = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id="run-flow",
)
# Simulate what ops_trace_manager.draft_node_execution_trace does:
# it calls node_execution_trace(node_execution_data=node_data) which
# reads top-level keys from node_data. Verify all expected keys exist.
expected_keys = {
# Token fields — read by enterprise_trace._emit_node_execution_trace
"total_tokens",
"total_price",
"currency",
"prompt_tokens",
"completion_tokens",
# Model fields — read for span attrs and metric labels
"model_provider",
"model_name",
# Node identity — read for span attrs
"node_type",
"node_execution_id",
"node_id",
"title",
"status",
"error",
"elapsed_time",
# Workflow context
"workflow_id",
"workflow_execution_id",
"tenant_id",
"app_id",
# Structure fields
"index",
"predecessor_node_id",
"iteration_id",
"iteration_index",
"loop_id",
"loop_index",
"parallel_id",
# Tool field
"tool_name",
# Content fields
"node_inputs",
"node_outputs",
"process_data",
# Timestamps
"created_at",
"finished_at",
}
assert set(node_data.keys()) == expected_keys
# Verify token/model values are correct (not None/zero when data exists)
assert node_data["total_tokens"] == 80
assert node_data["total_price"] == 0.0014
assert node_data["currency"] == "USD"
assert node_data["prompt_tokens"] == 50
assert node_data["completion_tokens"] == 30
assert node_data["model_provider"] == "openai"
assert node_data["model_name"] == "gpt-4o"
assert node_data["node_type"] == "llm"
def test_non_llm_node_has_none_for_model_and_token_breakdown(self) -> None:
"""For non-LLM nodes (e.g. code, IF), model and token breakdown
should be None, but total_tokens from metadata should still work."""
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(
node_type="code",
process_data_dict={"code": "print('hi')"},
outputs_dict={"result": "hi"},
execution_metadata_dict={},
)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["model_provider"] is None
assert result["model_name"] is None
assert result["prompt_tokens"] is None
assert result["completion_tokens"] is None
assert result["total_tokens"] == 0
def test_none_process_data_and_none_outputs(self) -> None:
"""Both process_data_dict and outputs_dict are None — exercises
the `or {}` fallback and isinstance guard together."""
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(process_data_dict=None, outputs_dict=None)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["model_provider"] is None
assert result["model_name"] is None
assert result["prompt_tokens"] is None
assert result["completion_tokens"] is None
def test_node_data_feeds_into_draft_node_execution_trace(self) -> None:
"""Verify the node_data dict can be consumed by
ops_trace_manager.draft_node_execution_trace without error and
produces a DraftNodeExecutionTrace with correct token/model fields."""
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_llm_execution()
node_data = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id="run-e2e",
)
# Directly construct DraftNodeExecutionTrace the way
# ops_trace_manager.node_execution_trace does (lines 1315-1350),
# skipping DB lookups by providing minimal metadata.
from core.ops.entities.trace_entity import DraftNodeExecutionTrace
trace_info = DraftNodeExecutionTrace(
workflow_id=node_data.get("workflow_id", ""),
workflow_run_id=node_data.get("workflow_execution_id", ""),
tenant_id=node_data.get("tenant_id", ""),
node_execution_id=node_data.get("node_execution_id", ""),
node_id=node_data.get("node_id", ""),
node_type=node_data.get("node_type", ""),
title=node_data.get("title", ""),
status=node_data.get("status", ""),
error=node_data.get("error"),
elapsed_time=node_data.get("elapsed_time", 0.0),
index=node_data.get("index", 0),
predecessor_node_id=node_data.get("predecessor_node_id"),
total_tokens=node_data.get("total_tokens", 0),
total_price=node_data.get("total_price", 0.0),
currency=node_data.get("currency"),
model_provider=node_data.get("model_provider"),
model_name=node_data.get("model_name"),
prompt_tokens=node_data.get("prompt_tokens"),
completion_tokens=node_data.get("completion_tokens"),
tool_name=node_data.get("tool_name"),
iteration_id=node_data.get("iteration_id"),
iteration_index=node_data.get("iteration_index"),
loop_id=node_data.get("loop_id"),
loop_index=node_data.get("loop_index"),
parallel_id=node_data.get("parallel_id"),
node_inputs=node_data.get("node_inputs"),
node_outputs=node_data.get("node_outputs"),
process_data=node_data.get("process_data"),
start_time=node_data.get("created_at"),
end_time=node_data.get("finished_at"),
metadata={},
)
# These are the fields enterprise_trace._emit_node_execution_trace reads
assert trace_info.total_tokens == 80
assert trace_info.prompt_tokens == 50
assert trace_info.completion_tokens == 30
assert trace_info.model_provider == "openai"
assert trace_info.model_name == "gpt-4o"
assert trace_info.node_type == "llm"
assert trace_info.total_price == 0.0014
assert trace_info.currency == "USD"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,54 @@
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry import event_handlers
from enterprise.telemetry.contracts import TelemetryCase
@pytest.fixture
def mock_gateway_emit():
with patch("core.telemetry.gateway.emit") as mock:
yield mock
def test_handle_app_created_calls_task(mock_gateway_emit):
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
sender.mode = "chat"
event_handlers._handle_app_created(sender)
mock_gateway_emit.assert_called_once_with(
case=TelemetryCase.APP_CREATED,
context={"tenant_id": "tenant-456"},
payload={"app_id": "app-123", "mode": "chat"},
)
def test_handle_app_created_no_exporter(mock_gateway_emit):
"""Gateway handles exporter availability internally; handler always calls gateway."""
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
event_handlers._handle_app_created(sender)
mock_gateway_emit.assert_called_once()
def test_handlers_create_valid_envelopes(mock_gateway_emit):
"""Verify handlers pass correct TelemetryCase and payload structure."""
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
sender.mode = "chat"
event_handlers._handle_app_created(sender)
call_kwargs = mock_gateway_emit.call_args[1]
assert call_kwargs["case"] == TelemetryCase.APP_CREATED
assert call_kwargs["context"]["tenant_id"] == "tenant-456"
assert call_kwargs["payload"]["app_id"] == "app-123"
assert call_kwargs["payload"]["mode"] == "chat"

View File

@ -0,0 +1,628 @@
"""Unit tests for EnterpriseExporter and _ExporterFactory."""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from configs.enterprise import EnterpriseTelemetryConfig
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
from enterprise.telemetry.exporter import EnterpriseExporter, _datetime_to_ns, _parse_otlp_headers
def test_config_api_key_default_empty():
"""Test that ENTERPRISE_OTLP_API_KEY defaults to empty string."""
config = EnterpriseTelemetryConfig()
assert config.ENTERPRISE_OTLP_API_KEY == ""
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_api_key_only_injects_bearer_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that API key alone injects Bearer authorization header."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-secret-key",
)
EnterpriseExporter(mock_config)
# Verify span exporter was called with Bearer header
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
assert headers is not None
assert ("authorization", "Bearer test-secret-key") in headers
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_empty_api_key_no_auth_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that empty API key does not inject authorization header."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="",
)
EnterpriseExporter(mock_config)
# Verify span exporter was called without authorization header
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
# Headers should be None or not contain authorization
if headers is not None:
assert not any(key == "authorization" for key, _ in headers)
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_api_key_and_custom_headers_merge(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that API key and custom headers are merged correctly."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="x-custom=foo",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-key",
)
EnterpriseExporter(mock_config)
# Verify both headers are present
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
assert headers is not None
assert ("authorization", "Bearer test-key") in headers
assert ("x-custom", "foo") in headers
@patch("enterprise.telemetry.exporter.logger")
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_api_key_overrides_conflicting_header(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock, mock_logger: MagicMock
) -> None:
"""Test that API key overrides conflicting authorization header and logs warning."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="authorization=Basic+old",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-key",
)
EnterpriseExporter(mock_config)
# Verify Bearer header takes precedence
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
assert headers is not None
assert ("authorization", "Bearer test-key") in headers
# Verify old authorization header is not present
assert ("authorization", "Basic old") not in headers
# Verify warning was logged
mock_logger.warning.assert_called_once()
assert mock_logger.warning.call_args is not None
warning_message = mock_logger.warning.call_args[0][0]
assert "ENTERPRISE_OTLP_API_KEY is set" in warning_message
assert "authorization" in warning_message
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_https_endpoint_uses_secure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that https:// endpoint enables TLS (insecure=False) for gRPC."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-key",
)
EnterpriseExporter(mock_config)
# Verify insecure=False for both exporters (https:// scheme)
assert mock_span_exporter.call_args is not None
assert mock_span_exporter.call_args.kwargs["insecure"] is False
assert mock_metric_exporter.call_args is not None
assert mock_metric_exporter.call_args.kwargs["insecure"] is False
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_http_endpoint_uses_insecure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that http:// endpoint uses insecure gRPC (insecure=True)."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="",
)
EnterpriseExporter(mock_config)
# Verify insecure=True for both exporters (http:// scheme)
assert mock_span_exporter.call_args is not None
assert mock_span_exporter.call_args.kwargs["insecure"] is True
assert mock_metric_exporter.call_args is not None
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
@patch("enterprise.telemetry.exporter.HTTPSpanExporter")
@patch("enterprise.telemetry.exporter.HTTPMetricExporter")
def test_insecure_not_passed_to_http_exporters(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that insecure parameter is not passed to HTTP exporters."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="http",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-key",
)
EnterpriseExporter(mock_config)
# Verify insecure kwarg is NOT in HTTP exporter calls
assert mock_span_exporter.call_args is not None
assert "insecure" not in mock_span_exporter.call_args.kwargs
assert mock_metric_exporter.call_args is not None
assert "insecure" not in mock_metric_exporter.call_args.kwargs
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_api_key_with_special_chars_preserved(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that API key with special characters is preserved without mangling."""
special_key = "abc+def/ghi=jkl=="
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY=special_key,
)
EnterpriseExporter(mock_config)
# Verify special characters are preserved in Bearer header
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
assert headers is not None
assert ("authorization", f"Bearer {special_key}") in headers
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_no_scheme_localhost_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that endpoint without scheme defaults to insecure for localhost."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="localhost:4317",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="",
)
EnterpriseExporter(mock_config)
# Verify insecure=True for localhost without scheme
assert mock_span_exporter.call_args is not None
assert mock_span_exporter.call_args.kwargs["insecure"] is True
assert mock_metric_exporter.call_args is not None
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_no_scheme_production_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that endpoint without scheme defaults to insecure (not https://)."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="collector.example.com:4317",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="",
)
EnterpriseExporter(mock_config)
# Verify insecure=True for any endpoint without https:// scheme
assert mock_span_exporter.call_args is not None
assert mock_span_exporter.call_args.kwargs["insecure"] is True
assert mock_metric_exporter.call_args is not None
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
# ---------------------------------------------------------------------------
# _parse_otlp_headers (line 55 — pair without "=" is skipped)
# ---------------------------------------------------------------------------
def test_parse_otlp_headers_empty_returns_empty_dict() -> None:
assert _parse_otlp_headers("") == {}
def test_parse_otlp_headers_value_may_contain_equals() -> None:
result = _parse_otlp_headers("token=abc=def==")
assert result == {"token": "abc=def=="}
def test_parse_otlp_headers_url_encoded() -> None:
result = _parse_otlp_headers("key=%E4%BD%A0%E5%A5%BD")
assert result == {"key": "你好"}
# ---------------------------------------------------------------------------
# _datetime_to_ns (lines 64-68)
# ---------------------------------------------------------------------------
def test_datetime_to_ns_naive_treated_as_utc() -> None:
"""Naive datetime must be interpreted as UTC (line 64-65)."""
naive = datetime(2024, 1, 1, 0, 0, 0) # no tzinfo
aware_utc = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
assert _datetime_to_ns(naive) == _datetime_to_ns(aware_utc)
def test_datetime_to_ns_tz_aware_converted_to_utc() -> None:
"""Timezone-aware datetime must be converted to UTC before computing ns (line 66-67)."""
import zoneinfo
eastern = zoneinfo.ZoneInfo("America/New_York")
dt_east = datetime(2024, 6, 1, 12, 0, 0, tzinfo=eastern) # UTC-4 in summer
dt_utc = dt_east.astimezone(UTC)
assert _datetime_to_ns(dt_east) == _datetime_to_ns(dt_utc)
def test_datetime_to_ns_returns_integer_nanoseconds() -> None:
dt = datetime(2024, 1, 1, 0, 0, 1, tzinfo=UTC)
result = _datetime_to_ns(dt)
# 2024-01-01 00:00:01 UTC = epoch + some_seconds; result should be in nanoseconds
assert isinstance(result, int)
# 1 second past epoch start of 2024 — should be > 1_700_000_000_000_000_000 (rough lower bound)
assert result > 1_700_000_000_000_000_000
# ---------------------------------------------------------------------------
# EnterpriseExporter constructor — include_content property (line 115 / 288-289)
# ---------------------------------------------------------------------------
def _make_grpc_config(**overrides) -> SimpleNamespace:
defaults = {
"ENTERPRISE_OTLP_ENDPOINT": "https://collector.example.com",
"ENTERPRISE_OTLP_HEADERS": "",
"ENTERPRISE_OTLP_PROTOCOL": "grpc",
"ENTERPRISE_SERVICE_NAME": "dify",
"ENTERPRISE_OTEL_SAMPLING_RATE": 1.0,
"ENTERPRISE_INCLUDE_CONTENT": True,
"ENTERPRISE_OTLP_API_KEY": "",
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_include_content_true_stored_on_exporter(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""include_content=True is stored as a public attribute (line 115)."""
exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=True))
assert exporter.include_content is True
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_include_content_false_stored_on_exporter(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""include_content=False is preserved (lines 288-289 path exercised by callers)."""
exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=False))
assert exporter.include_content is False
# ---------------------------------------------------------------------------
# EnterpriseExporter constructor — gRPC setup (lines 64-68 exporter-init path)
# ---------------------------------------------------------------------------
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_grpc_exporter_created_with_correct_endpoint(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""GRPCSpanExporter and GRPCMetricExporter receive the configured endpoint."""
EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT="https://my-collector:4317"))
assert mock_span_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317"
assert mock_metric_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317"
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_grpc_exporter_empty_endpoint_passes_none(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""Empty string endpoint is normalised to None for both gRPC exporters."""
EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT=""))
assert mock_span_exporter.call_args.kwargs["endpoint"] is None
assert mock_metric_exporter.call_args.kwargs["endpoint"] is None
# ---------------------------------------------------------------------------
# EnterpriseExporter.export_span (lines 204-271)
# ---------------------------------------------------------------------------
def _make_exporter_with_mock_tracer() -> tuple[EnterpriseExporter, MagicMock, MagicMock]:
"""Return (exporter, mock_tracer, mock_span) with OTEL internals fully mocked."""
mock_span = MagicMock()
mock_span.__enter__ = MagicMock(return_value=mock_span)
mock_span.__exit__ = MagicMock(return_value=False)
mock_tracer = MagicMock()
mock_tracer.start_as_current_span.return_value = mock_span
with (
patch("enterprise.telemetry.exporter.GRPCSpanExporter"),
patch("enterprise.telemetry.exporter.GRPCMetricExporter"),
):
exporter = EnterpriseExporter(_make_grpc_config())
exporter._tracer = mock_tracer
return exporter, mock_tracer, mock_span
@patch("enterprise.telemetry.exporter.set_correlation_id")
@patch("enterprise.telemetry.exporter.set_span_id_source")
def test_export_span_sets_and_clears_context(mock_set_span: MagicMock, mock_set_corr: MagicMock) -> None:
"""export_span sets correlation/span context before the span and clears them in finally."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
exporter.export_span(
name="test.span",
attributes={"k": "v"},
correlation_id="corr-1",
span_id_source="span-src-1",
)
# Context was set at the start of the call
mock_set_corr.assert_any_call("corr-1")
mock_set_span.assert_any_call("span-src-1")
# Context was cleared in finally
mock_set_corr.assert_called_with(None)
mock_set_span.assert_called_with(None)
def test_export_span_sets_attributes_on_span() -> None:
"""All non-None attribute values are set on the span via set_attribute."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
exporter.export_span(
name="test.span",
attributes={"key1": "value1", "key2": None, "key3": 42},
)
# set_attribute should be called for non-None values only
calls = list(mock_span.set_attribute.call_args_list)
keys_set = {c[0][0] for c in calls}
assert "key1" in keys_set
assert "key3" in keys_set
assert "key2" not in keys_set
def test_export_span_no_end_time_uses_end_on_exit() -> None:
"""When end_time is None, end_on_exit=True is passed to start_as_current_span."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
exporter.export_span(name="test.span", attributes={})
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["end_on_exit"] is True
def test_export_span_with_end_time_calls_span_end() -> None:
"""When end_time is provided, span.end() is called with the converted ns timestamp."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
start = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
end = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
exporter.export_span(name="test.span", attributes={}, start_time=start, end_time=end)
mock_span.end.assert_called_once()
end_ns = mock_span.end.call_args.kwargs["end_time"]
assert end_ns == _datetime_to_ns(end)
def test_export_span_with_start_time_passed_to_start_as_current_span() -> None:
"""When start_time is provided it is converted to ns and passed to start_as_current_span."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
start = datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC)
exporter.export_span(name="test.span", attributes={}, start_time=start)
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["start_time"] == _datetime_to_ns(start)
def test_export_span_root_span_no_parent_context() -> None:
"""When span_id_source == correlation_id the span is root — no parent context."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
uid = "123e4567-e89b-12d3-a456-426614174000"
exporter.export_span(
name="root.span",
attributes={},
correlation_id=uid,
span_id_source=uid,
)
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["context"] is None
def test_export_span_child_span_has_parent_context() -> None:
"""When correlation_id != span_id_source the child span gets a parent context."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
corr_uid = "123e4567-e89b-12d3-a456-426614174000"
node_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
exporter.export_span(
name="child.span",
attributes={},
correlation_id=corr_uid,
span_id_source=node_uid,
)
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["context"] is not None
def test_export_span_cross_workflow_parent_context() -> None:
"""When parent_span_id_source is set, the cross-workflow parent context is built."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
corr_uid = "123e4567-e89b-12d3-a456-426614174000"
parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
exporter.export_span(
name="cross.span",
attributes={},
correlation_id=corr_uid,
parent_span_id_source=parent_uid,
)
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["context"] is not None
@patch("enterprise.telemetry.exporter.logger")
def test_export_span_logs_exception_on_error(mock_logger: MagicMock) -> None:
"""If the span block raises, the exception is logged and context is still cleared."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
mock_tracer.start_as_current_span.side_effect = RuntimeError("boom")
exporter.export_span(name="bad.span", attributes={}) # must not raise
mock_logger.exception.assert_called_once()
assert "bad.span" in mock_logger.exception.call_args[0][1]
@patch("enterprise.telemetry.exporter.logger")
def test_export_span_invalid_trace_correlation_logs_warning(mock_logger: MagicMock) -> None:
"""Invalid UUID for trace_correlation_override triggers a warning log."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
exporter.export_span(
name="link.span",
attributes={},
correlation_id="not-a-valid-uuid",
parent_span_id_source=parent_uid,
)
mock_logger.warning.assert_called()
# ---------------------------------------------------------------------------
# EnterpriseExporter.increment_counter (lines 276-278)
# ---------------------------------------------------------------------------
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_increment_counter_calls_add_on_counter(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""increment_counter calls .add() on the matching counter instrument."""
exporter = EnterpriseExporter(_make_grpc_config())
mock_counter = MagicMock()
exporter._counters[EnterpriseTelemetryCounter.TOKENS] = mock_counter
labels = {"tenant_id": "t1", "app_id": "app-1"}
exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 50, labels)
mock_counter.add.assert_called_once_with(50, labels)
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_increment_counter_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""increment_counter silently does nothing when the counter is not found."""
exporter = EnterpriseExporter(_make_grpc_config())
exporter._counters.clear()
# Should not raise
exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 5, {})
# ---------------------------------------------------------------------------
# EnterpriseExporter.record_histogram (lines 283-285)
# ---------------------------------------------------------------------------
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_record_histogram_calls_record_on_histogram(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""record_histogram calls .record() on the matching histogram instrument."""
exporter = EnterpriseExporter(_make_grpc_config())
mock_histogram = MagicMock()
exporter._histograms[EnterpriseTelemetryHistogram.WORKFLOW_DURATION] = mock_histogram
labels = {"tenant_id": "t1"}
exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 3.14, labels)
mock_histogram.record.assert_called_once_with(3.14, labels)
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_record_histogram_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""record_histogram silently does nothing when the histogram is not found."""
exporter = EnterpriseExporter(_make_grpc_config())
exporter._histograms.clear()
# Should not raise
exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 1.0, {})

View File

@ -0,0 +1,272 @@
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.gateway import (
CASE_ROUTING,
CASE_TO_TRACE_TASK,
PAYLOAD_SIZE_THRESHOLD_BYTES,
emit,
)
from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope
class TestCaseRoutingTable:
def test_all_cases_have_routing(self) -> None:
for case in TelemetryCase:
assert case in CASE_ROUTING, f"Missing routing for {case}"
def test_trace_cases(self) -> None:
trace_cases = [
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
]
for case in trace_cases:
assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace"
def test_metric_log_cases(self) -> None:
metric_log_cases = [
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
]
for case in metric_log_cases:
assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log"
def test_ce_eligible_cases(self) -> None:
ce_eligible_cases = [
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.TOOL_EXECUTION,
TelemetryCase.MODERATION_CHECK,
TelemetryCase.SUGGESTED_QUESTION,
TelemetryCase.DATASET_RETRIEVAL,
TelemetryCase.GENERATE_NAME,
]
for case in ce_eligible_cases:
assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible"
def test_enterprise_only_cases(self) -> None:
enterprise_only_cases = [
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
]
for case in enterprise_only_cases:
assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only"
def test_trace_cases_have_task_name_mapping(self) -> None:
trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE]
for case in trace_cases:
assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}"
@pytest.fixture
def mock_ops_trace_manager():
mock_module = MagicMock()
mock_trace_task_class = MagicMock()
mock_trace_task_class.return_value = MagicMock()
mock_module.TraceTask = mock_trace_task_class
mock_module.TraceQueueManager = MagicMock()
mock_trace_entity = MagicMock()
mock_trace_task_name = MagicMock()
mock_trace_task_name.return_value = "workflow"
mock_trace_entity.TraceTaskName = mock_trace_task_name
with (
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
):
yield mock_module, mock_trace_entity
class TestGatewayTraceRouting:
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_trace_case_routes_to_trace_manager(
self,
mock_ee_enabled: MagicMock,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False)
def test_ce_eligible_trace_enqueued_when_ee_disabled(
self,
mock_ee_enabled: MagicMock,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False)
def test_enterprise_only_trace_dropped_when_ee_disabled(
self,
mock_ee_enabled: MagicMock,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_enterprise_only_trace_enqueued_when_ee_enabled(
self,
mock_ee_enabled: MagicMock,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayMetricLogRouting:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_metric_case_routes_to_celery_task(
self,
mock_delay: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc", "name": "My App"}
emit(TelemetryCase.APP_CREATED, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.APP_CREATED
assert envelope.tenant_id == "tenant-123"
assert envelope.payload["app_id"] == "app-abc"
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_envelope_has_unique_event_id(
self,
mock_delay: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc"}
emit(TelemetryCase.APP_CREATED, context, payload)
emit(TelemetryCase.APP_CREATED, context, payload)
assert mock_delay.call_count == 2
envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0])
envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0])
assert envelope1.event_id != envelope2.event_id
class TestGatewayPayloadSizing:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_small_payload_inlined(
self,
mock_delay: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"key": "small_value"}
emit(TelemetryCase.APP_CREATED, context, payload)
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == payload
assert envelope.metadata is None
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("core.telemetry.gateway.storage")
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_large_payload_stored(
self,
mock_delay: MagicMock,
mock_storage: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
context = {"tenant_id": "tenant-123"}
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
payload = {"key": large_value}
emit(TelemetryCase.APP_CREATED, context, payload)
mock_storage.save.assert_called_once()
storage_key = mock_storage.save.call_args[0][0]
assert storage_key.startswith("telemetry/tenant-123/")
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == {}
assert envelope.metadata is not None
assert envelope.metadata["payload_ref"] == storage_key
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("core.telemetry.gateway.storage")
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_large_payload_fallback_on_storage_error(
self,
mock_delay: MagicMock,
mock_storage: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
mock_storage.save.side_effect = Exception("Storage failure")
context = {"tenant_id": "tenant-123"}
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
payload = {"key": large_value}
emit(TelemetryCase.APP_CREATED, context, payload)
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == payload
assert envelope.metadata is None
class TestTraceTaskNameMapping:
def test_workflow_run_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE
def test_message_run_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE
def test_node_execution_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE
def test_draft_node_execution_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
def test_prompt_generation_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE

View File

@ -0,0 +1,201 @@
"""Unit tests for enterprise/telemetry/id_generator.py."""
from __future__ import annotations
import uuid
from unittest.mock import patch
# ---------------------------------------------------------------------------
# compute_deterministic_span_id
# ---------------------------------------------------------------------------
class TestComputeDeterministicSpanId:
def test_returns_lower_64_bits_of_uuid(self) -> None:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
uid = "123e4567-e89b-12d3-a456-426614174000"
expected = uuid.UUID(uid).int & ((1 << 64) - 1)
assert compute_deterministic_span_id(uid) == expected
def test_non_zero_result_returned_unchanged(self) -> None:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
# This UUID has non-zero lower 64 bits
uid = "123e4567-e89b-12d3-a456-426614174000"
result = compute_deterministic_span_id(uid)
assert result != 0
def test_zero_lower_bits_returns_one(self) -> None:
"""When the lower 64 bits of the UUID int are 0, the function must return 1 (OTEL requirement)."""
from enterprise.telemetry.id_generator import compute_deterministic_span_id
# Craft a UUID whose lower 64 bits are 0: upper 64 bits are 1, lower 64 bits are 0.
# int = (1 << 64), UUID fields constructed to produce this integer.
target_int = 1 << 64 # lower 64 bits are 0x0000000000000000
crafted_uuid = uuid.UUID(int=target_int)
result = compute_deterministic_span_id(str(crafted_uuid))
assert result == 1
def test_raises_on_invalid_uuid(self) -> None:
import pytest
from enterprise.telemetry.id_generator import compute_deterministic_span_id
with pytest.raises((ValueError, AttributeError)):
compute_deterministic_span_id("not-a-uuid")
def test_different_uuids_produce_different_span_ids(self) -> None:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
uid1 = "123e4567-e89b-12d3-a456-426614174000"
uid2 = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
assert compute_deterministic_span_id(uid1) != compute_deterministic_span_id(uid2)
def test_deterministic_same_input_same_output(self) -> None:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
uid = "123e4567-e89b-12d3-a456-426614174000"
assert compute_deterministic_span_id(uid) == compute_deterministic_span_id(uid)
# ---------------------------------------------------------------------------
# Context variable helpers
# ---------------------------------------------------------------------------
class TestContextVariableHelpers:
def test_set_and_get_correlation_id(self) -> None:
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
set_correlation_id("corr-123")
assert get_correlation_id() == "corr-123"
def test_clear_correlation_id(self) -> None:
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
set_correlation_id("corr-abc")
set_correlation_id(None)
assert get_correlation_id() is None
def test_correlation_id_default_is_none(self) -> None:
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
set_correlation_id(None)
assert get_correlation_id() is None
def test_set_span_id_source_stored_in_context(self) -> None:
from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source
set_span_id_source("span-src-1")
assert _span_id_source_context.get() == "span-src-1"
def test_clear_span_id_source(self) -> None:
from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source
set_span_id_source("span-src-1")
set_span_id_source(None)
assert _span_id_source_context.get() is None
# ---------------------------------------------------------------------------
# CorrelationIdGenerator.generate_trace_id
# ---------------------------------------------------------------------------
class TestCorrelationIdGeneratorGenerateTraceId:
def setup_method(self) -> None:
from enterprise.telemetry.id_generator import set_correlation_id
set_correlation_id(None)
def test_returns_uuid_int_when_correlation_id_set(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
uid = "123e4567-e89b-12d3-a456-426614174000"
set_correlation_id(uid)
gen = CorrelationIdGenerator()
trace_id = gen.generate_trace_id()
assert trace_id == uuid.UUID(uid).int
def test_returns_random_when_no_correlation_id(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
set_correlation_id(None)
gen = CorrelationIdGenerator()
# Should return a non-zero int without raising
trace_id = gen.generate_trace_id()
assert isinstance(trace_id, int)
assert trace_id > 0
def test_returns_random_when_correlation_id_is_invalid_uuid(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
set_correlation_id("not-a-valid-uuid")
gen = CorrelationIdGenerator()
with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=42) as mock_rng:
trace_id = gen.generate_trace_id()
mock_rng.assert_called_once_with(128)
assert trace_id == 42
# ---------------------------------------------------------------------------
# CorrelationIdGenerator.generate_span_id
# ---------------------------------------------------------------------------
class TestCorrelationIdGeneratorGenerateSpanId:
def setup_method(self) -> None:
from enterprise.telemetry.id_generator import set_span_id_source
set_span_id_source(None)
def test_uses_deterministic_span_id_when_source_set(self) -> None:
from enterprise.telemetry.id_generator import (
CorrelationIdGenerator,
compute_deterministic_span_id,
set_span_id_source,
)
uid = "123e4567-e89b-12d3-a456-426614174000"
set_span_id_source(uid)
gen = CorrelationIdGenerator()
span_id = gen.generate_span_id()
assert span_id == compute_deterministic_span_id(uid)
def test_returns_random_when_no_source(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
set_span_id_source(None)
gen = CorrelationIdGenerator()
span_id = gen.generate_span_id()
assert isinstance(span_id, int)
assert span_id != 0
def test_returns_random_when_source_is_invalid_uuid(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
set_span_id_source("not-a-uuid")
gen = CorrelationIdGenerator()
with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=7) as mock_rng:
span_id = gen.generate_span_id()
assert span_id == 7
def test_random_span_id_retried_if_zero(self) -> None:
"""generate_span_id must never return 0 — it retries until non-zero."""
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
set_span_id_source(None)
gen = CorrelationIdGenerator()
# First call returns 0 (invalid), second returns 99
with patch("enterprise.telemetry.id_generator.random.getrandbits", side_effect=[0, 99]):
span_id = gen.generate_span_id()
assert span_id == 99
def test_generate_span_id_always_non_zero(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
set_span_id_source(None)
gen = CorrelationIdGenerator()
for _ in range(20):
assert gen.generate_span_id() != 0

View File

@ -0,0 +1,511 @@
"""Unit tests for EnterpriseMetricHandler."""
import json
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
@pytest.fixture
def mock_redis():
with patch("enterprise.telemetry.metric_handler.redis_client") as mock:
yield mock
@pytest.fixture
def sample_envelope():
return TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-123",
payload={"app_id": "app-123", "name": "Test App"},
)
def test_dispatch_app_created(sample_envelope, mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_created") as mock_handler:
handler.handle(sample_envelope)
mock_handler.assert_called_once_with(sample_envelope)
def test_dispatch_app_updated(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_UPDATED,
tenant_id="test-tenant",
event_id="test-event-456",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_updated") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_app_deleted(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_DELETED,
tenant_id="test-tenant",
event_id="test-event-789",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_deleted") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_feedback_created(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="test-tenant",
event_id="test-event-abc",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_feedback_created") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_message_run(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.MESSAGE_RUN,
tenant_id="test-tenant",
event_id="test-event-msg",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_message_run") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_tool_execution(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.TOOL_EXECUTION,
tenant_id="test-tenant",
event_id="test-event-tool",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_tool_execution") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_moderation_check(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.MODERATION_CHECK,
tenant_id="test-tenant",
event_id="test-event-mod",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_moderation_check") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_suggested_question(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.SUGGESTED_QUESTION,
tenant_id="test-tenant",
event_id="test-event-sq",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_suggested_question") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_dataset_retrieval(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.DATASET_RETRIEVAL,
tenant_id="test-tenant",
event_id="test-event-ds",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_dataset_retrieval") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_generate_name(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.GENERATE_NAME,
tenant_id="test-tenant",
event_id="test-event-gn",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_generate_name") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_prompt_generation(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.PROMPT_GENERATION,
tenant_id="test-tenant",
event_id="test-event-pg",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_prompt_generation") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_all_known_cases_have_handlers(mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
for case in TelemetryCase:
envelope = TelemetryEnvelope(
case=case,
tenant_id="test-tenant",
event_id=f"test-{case.value}",
payload={},
)
handler.handle(envelope)
def test_idempotency_duplicate(sample_envelope, mock_redis):
mock_redis.set.return_value = None
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_created") as mock_handler:
handler.handle(sample_envelope)
mock_handler.assert_not_called()
def test_idempotency_first_seen(sample_envelope, mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
is_dup = handler._is_duplicate(sample_envelope)
assert is_dup is False
mock_redis.set.assert_called_once_with(
"telemetry:dedup:test-tenant:test-event-123",
b"1",
nx=True,
ex=3600,
)
def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog):
mock_redis.set.side_effect = Exception("Redis unavailable")
handler = EnterpriseMetricHandler()
is_dup = handler._is_duplicate(sample_envelope)
assert is_dup is False
assert "Redis unavailable for deduplication check" in caplog.text
def test_rehydration_uses_payload(sample_envelope):
handler = EnterpriseMetricHandler()
payload = handler._rehydrate(sample_envelope)
assert payload == {"app_id": "app-123", "name": "Test App"}
def test_rehydration_from_storage():
"""Verify _rehydrate loads payload from object storage via payload_ref."""
stored_data = {"app_id": "app-stored", "mode": "workflow"}
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-fb",
payload={},
metadata={"payload_ref": "telemetry/test-tenant/test-event-fb.json"},
)
handler = EnterpriseMetricHandler()
with patch("enterprise.telemetry.metric_handler.storage") as mock_storage:
mock_storage.load.return_value = json.dumps(stored_data).encode("utf-8")
payload = handler._rehydrate(envelope)
assert payload == stored_data
mock_storage.load.assert_called_once_with("telemetry/test-tenant/test-event-fb.json")
def test_rehydration_storage_failure_emits_degraded_event():
"""Verify _rehydrate emits degraded event when storage load fails."""
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-fail",
payload={},
metadata={"payload_ref": "telemetry/test-tenant/test-event-fail.json"},
)
handler = EnterpriseMetricHandler()
with (
patch("enterprise.telemetry.metric_handler.storage") as mock_storage,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_storage.load.side_effect = Exception("Storage unavailable")
payload = handler._rehydrate(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
assert payload == {}
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED
assert "dify.telemetry.error" in call_args[1]["attributes"]
def test_rehydration_emits_degraded_event_on_empty_payload():
"""Verify _rehydrate emits degraded event when payload is empty and no ref exists."""
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-empty",
payload={},
)
handler = EnterpriseMetricHandler()
with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit:
payload = handler._rehydrate(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
assert payload == {}
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED
assert "dify.telemetry.error" in call_args[1]["attributes"]
def test_on_app_created_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789", "mode": "chat"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_created(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_CREATED
assert call_args[1]["tenant_id"] == "tenant-123"
attrs = call_args[1]["attributes"]
assert attrs["dify.app_id"] == "app-789"
assert attrs["dify.tenant_id"] == "tenant-123"
assert attrs["dify.event.id"] == "event-456"
assert attrs["dify.app.mode"] == "chat"
assert "dify.app.created_at" in attrs
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
mock_exporter.increment_counter.assert_called_once()
counter_call = mock_exporter.increment_counter.call_args
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_CREATED
assert counter_call[0][1] == 1
assert counter_call[0][2]["tenant_id"] == "tenant-123"
assert counter_call[0][2]["app_id"] == "app-789"
assert counter_call[0][2]["mode"] == "chat"
def test_on_app_updated_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_UPDATED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_updated(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_UPDATED
assert call_args[1]["tenant_id"] == "tenant-123"
attrs = call_args[1]["attributes"]
assert attrs["dify.app_id"] == "app-789"
assert attrs["dify.tenant_id"] == "tenant-123"
assert attrs["dify.event.id"] == "event-456"
assert "dify.app.updated_at" in attrs
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
mock_exporter.increment_counter.assert_called_once()
counter_call = mock_exporter.increment_counter.call_args
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_UPDATED
assert counter_call[0][1] == 1
assert counter_call[0][2]["tenant_id"] == "tenant-123"
assert counter_call[0][2]["app_id"] == "app-789"
def test_on_app_deleted_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_DELETED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_deleted(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_DELETED
assert call_args[1]["tenant_id"] == "tenant-123"
attrs = call_args[1]["attributes"]
assert attrs["dify.app_id"] == "app-789"
assert attrs["dify.tenant_id"] == "tenant-123"
assert attrs["dify.event.id"] == "event-456"
assert "dify.app.deleted_at" in attrs
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
mock_exporter.increment_counter.assert_called_once()
counter_call = mock_exporter.increment_counter.call_args
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_DELETED
assert counter_call[0][1] == 1
assert counter_call[0][2]["tenant_id"] == "tenant-123"
assert counter_call[0][2]["app_id"] == "app-789"
def test_on_feedback_created_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={
"message_id": "msg-001",
"app_id": "app-789",
"conversation_id": "conv-123",
"from_end_user_id": "user-456",
"from_account_id": None,
"rating": "like",
"from_source": "api",
"content": "Great!",
},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_exporter.include_content = True
mock_get_exporter.return_value = mock_exporter
handler._on_feedback_created(envelope)
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == "dify.feedback.created"
assert call_args[1]["attributes"]["dify.message.id"] == "msg-001"
assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!"
assert "dify.feedback.created_at" in call_args[1]["attributes"]
assert call_args[1]["tenant_id"] == "tenant-123"
assert call_args[1]["user_id"] == "user-456"
mock_exporter.increment_counter.assert_called_once()
counter_args = mock_exporter.increment_counter.call_args
assert counter_args[0][2]["app_id"] == "app-789"
assert counter_args[0][2]["rating"] == "like"
def test_on_feedback_created_without_content(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={
"message_id": "msg-001",
"app_id": "app-789",
"conversation_id": "conv-123",
"from_end_user_id": "user-456",
"from_account_id": None,
"rating": "like",
"from_source": "api",
"content": "Great!",
},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_exporter.include_content = False
mock_get_exporter.return_value = mock_exporter
handler._on_feedback_created(envelope)
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert "dify.feedback.content" not in call_args[1]["attributes"]

View File

@ -0,0 +1,327 @@
"""Unit tests for enterprise/telemetry/telemetry_log.py."""
from __future__ import annotations
import uuid
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# compute_trace_id_hex
# ---------------------------------------------------------------------------
class TestComputeTraceIdHex:
def setup_method(self) -> None:
# Clear lru_cache between tests to avoid cross-test pollution
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
compute_trace_id_hex.cache_clear()
def test_none_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
assert compute_trace_id_hex(None) == ""
def test_empty_string_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
assert compute_trace_id_hex("") == ""
def test_already_32_hex_chars_returned_as_is(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
hex_id = "a" * 32
assert compute_trace_id_hex(hex_id) == hex_id
def test_valid_uuid_string_converted_to_32_hex(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
uid = "123e4567-e89b-12d3-a456-426614174000"
result = compute_trace_id_hex(uid)
assert len(result) == 32
assert all(ch in "0123456789abcdef" for ch in result)
# Round-trip: int of the UUID should equal the int parsed from result
assert int(result, 16) == uuid.UUID(uid).int
def test_invalid_string_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
assert compute_trace_id_hex("not-a-uuid") == ""
def test_whitespace_stripped(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
uid = " 123e4567-e89b-12d3-a456-426614174000 "
result = compute_trace_id_hex(uid)
assert len(result) == 32
def test_uppercase_uuid_accepted(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
uid = "123E4567-E89B-12D3-A456-426614174000"
result = compute_trace_id_hex(uid)
assert len(result) == 32
def test_result_is_cached(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
uid = "123e4567-e89b-12d3-a456-426614174000"
r1 = compute_trace_id_hex(uid)
r2 = compute_trace_id_hex(uid)
assert r1 == r2
info = compute_trace_id_hex.cache_info()
assert info.hits >= 1
# ---------------------------------------------------------------------------
# compute_span_id_hex
# ---------------------------------------------------------------------------
class TestComputeSpanIdHex:
def setup_method(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
compute_span_id_hex.cache_clear()
def test_none_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
assert compute_span_id_hex(None) == ""
def test_empty_string_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
assert compute_span_id_hex("") == ""
def test_already_16_hex_chars_returned_as_is(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
hex_id = "abcdef0123456789"
assert compute_span_id_hex(hex_id) == hex_id
def test_valid_uuid_produces_16_hex_span_id(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
uid = "123e4567-e89b-12d3-a456-426614174000"
result = compute_span_id_hex(uid)
assert len(result) == 16
assert all(ch in "0123456789abcdef" for ch in result)
def test_invalid_string_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
assert compute_span_id_hex("not-a-uuid-at-all!") == ""
def test_result_is_cached(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
uid = "123e4567-e89b-12d3-a456-426614174000"
compute_span_id_hex(uid)
compute_span_id_hex(uid)
info = compute_span_id_hex.cache_info()
assert info.hits >= 1
# ---------------------------------------------------------------------------
# emit_telemetry_log
# ---------------------------------------------------------------------------
class TestEmitTelemetryLog:
def setup_method(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex
compute_trace_id_hex.cache_clear()
compute_span_id_hex.cache_clear()
@patch("enterprise.telemetry.telemetry_log.logger")
def test_logs_info_with_event_name_and_signal(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(
event_name="dify.workflow.run",
attributes={"tenant_id": "t1"},
signal="metric_only",
)
mock_logger.info.assert_called_once()
args, kwargs = mock_logger.info.call_args
assert args[0] == "telemetry.%s"
assert args[1] == "metric_only"
extra = kwargs["extra"]
assert extra["attributes"]["dify.event.name"] == "dify.workflow.run"
assert extra["attributes"]["dify.event.signal"] == "metric_only"
@patch("enterprise.telemetry.telemetry_log.logger")
def test_no_log_when_info_disabled(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = False
emit_telemetry_log(event_name="dify.workflow.run", attributes={})
mock_logger.info.assert_not_called()
@patch("enterprise.telemetry.telemetry_log.logger")
def test_trace_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
uid = "123e4567-e89b-12d3-a456-426614174000"
emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source=uid)
extra = mock_logger.info.call_args.kwargs["extra"]
assert "trace_id" in extra
assert len(extra["trace_id"]) == 32
@patch("enterprise.telemetry.telemetry_log.logger")
def test_trace_id_absent_when_invalid_source(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source="bad-id")
extra = mock_logger.info.call_args.kwargs["extra"]
assert "trace_id" not in extra
@patch("enterprise.telemetry.telemetry_log.logger")
def test_span_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
uid = "123e4567-e89b-12d3-a456-426614174000"
emit_telemetry_log(event_name="test.event", attributes={}, span_id_source=uid)
extra = mock_logger.info.call_args.kwargs["extra"]
assert "span_id" in extra
assert len(extra["span_id"]) == 16
@patch("enterprise.telemetry.telemetry_log.logger")
def test_tenant_id_added_when_provided(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={}, tenant_id="tenant-99")
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["tenant_id"] == "tenant-99"
@patch("enterprise.telemetry.telemetry_log.logger")
def test_user_id_added_when_provided(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={}, user_id="user-42")
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["user_id"] == "user-42"
@patch("enterprise.telemetry.telemetry_log.logger")
def test_tenant_and_user_id_absent_when_not_provided(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={})
extra = mock_logger.info.call_args.kwargs["extra"]
assert "tenant_id" not in extra
assert "user_id" not in extra
@patch("enterprise.telemetry.telemetry_log.logger")
def test_caller_attributes_merged_into_attrs(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(
event_name="dify.node.run",
attributes={"node_type": "code", "elapsed": 0.5},
)
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["attributes"]["node_type"] == "code"
assert extra["attributes"]["elapsed"] == 0.5
@patch("enterprise.telemetry.telemetry_log.logger")
def test_signal_span_detail_forwarded(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={}, signal="span_detail")
args = mock_logger.info.call_args[0]
assert args[1] == "span_detail"
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["attributes"]["dify.event.signal"] == "span_detail"
# ---------------------------------------------------------------------------
# emit_metric_only_event
# ---------------------------------------------------------------------------
class TestEmitMetricOnlyEvent:
def setup_method(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex
compute_trace_id_hex.cache_clear()
compute_span_id_hex.cache_clear()
@patch("enterprise.telemetry.telemetry_log.logger")
def test_delegates_to_emit_telemetry_log_with_metric_only_signal(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_metric_only_event
mock_logger.isEnabledFor.return_value = True
emit_metric_only_event(
event_name="dify.app.created",
attributes={"app_id": "app-1"},
tenant_id="t1",
user_id="u1",
)
mock_logger.info.assert_called_once()
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["attributes"]["dify.event.signal"] == "metric_only"
assert extra["attributes"]["dify.event.name"] == "dify.app.created"
assert extra["attributes"]["app_id"] == "app-1"
assert extra["tenant_id"] == "t1"
assert extra["user_id"] == "u1"
@patch("enterprise.telemetry.telemetry_log.logger")
def test_trace_and_span_ids_passed_through(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_metric_only_event
mock_logger.isEnabledFor.return_value = True
uid = "123e4567-e89b-12d3-a456-426614174000"
emit_metric_only_event(
event_name="dify.workflow.run",
attributes={},
trace_id_source=uid,
span_id_source=uid,
)
extra = mock_logger.info.call_args.kwargs["extra"]
assert "trace_id" in extra
assert "span_id" in extra
@patch("enterprise.telemetry.telemetry_log.logger")
def test_no_log_emitted_when_logger_disabled(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_metric_only_event
mock_logger.isEnabledFor.return_value = False
emit_metric_only_event(event_name="dify.workflow.run", attributes={})
mock_logger.info.assert_not_called()

View File

@ -0,0 +1,206 @@
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def mock_db():
with patch("services.app_service.db") as mock_db:
mock_db.session = MagicMock()
yield mock_db
@pytest.fixture
def _mock_deps():
with (
patch("services.app_service.BillingService"),
patch("services.app_service.FeatureService"),
patch("services.app_service.EnterpriseService"),
patch("services.app_service.remove_app_and_related_data_task"),
):
yield
@pytest.fixture
def app_model():
app = MagicMock()
app.id = "app-123"
app.tenant_id = "tenant-456"
app.name = "Old Name"
app.icon_type = "emoji"
app.icon = "🤖"
app.icon_background = "#fff"
app.enable_site = False
app.enable_api = False
return app
def _make_collector(target: list):
def handler(sender, **kw):
target.append(sender)
return handler
@pytest.mark.usefixtures("mock_db", "_mock_deps")
class TestAppWasDeletedSignal:
def test_sends_signal(self, app_model):
from events.app_event import app_was_deleted
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_deleted.connect(handler)
try:
AppService().delete_app(app_model)
finally:
app_was_deleted.disconnect(handler)
assert received == [app_model]
def test_signal_fires_before_db_delete(self, app_model, mock_db):
from events.app_event import app_was_deleted
from services.app_service import AppService
call_order: list[str] = []
def handler(sender, **kw):
call_order.append("signal")
app_was_deleted.connect(handler)
mock_db.session.delete.side_effect = lambda _: call_order.append("db_delete")
try:
AppService().delete_app(app_model)
finally:
app_was_deleted.disconnect(handler)
assert call_order.index("signal") < call_order.index("db_delete")
@pytest.mark.usefixtures("mock_db")
class TestAppWasUpdatedSignal:
def test_update_app(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
AppService().update_app(
app_model,
{
"name": "New",
"description": "Desc",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#fff",
"use_icon_as_answer_icon": False,
"max_active_requests": 0,
},
)
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_name(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
AppService().update_app_name(app_model, "New Name")
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_icon(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
AppService().update_app_icon(app_model, "🎉", "#000")
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_site_status_sends_when_changed(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
app_model.enable_site = False
AppService().update_app_site_status(app_model, True)
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_site_status_skips_when_unchanged(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
try:
app_model.enable_site = True
AppService().update_app_site_status(app_model, True)
finally:
app_was_updated.disconnect(handler)
assert received == []
def test_update_app_api_status_sends_when_changed(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
app_model.enable_api = False
AppService().update_app_api_status(app_model, True)
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_api_status_skips_when_unchanged(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
try:
app_model.enable_api = True
AppService().update_app_api_status(app_model, True)
finally:
app_was_updated.disconnect(handler)
assert received == []

View File

@ -0,0 +1,69 @@
"""Unit tests for enterprise telemetry Celery task."""
import json
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
@pytest.fixture
def sample_envelope_json():
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-123",
payload={"app_id": "app-123"},
)
return envelope.model_dump_json()
def test_process_enterprise_telemetry_success(sample_envelope_json):
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
mock_handler = MagicMock()
mock_handler_class.return_value = mock_handler
process_enterprise_telemetry(sample_envelope_json)
mock_handler.handle.assert_called_once()
call_args = mock_handler.handle.call_args[0][0]
assert isinstance(call_args, TelemetryEnvelope)
assert call_args.case == TelemetryCase.APP_CREATED
assert call_args.tenant_id == "test-tenant"
assert call_args.event_id == "test-event-123"
def test_process_enterprise_telemetry_invalid_json(caplog):
invalid_json = "not valid json"
process_enterprise_telemetry(invalid_json)
assert "Failed to process enterprise telemetry envelope" in caplog.text
def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog):
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
mock_handler = MagicMock()
mock_handler.handle.side_effect = Exception("Handler error")
mock_handler_class.return_value = mock_handler
process_enterprise_telemetry(sample_envelope_json)
assert "Failed to process enterprise telemetry envelope" in caplog.text
def test_process_enterprise_telemetry_validation_error(caplog):
invalid_envelope = json.dumps(
{
"case": "INVALID_CASE",
"tenant_id": "test-tenant",
"event_id": "test-event",
"payload": {},
}
)
process_enterprise_telemetry(invalid_envelope)
assert "Failed to process enterprise telemetry envelope" in caplog.text