mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 15:08:06 +08:00
feat: enterprise otel exporter (#33138)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
554
api/tests/unit_tests/core/ops/test_lookup_helpers.py
Normal file
554
api/tests/unit_tests/core/ops/test_lookup_helpers.py
Normal file
@ -0,0 +1,554 @@
|
||||
"""Unit tests for lookup helper functions in core.ops.ops_trace_manager.
|
||||
|
||||
Covers:
|
||||
- _lookup_app_and_workspace_names
|
||||
- _lookup_credential_name
|
||||
- _lookup_llm_credential_info
|
||||
- TraceTask._get_user_id_from_metadata
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_db_and_session_patches(scalar_side_effect=None, scalar_return_value=None):
|
||||
"""Return (mock_db, cm, session) ready to patch 'core.ops.ops_trace_manager.db'
|
||||
and 'core.ops.ops_trace_manager.Session'.
|
||||
|
||||
Provide either scalar_side_effect (list, for multiple calls) or
|
||||
scalar_return_value (single value).
|
||||
"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
if scalar_side_effect is not None:
|
||||
session.scalar.side_effect = scalar_side_effect
|
||||
else:
|
||||
session.scalar.return_value = scalar_return_value
|
||||
|
||||
cm = MagicMock()
|
||||
cm.__enter__ = MagicMock(return_value=session)
|
||||
cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
return mock_db, cm, session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _lookup_app_and_workspace_names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookupAppAndWorkspaceNames:
|
||||
"""Tests for _lookup_app_and_workspace_names(app_id, tenant_id)."""
|
||||
|
||||
def test_both_found(self):
|
||||
"""Returns (app_name, workspace_name) when both records exist."""
|
||||
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
|
||||
|
||||
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", "MyWorkspace"])
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
|
||||
|
||||
assert app_name == "MyApp"
|
||||
assert workspace_name == "MyWorkspace"
|
||||
|
||||
def test_app_only_found(self):
|
||||
"""Returns (app_name, '') when tenant record is absent."""
|
||||
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
|
||||
|
||||
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", None])
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
|
||||
|
||||
assert app_name == "MyApp"
|
||||
assert workspace_name == ""
|
||||
|
||||
def test_tenant_only_found(self):
|
||||
"""Returns ('', workspace_name) when app record is absent."""
|
||||
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
|
||||
|
||||
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, "MyWorkspace"])
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
|
||||
|
||||
assert app_name == ""
|
||||
assert workspace_name == "MyWorkspace"
|
||||
|
||||
def test_neither_found(self):
|
||||
"""Returns ('', '') when both DB lookups return None."""
|
||||
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
|
||||
|
||||
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, None])
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
|
||||
|
||||
assert app_name == ""
|
||||
assert workspace_name == ""
|
||||
|
||||
def test_none_inputs_skips_db(self):
|
||||
"""Returns ('', '') immediately when both IDs are None — no DB access."""
|
||||
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_session_cls = MagicMock()
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
|
||||
):
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(None, None)
|
||||
|
||||
mock_session_cls.assert_not_called()
|
||||
assert app_name == ""
|
||||
assert workspace_name == ""
|
||||
|
||||
def test_app_id_none_only_queries_tenant(self):
|
||||
"""When app_id is None, only the tenant query is issued."""
|
||||
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
|
||||
|
||||
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyWorkspace")
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(None, "tenant-456")
|
||||
|
||||
assert app_name == ""
|
||||
assert workspace_name == "OnlyWorkspace"
|
||||
assert session.scalar.call_count == 1
|
||||
|
||||
def test_tenant_id_none_only_queries_app(self):
|
||||
"""When tenant_id is None, only the app query is issued."""
|
||||
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
|
||||
|
||||
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyApp")
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", None)
|
||||
|
||||
assert app_name == "OnlyApp"
|
||||
assert workspace_name == ""
|
||||
assert session.scalar.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _lookup_credential_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookupCredentialName:
|
||||
"""Tests for _lookup_credential_name(credential_id, provider_type)."""
|
||||
|
||||
@pytest.mark.parametrize("provider_type", ["builtin", "plugin", "api", "workflow", "mcp"])
|
||||
def test_known_provider_types_return_name(self, provider_type):
|
||||
"""Each valid provider_type results in a DB query and returns the credential name."""
|
||||
from core.ops.ops_trace_manager import _lookup_credential_name
|
||||
|
||||
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="CredentialA")
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
result = _lookup_credential_name("cred-123", provider_type)
|
||||
|
||||
assert result == "CredentialA"
|
||||
session.scalar.assert_called_once()
|
||||
|
||||
def test_credential_not_found_returns_empty_string(self):
|
||||
"""Returns '' when DB yields None for the given credential_id."""
|
||||
from core.ops.ops_trace_manager import _lookup_credential_name
|
||||
|
||||
mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None)
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
result = _lookup_credential_name("cred-999", "api")
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_invalid_provider_type_returns_empty_string_without_db(self):
|
||||
"""Returns '' immediately for an unrecognised provider_type — no DB access."""
|
||||
from core.ops.ops_trace_manager import _lookup_credential_name
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_session_cls = MagicMock()
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
|
||||
):
|
||||
result = _lookup_credential_name("cred-123", "unknown_type")
|
||||
|
||||
mock_session_cls.assert_not_called()
|
||||
assert result == ""
|
||||
|
||||
def test_none_credential_id_returns_empty_string_without_db(self):
|
||||
"""Returns '' immediately when credential_id is None — no DB access."""
|
||||
from core.ops.ops_trace_manager import _lookup_credential_name
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_session_cls = MagicMock()
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
|
||||
):
|
||||
result = _lookup_credential_name(None, "api")
|
||||
|
||||
mock_session_cls.assert_not_called()
|
||||
assert result == ""
|
||||
|
||||
def test_none_provider_type_returns_empty_string_without_db(self):
|
||||
"""Returns '' immediately when provider_type is None — no DB access."""
|
||||
from core.ops.ops_trace_manager import _lookup_credential_name
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_session_cls = MagicMock()
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
|
||||
):
|
||||
result = _lookup_credential_name("cred-123", None)
|
||||
|
||||
mock_session_cls.assert_not_called()
|
||||
assert result == ""
|
||||
|
||||
def test_builtin_and_plugin_map_to_same_model(self):
|
||||
"""Both 'builtin' and 'plugin' provider_types query BuiltinToolProvider."""
|
||||
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
|
||||
from models.tools import BuiltinToolProvider
|
||||
|
||||
assert _PROVIDER_TYPE_TO_MODEL["builtin"] is BuiltinToolProvider
|
||||
assert _PROVIDER_TYPE_TO_MODEL["plugin"] is BuiltinToolProvider
|
||||
|
||||
def test_api_maps_to_api_tool_provider(self):
|
||||
"""'api' maps to ApiToolProvider."""
|
||||
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
assert _PROVIDER_TYPE_TO_MODEL["api"] is ApiToolProvider
|
||||
|
||||
def test_workflow_maps_to_workflow_tool_provider(self):
|
||||
"""'workflow' maps to WorkflowToolProvider."""
|
||||
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
|
||||
from models.tools import WorkflowToolProvider
|
||||
|
||||
assert _PROVIDER_TYPE_TO_MODEL["workflow"] is WorkflowToolProvider
|
||||
|
||||
def test_mcp_maps_to_mcp_tool_provider(self):
|
||||
"""'mcp' maps to MCPToolProvider."""
|
||||
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
assert _PROVIDER_TYPE_TO_MODEL["mcp"] is MCPToolProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _lookup_llm_credential_info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookupLlmCredentialInfo:
|
||||
"""Tests for _lookup_llm_credential_info(tenant_id, provider, model, model_type)."""
|
||||
|
||||
def _provider_record(self, credential_id: str | None = None) -> MagicMock:
|
||||
record = MagicMock()
|
||||
record.credential_id = credential_id
|
||||
return record
|
||||
|
||||
def _model_record(self, credential_id: str | None = None) -> MagicMock:
|
||||
record = MagicMock()
|
||||
record.credential_id = credential_id
|
||||
return record
|
||||
|
||||
def test_model_level_credential_found(self):
|
||||
"""Returns model-level credential_id and name when ProviderModel has a credential."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
provider_record = self._provider_record(credential_id=None)
|
||||
model_record = self._model_record(credential_id="model-cred-id")
|
||||
|
||||
# scalar calls: (1) Provider, (2) ProviderModel, (3) ProviderModelCredential.credential_name
|
||||
mock_db, cm, _session = _make_db_and_session_patches(
|
||||
scalar_side_effect=[provider_record, model_record, "ModelCredName"]
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
|
||||
|
||||
assert cred_id == "model-cred-id"
|
||||
assert cred_name == "ModelCredName"
|
||||
|
||||
def test_provider_level_fallback_when_no_model_credential(self):
|
||||
"""Falls back to provider-level credential when ProviderModel has no credential_id."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
provider_record = self._provider_record(credential_id="prov-cred-id")
|
||||
model_record = self._model_record(credential_id=None)
|
||||
|
||||
# scalar calls: (1) Provider, (2) ProviderModel (no cred), (3) ProviderCredential.credential_name
|
||||
mock_db, cm, _session = _make_db_and_session_patches(
|
||||
scalar_side_effect=[provider_record, model_record, "ProvCredName"]
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
|
||||
|
||||
assert cred_id == "prov-cred-id"
|
||||
assert cred_name == "ProvCredName"
|
||||
|
||||
def test_provider_level_fallback_when_no_model_record(self):
|
||||
"""Falls back to provider-level credential when no ProviderModel row exists."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
provider_record = self._provider_record(credential_id="prov-cred-id")
|
||||
|
||||
# scalar calls: (1) Provider, (2) ProviderModel → None, (3) ProviderCredential.credential_name
|
||||
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, None, "ProvCredName"])
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
|
||||
|
||||
assert cred_id == "prov-cred-id"
|
||||
assert cred_name == "ProvCredName"
|
||||
|
||||
def test_no_model_arg_uses_provider_level_only(self):
|
||||
"""When model is None, skips ProviderModel query and uses provider credential."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
provider_record = self._provider_record(credential_id="prov-cred-id")
|
||||
|
||||
# scalar calls: (1) Provider, (2) ProviderCredential.credential_name — no ProviderModel
|
||||
mock_db, cm, session = _make_db_and_session_patches(scalar_side_effect=[provider_record, "ProvCredName"])
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", None)
|
||||
|
||||
assert cred_id == "prov-cred-id"
|
||||
assert cred_name == "ProvCredName"
|
||||
assert session.scalar.call_count == 2
|
||||
|
||||
def test_provider_not_found_returns_none_and_empty(self):
|
||||
"""Returns (None, '') when Provider record does not exist."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None)
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
|
||||
|
||||
assert cred_id is None
|
||||
assert cred_name == ""
|
||||
|
||||
def test_none_tenant_id_returns_none_and_empty_without_db(self):
|
||||
"""Returns (None, '') immediately when tenant_id is None — no DB access."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_session_cls = MagicMock()
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info(None, "openai", "gpt-4")
|
||||
|
||||
mock_session_cls.assert_not_called()
|
||||
assert cred_id is None
|
||||
assert cred_name == ""
|
||||
|
||||
def test_none_provider_returns_none_and_empty_without_db(self):
|
||||
"""Returns (None, '') immediately when provider is None — no DB access."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_session_cls = MagicMock()
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", None, "gpt-4")
|
||||
|
||||
mock_session_cls.assert_not_called()
|
||||
assert cred_id is None
|
||||
assert cred_name == ""
|
||||
|
||||
def test_db_error_on_outer_query_returns_none_and_empty(self):
|
||||
"""Returns (None, '') and logs a warning when the outer DB query raises."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
mock_db, cm, session = _make_db_and_session_patches()
|
||||
session.scalar.side_effect = Exception("DB connection failed")
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
|
||||
|
||||
assert cred_id is None
|
||||
assert cred_name == ""
|
||||
|
||||
def test_credential_name_lookup_failure_returns_id_with_empty_name(self):
|
||||
"""When credential name sub-query fails, returns cred_id but '' for name."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
provider_record = self._provider_record(credential_id="prov-cred-id")
|
||||
|
||||
# Provider found, no model record, then name lookup raises
|
||||
mock_db, cm, _session = _make_db_and_session_patches(
|
||||
scalar_side_effect=[provider_record, None, Exception("deleted")]
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
|
||||
|
||||
assert cred_id == "prov-cred-id"
|
||||
assert cred_name == ""
|
||||
|
||||
def test_no_credential_on_provider_or_model_returns_none_id(self):
|
||||
"""Returns (None, '') when neither provider nor model has a credential_id."""
|
||||
from core.ops.ops_trace_manager import _lookup_llm_credential_info
|
||||
|
||||
provider_record = self._provider_record(credential_id=None)
|
||||
model_record = self._model_record(credential_id=None)
|
||||
|
||||
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, model_record])
|
||||
|
||||
with (
|
||||
patch("core.ops.ops_trace_manager.db", mock_db),
|
||||
patch("core.ops.ops_trace_manager.Session", return_value=cm),
|
||||
):
|
||||
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
|
||||
|
||||
assert cred_id is None
|
||||
assert cred_name == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TraceTask._get_user_id_from_metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUserIdFromMetadata:
|
||||
"""Tests for TraceTask._get_user_id_from_metadata(metadata).
|
||||
|
||||
Pure dict logic — no DB access required.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def get_user_id(self):
|
||||
"""Return the classmethod under test."""
|
||||
from core.ops.ops_trace_manager import TraceTask
|
||||
|
||||
return TraceTask._get_user_id_from_metadata
|
||||
|
||||
def test_from_end_user_id_has_highest_priority(self, get_user_id):
|
||||
"""from_end_user_id takes precedence over all other keys."""
|
||||
metadata = {
|
||||
"from_end_user_id": "eu-abc",
|
||||
"from_account_id": "acc-xyz",
|
||||
"user_id": "u-123",
|
||||
}
|
||||
assert get_user_id(metadata) == "end_user:eu-abc"
|
||||
|
||||
def test_from_account_id_used_when_no_end_user(self, get_user_id):
|
||||
"""from_account_id is used when from_end_user_id is absent."""
|
||||
metadata = {
|
||||
"from_account_id": "acc-xyz",
|
||||
"user_id": "u-123",
|
||||
}
|
||||
assert get_user_id(metadata) == "account:acc-xyz"
|
||||
|
||||
def test_user_id_used_when_no_end_user_or_account(self, get_user_id):
|
||||
"""user_id is used when both higher-priority keys are absent."""
|
||||
metadata = {"user_id": "u-123"}
|
||||
assert get_user_id(metadata) == "user:u-123"
|
||||
|
||||
def test_returns_anonymous_when_all_keys_absent(self, get_user_id):
|
||||
"""Returns 'anonymous' when metadata has none of the expected keys."""
|
||||
assert get_user_id({}) == "anonymous"
|
||||
|
||||
def test_empty_string_end_user_id_is_skipped(self, get_user_id):
|
||||
"""Empty string for from_end_user_id is falsy and falls through to next key."""
|
||||
metadata = {
|
||||
"from_end_user_id": "",
|
||||
"from_account_id": "acc-xyz",
|
||||
}
|
||||
assert get_user_id(metadata) == "account:acc-xyz"
|
||||
|
||||
def test_empty_string_account_id_is_skipped(self, get_user_id):
|
||||
"""Empty string for from_account_id is falsy and falls through to user_id."""
|
||||
metadata = {
|
||||
"from_end_user_id": "",
|
||||
"from_account_id": "",
|
||||
"user_id": "u-123",
|
||||
}
|
||||
assert get_user_id(metadata) == "user:u-123"
|
||||
|
||||
def test_empty_string_user_id_falls_through_to_anonymous(self, get_user_id):
|
||||
"""Empty string for user_id is falsy, so 'anonymous' is returned."""
|
||||
metadata = {
|
||||
"from_end_user_id": "",
|
||||
"from_account_id": "",
|
||||
"user_id": "",
|
||||
}
|
||||
assert get_user_id(metadata) == "anonymous"
|
||||
|
||||
def test_only_from_end_user_id_present(self, get_user_id):
|
||||
"""Minimal case: only from_end_user_id present."""
|
||||
assert get_user_id({"from_end_user_id": "eu-only"}) == "end_user:eu-only"
|
||||
|
||||
def test_irrelevant_keys_do_not_interfere(self, get_user_id):
|
||||
"""Extra metadata keys have no effect on the result."""
|
||||
metadata = {"invoke_from": "web", "app_id": "a1"}
|
||||
assert get_user_id(metadata) == "anonymous"
|
||||
@ -86,6 +86,7 @@ def make_message_data(**overrides):
|
||||
created_at = datetime(2025, 2, 20, 12, 0, 0)
|
||||
base = {
|
||||
"id": "msg-id",
|
||||
"app_id": "app-id",
|
||||
"conversation_id": "conv-id",
|
||||
"created_at": created_at,
|
||||
"updated_at": created_at + timedelta(seconds=3),
|
||||
@ -182,6 +183,9 @@ class DummySessionContext:
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def scalar(self, *args, **kwargs):
|
||||
if self._index >= len(self._values):
|
||||
return None
|
||||
@ -189,6 +193,12 @@ class DummySessionContext:
|
||||
self._index += 1
|
||||
return value
|
||||
|
||||
def scalars(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_provider_map(monkeypatch):
|
||||
@ -454,7 +464,7 @@ def test_trace_task_message_trace(trace_task_message, mock_db):
|
||||
|
||||
def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db):
|
||||
DummySessionContext.scalar_values = ["wf-app-log", "message-ref"]
|
||||
execution = SimpleNamespace(id_="run-id")
|
||||
execution = SimpleNamespace(id_="run-id", total_tokens=0)
|
||||
task = TraceTask(
|
||||
trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user"
|
||||
)
|
||||
|
||||
194
api/tests/unit_tests/core/ops/test_trace_queue_manager.py
Normal file
194
api/tests/unit_tests/core/ops/test_trace_queue_manager.py
Normal file
@ -0,0 +1,194 @@
|
||||
"""Unit tests for TraceQueueManager telemetry guard.
|
||||
|
||||
Verifies that TraceQueueManager.add_trace_task() only enqueues tasks when at
|
||||
least one consumer is active:
|
||||
- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR
|
||||
- A third-party trace instance (Langfuse, etc.) is configured
|
||||
|
||||
When neither is active, tasks are silently dropped to avoid unnecessary work.
|
||||
|
||||
When BOTH are false, tasks are silently dropped (correct behavior).
|
||||
"""
|
||||
|
||||
import queue
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_queue_manager_and_task(monkeypatch):
|
||||
"""Fixture to provide TraceQueueManager and TraceTask with delayed imports."""
|
||||
module_name = "core.ops.ops_trace_manager"
|
||||
if module_name not in sys.modules:
|
||||
ops_stub = types.ModuleType(module_name)
|
||||
|
||||
class StubTraceTask:
|
||||
def __init__(self, trace_type):
|
||||
self.trace_type = trace_type
|
||||
self.app_id = None
|
||||
|
||||
class StubTraceQueueManager:
|
||||
def __init__(self, app_id=None):
|
||||
self.app_id = app_id
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
|
||||
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
|
||||
|
||||
def add_trace_task(self, trace_task):
|
||||
if self._enterprise_telemetry_enabled or self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
from core.ops.ops_trace_manager import trace_manager_queue
|
||||
|
||||
trace_manager_queue.put(trace_task)
|
||||
|
||||
class StubOpsTraceManager:
|
||||
@staticmethod
|
||||
def get_ops_trace_instance(app_id):
|
||||
return None
|
||||
|
||||
ops_stub.TraceQueueManager = StubTraceQueueManager
|
||||
ops_stub.TraceTask = StubTraceTask
|
||||
ops_stub.OpsTraceManager = StubOpsTraceManager
|
||||
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
|
||||
monkeypatch.setitem(sys.modules, module_name, ops_stub)
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
|
||||
ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"])
|
||||
TraceQueueManager = ops_module.TraceQueueManager
|
||||
TraceTask = ops_module.TraceTask
|
||||
|
||||
return TraceQueueManager, TraceTask, TraceTaskName
|
||||
|
||||
|
||||
class TestTraceQueueManagerTelemetryGuard:
|
||||
"""Test TraceQueueManager's telemetry guard in add_trace_task()."""
|
||||
|
||||
def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task):
|
||||
"""Verify task is NOT enqueued when telemetry disabled and no trace instance.
|
||||
|
||||
This is the core guard: when _enterprise_telemetry_enabled=False AND
|
||||
trace_instance=None, the task should be silently dropped.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False),
|
||||
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="test-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
mock_queue.put.assert_not_called()
|
||||
|
||||
def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task):
|
||||
"""Verify task IS enqueued when enterprise telemetry is enabled.
|
||||
|
||||
When _enterprise_telemetry_enabled=True, the task should be enqueued
|
||||
regardless of trace_instance state.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
|
||||
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="test-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.app_id == "test-app-id"
|
||||
|
||||
def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task):
|
||||
"""Verify task IS enqueued when third-party trace instance is configured.
|
||||
|
||||
When trace_instance is not None (e.g., Langfuse configured), the task
|
||||
should be enqueued even if enterprise telemetry is disabled.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
mock_trace_instance = MagicMock()
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False),
|
||||
patch(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
|
||||
),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="test-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.app_id == "test-app-id"
|
||||
|
||||
def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task):
|
||||
"""Verify task IS enqueued when both telemetry and trace instance are enabled.
|
||||
|
||||
When both _enterprise_telemetry_enabled=True AND trace_instance is set,
|
||||
the task should definitely be enqueued.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
mock_trace_instance = MagicMock()
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
|
||||
patch(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
|
||||
),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="test-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.app_id == "test-app-id"
|
||||
|
||||
def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task):
|
||||
"""Verify app_id is set on the task before enqueuing.
|
||||
|
||||
The guard logic sets trace_task.app_id = self.app_id before calling
|
||||
trace_manager_queue.put(trace_task). This test verifies that behavior.
|
||||
"""
|
||||
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
|
||||
|
||||
mock_queue = MagicMock(spec=queue.Queue)
|
||||
|
||||
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
|
||||
|
||||
with (
|
||||
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
|
||||
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
|
||||
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
|
||||
):
|
||||
manager = TraceQueueManager(app_id="expected-app-id")
|
||||
manager.add_trace_task(trace_task)
|
||||
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.app_id == "expected-app-id"
|
||||
181
api/tests/unit_tests/core/telemetry/test_facade.py
Normal file
181
api/tests/unit_tests/core/telemetry/test_facade.py
Normal file
@ -0,0 +1,181 @@
|
||||
"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.telemetry.events import TelemetryContext, TelemetryEvent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def telemetry_test_setup(monkeypatch):
|
||||
module_name = "core.ops.ops_trace_manager"
|
||||
ops_stub = types.ModuleType(module_name)
|
||||
|
||||
class StubTraceTask:
|
||||
def __init__(self, trace_type, **kwargs):
|
||||
self.trace_type = trace_type
|
||||
self.app_id = None
|
||||
self.kwargs = kwargs
|
||||
|
||||
class StubTraceQueueManager:
|
||||
def __init__(self, app_id=None, user_id=None):
|
||||
self.app_id = app_id
|
||||
self.user_id = user_id
|
||||
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
|
||||
|
||||
def add_trace_task(self, trace_task):
|
||||
trace_task.app_id = self.app_id
|
||||
from core.ops.ops_trace_manager import trace_manager_queue
|
||||
|
||||
trace_manager_queue.put(trace_task)
|
||||
|
||||
class StubOpsTraceManager:
|
||||
@staticmethod
|
||||
def get_ops_trace_instance(app_id):
|
||||
return None
|
||||
|
||||
ops_stub.TraceQueueManager = StubTraceQueueManager
|
||||
ops_stub.TraceTask = StubTraceTask
|
||||
ops_stub.OpsTraceManager = StubOpsTraceManager
|
||||
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
|
||||
monkeypatch.setitem(sys.modules, module_name, ops_stub)
|
||||
|
||||
from core.telemetry import emit
|
||||
|
||||
return emit, ops_stub.trace_manager_queue
|
||||
|
||||
|
||||
class TestTelemetryEmit:
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_enterprise_trace_creates_trace_task(self, mock_ee, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={"key": "value"},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
|
||||
|
||||
def test_emit_community_trace_enqueued(self, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
|
||||
def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_not_called()
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, mock_ee, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
enterprise_only_traces = [
|
||||
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
]
|
||||
|
||||
for trace_name in enterprise_only_traces:
|
||||
mock_queue.reset_mock()
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=trace_name,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.trace_type == trace_name
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_passes_name_directly_to_trace_task(self, mock_ee, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={"extra": "data"},
|
||||
)
|
||||
|
||||
emit_fn(event)
|
||||
|
||||
mock_queue.put.assert_called_once()
|
||||
called_task = mock_queue.put.call_args[0][0]
|
||||
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
|
||||
assert isinstance(called_task.trace_type, TraceTaskName)
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_with_provided_trace_manager(self, mock_ee, telemetry_test_setup):
|
||||
emit_fn, mock_queue = telemetry_test_setup
|
||||
|
||||
mock_trace_manager = MagicMock()
|
||||
mock_trace_manager.add_trace_task = MagicMock()
|
||||
|
||||
event = TelemetryEvent(
|
||||
name=TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id="test-tenant",
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
),
|
||||
payload={},
|
||||
)
|
||||
|
||||
emit_fn(event, trace_manager=mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
called_task = mock_trace_manager.add_trace_task.call_args[0][0]
|
||||
assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE
|
||||
225
api/tests/unit_tests/core/telemetry/test_gateway_integration.py
Normal file
225
api/tests/unit_tests/core/telemetry/test_gateway_integration.py
Normal file
@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
|
||||
class TestTelemetryCoreExports:
|
||||
def test_is_enterprise_telemetry_enabled_exported(self) -> None:
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func
|
||||
|
||||
assert callable(exported_func)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ops_trace_manager():
|
||||
mock_module = MagicMock()
|
||||
mock_trace_task_class = MagicMock()
|
||||
mock_trace_task_class.return_value = MagicMock()
|
||||
mock_module.TraceTask = mock_trace_task_class
|
||||
mock_module.TraceQueueManager = MagicMock()
|
||||
|
||||
mock_trace_entity = MagicMock()
|
||||
mock_trace_task_name = MagicMock()
|
||||
mock_trace_task_name.return_value = "workflow"
|
||||
mock_trace_entity.TraceTaskName = mock_trace_task_name
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
|
||||
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
|
||||
):
|
||||
yield mock_module, mock_trace_entity
|
||||
|
||||
|
||||
class TestGatewayIntegrationTraceRouting:
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_ce_eligible_trace_routed_to_trace_manager(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True):
|
||||
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_ce_eligible_trace_routed_when_ee_disabled(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_enterprise_only_trace_dropped_when_ee_disabled(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_enterprise_only_trace_routed_when_ee_enabled(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
|
||||
class TestGatewayIntegrationMetricRouting:
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_metric_case_routes_to_celery_task(
|
||||
self,
|
||||
mock_ee_enabled: MagicMock,
|
||||
) -> None:
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
|
||||
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc", "name": "My App"}
|
||||
|
||||
emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_delay.assert_called_once()
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.case == TelemetryCase.APP_CREATED
|
||||
assert envelope.tenant_id == "tenant-123"
|
||||
assert envelope.payload["app_id"] == "app-abc"
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_tool_execution_trace_routed(
|
||||
self,
|
||||
mock_ee_enabled: MagicMock,
|
||||
) -> None:
|
||||
mock_trace_manager = MagicMock()
|
||||
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
|
||||
payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"}
|
||||
|
||||
emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_moderation_check_trace_routed(
|
||||
self,
|
||||
mock_ee_enabled: MagicMock,
|
||||
) -> None:
|
||||
mock_trace_manager = MagicMock()
|
||||
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
|
||||
payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}}
|
||||
|
||||
emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
|
||||
class TestGatewayIntegrationCEEligibility:
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_workflow_run_is_ce_eligible(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_message_run_is_ce_eligible(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"message_id": "msg-abc", "conversation_id": "conv-123"}
|
||||
|
||||
emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_node_execution_not_ce_eligible(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_draft_node_execution_not_ce_eligible(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_execution_data": {}}
|
||||
|
||||
emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("mock_ops_trace_manager")
|
||||
def test_prompt_generation_not_ce_eligible(
|
||||
self,
|
||||
mock_trace_manager: MagicMock,
|
||||
) -> None:
|
||||
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
|
||||
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
|
||||
payload = {"operation_type": "generate", "instruction": "test"}
|
||||
|
||||
emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
|
||||
class TestIsEnterpriseTelemetryEnabled:
|
||||
def test_returns_false_when_exporter_import_fails(self) -> None:
|
||||
with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}):
|
||||
result = is_enterprise_telemetry_enabled()
|
||||
assert result is False
|
||||
|
||||
def test_function_is_callable(self) -> None:
|
||||
assert callable(is_enterprise_telemetry_enabled)
|
||||
230
api/tests/unit_tests/enterprise/telemetry/test_contracts.py
Normal file
230
api/tests/unit_tests/enterprise/telemetry/test_contracts.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""Unit tests for telemetry gateway contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.telemetry.gateway import CASE_ROUTING
|
||||
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
|
||||
|
||||
|
||||
class TestTelemetryCase:
|
||||
"""Tests for TelemetryCase enum."""
|
||||
|
||||
def test_all_cases_defined(self) -> None:
|
||||
"""Verify all 14 telemetry cases are defined."""
|
||||
expected_cases = {
|
||||
"WORKFLOW_RUN",
|
||||
"NODE_EXECUTION",
|
||||
"DRAFT_NODE_EXECUTION",
|
||||
"MESSAGE_RUN",
|
||||
"TOOL_EXECUTION",
|
||||
"MODERATION_CHECK",
|
||||
"SUGGESTED_QUESTION",
|
||||
"DATASET_RETRIEVAL",
|
||||
"GENERATE_NAME",
|
||||
"PROMPT_GENERATION",
|
||||
"APP_CREATED",
|
||||
"APP_UPDATED",
|
||||
"APP_DELETED",
|
||||
"FEEDBACK_CREATED",
|
||||
}
|
||||
actual_cases = {case.name for case in TelemetryCase}
|
||||
assert actual_cases == expected_cases
|
||||
|
||||
def test_case_values(self) -> None:
|
||||
"""Verify case enum values are correct."""
|
||||
assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run"
|
||||
assert TelemetryCase.NODE_EXECUTION.value == "node_execution"
|
||||
assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution"
|
||||
assert TelemetryCase.MESSAGE_RUN.value == "message_run"
|
||||
assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution"
|
||||
assert TelemetryCase.MODERATION_CHECK.value == "moderation_check"
|
||||
assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question"
|
||||
assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval"
|
||||
assert TelemetryCase.GENERATE_NAME.value == "generate_name"
|
||||
assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation"
|
||||
assert TelemetryCase.APP_CREATED.value == "app_created"
|
||||
assert TelemetryCase.APP_UPDATED.value == "app_updated"
|
||||
assert TelemetryCase.APP_DELETED.value == "app_deleted"
|
||||
assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created"
|
||||
|
||||
|
||||
class TestCaseRoute:
|
||||
"""Tests for CaseRoute model."""
|
||||
|
||||
def test_valid_trace_route(self) -> None:
|
||||
"""Verify valid trace route creation."""
|
||||
route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True)
|
||||
assert route.signal_type == SignalType.TRACE
|
||||
assert route.ce_eligible is True
|
||||
|
||||
def test_valid_metric_log_route(self) -> None:
|
||||
"""Verify valid metric_log route creation."""
|
||||
route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False)
|
||||
assert route.signal_type == SignalType.METRIC_LOG
|
||||
assert route.ce_eligible is False
|
||||
|
||||
def test_invalid_signal_type(self) -> None:
|
||||
"""Verify invalid signal_type is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
CaseRoute(signal_type="invalid", ce_eligible=True)
|
||||
|
||||
|
||||
class TestTelemetryEnvelope:
|
||||
"""Tests for TelemetryEnvelope model."""
|
||||
|
||||
def test_valid_envelope_minimal(self) -> None:
|
||||
"""Verify valid minimal envelope creation."""
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
)
|
||||
assert envelope.case == TelemetryCase.WORKFLOW_RUN
|
||||
assert envelope.tenant_id == "tenant-123"
|
||||
assert envelope.event_id == "event-456"
|
||||
assert envelope.payload == {"key": "value"}
|
||||
assert envelope.metadata is None
|
||||
|
||||
def test_valid_envelope_full(self) -> None:
|
||||
"""Verify valid envelope with all fields."""
|
||||
metadata = {"payload_ref": "telemetry/tenant-789/event-012.json"}
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.MESSAGE_RUN,
|
||||
tenant_id="tenant-789",
|
||||
event_id="event-012",
|
||||
payload={"message": "hello"},
|
||||
metadata=metadata,
|
||||
)
|
||||
assert envelope.case == TelemetryCase.MESSAGE_RUN
|
||||
assert envelope.tenant_id == "tenant-789"
|
||||
assert envelope.event_id == "event-012"
|
||||
assert envelope.payload == {"message": "hello"}
|
||||
assert envelope.metadata == metadata
|
||||
|
||||
def test_missing_required_case(self) -> None:
|
||||
"""Verify missing case field is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
TelemetryEnvelope(
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
)
|
||||
|
||||
def test_missing_required_tenant_id(self) -> None:
|
||||
"""Verify missing tenant_id field is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
)
|
||||
|
||||
def test_missing_required_event_id(self) -> None:
|
||||
"""Verify missing event_id field is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
payload={"key": "value"},
|
||||
)
|
||||
|
||||
def test_missing_required_payload(self) -> None:
|
||||
"""Verify missing payload field is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
)
|
||||
|
||||
def test_metadata_none(self) -> None:
|
||||
"""Verify metadata can be None."""
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.WORKFLOW_RUN,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"key": "value"},
|
||||
metadata=None,
|
||||
)
|
||||
assert envelope.metadata is None
|
||||
|
||||
|
||||
class TestCaseRouting:
|
||||
"""Tests for CASE_ROUTING table."""
|
||||
|
||||
def test_all_cases_routed(self) -> None:
|
||||
"""Verify all 14 cases have routing entries."""
|
||||
assert len(CASE_ROUTING) == 14
|
||||
for case in TelemetryCase:
|
||||
assert case in CASE_ROUTING
|
||||
|
||||
def test_trace_ce_eligible_cases(self) -> None:
|
||||
"""Verify trace cases with CE eligibility."""
|
||||
ce_eligible_trace_cases = {
|
||||
TelemetryCase.WORKFLOW_RUN,
|
||||
TelemetryCase.MESSAGE_RUN,
|
||||
}
|
||||
for case in ce_eligible_trace_cases:
|
||||
route = CASE_ROUTING[case]
|
||||
assert route.signal_type == SignalType.TRACE
|
||||
assert route.ce_eligible is True
|
||||
|
||||
def test_trace_enterprise_only_cases(self) -> None:
|
||||
"""Verify trace cases that are enterprise-only."""
|
||||
enterprise_only_trace_cases = {
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
}
|
||||
for case in enterprise_only_trace_cases:
|
||||
route = CASE_ROUTING[case]
|
||||
assert route.signal_type == SignalType.TRACE
|
||||
assert route.ce_eligible is False
|
||||
|
||||
def test_metric_log_cases(self) -> None:
|
||||
"""Verify metric/log-only cases."""
|
||||
metric_log_cases = {
|
||||
TelemetryCase.APP_CREATED,
|
||||
TelemetryCase.APP_UPDATED,
|
||||
TelemetryCase.APP_DELETED,
|
||||
TelemetryCase.FEEDBACK_CREATED,
|
||||
}
|
||||
for case in metric_log_cases:
|
||||
route = CASE_ROUTING[case]
|
||||
assert route.signal_type == SignalType.METRIC_LOG
|
||||
assert route.ce_eligible is False
|
||||
|
||||
def test_routing_table_completeness(self) -> None:
|
||||
"""Verify routing table covers all cases with correct types."""
|
||||
trace_cases = {
|
||||
TelemetryCase.WORKFLOW_RUN,
|
||||
TelemetryCase.MESSAGE_RUN,
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
TelemetryCase.TOOL_EXECUTION,
|
||||
TelemetryCase.MODERATION_CHECK,
|
||||
TelemetryCase.SUGGESTED_QUESTION,
|
||||
TelemetryCase.DATASET_RETRIEVAL,
|
||||
TelemetryCase.GENERATE_NAME,
|
||||
}
|
||||
metric_log_cases = {
|
||||
TelemetryCase.APP_CREATED,
|
||||
TelemetryCase.APP_UPDATED,
|
||||
TelemetryCase.APP_DELETED,
|
||||
TelemetryCase.FEEDBACK_CREATED,
|
||||
}
|
||||
|
||||
all_cases = trace_cases | metric_log_cases
|
||||
assert len(all_cases) == 14
|
||||
assert all_cases == set(TelemetryCase)
|
||||
|
||||
for case in trace_cases:
|
||||
assert CASE_ROUTING[case].signal_type == SignalType.TRACE
|
||||
|
||||
for case in metric_log_cases:
|
||||
assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG
|
||||
519
api/tests/unit_tests/enterprise/telemetry/test_draft_trace.py
Normal file
519
api/tests/unit_tests/enterprise/telemetry/test_draft_trace.py
Normal file
@ -0,0 +1,519 @@
|
||||
"""Unit tests for enterprise/telemetry/draft_trace.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_execution(**overrides) -> MagicMock:
|
||||
"""Return a minimal WorkflowNodeExecutionModel mock."""
|
||||
execution = MagicMock()
|
||||
execution.tenant_id = overrides.get("tenant_id", "tenant-1")
|
||||
execution.app_id = overrides.get("app_id", "app-1")
|
||||
execution.workflow_id = overrides.get("workflow_id", "wf-1")
|
||||
execution.id = overrides.get("id", "exec-1")
|
||||
execution.node_id = overrides.get("node_id", "node-1")
|
||||
execution.node_type = overrides.get("node_type", "llm")
|
||||
execution.title = overrides.get("title", "My LLM Node")
|
||||
execution.status = overrides.get("status", "succeeded")
|
||||
execution.error = overrides.get("error")
|
||||
execution.elapsed_time = overrides.get("elapsed_time", 1.5)
|
||||
execution.index = overrides.get("index", 1)
|
||||
execution.predecessor_node_id = overrides.get("predecessor_node_id")
|
||||
execution.created_at = overrides.get("created_at", datetime(2024, 1, 1, tzinfo=UTC))
|
||||
execution.finished_at = overrides.get("finished_at", datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC))
|
||||
execution.workflow_run_id = overrides.get("workflow_run_id", "run-1")
|
||||
execution.inputs_dict = overrides.get("inputs_dict", {"prompt": "hello"})
|
||||
execution.outputs_dict = overrides.get("outputs_dict", {"answer": "world"})
|
||||
execution.process_data_dict = overrides.get("process_data_dict", {})
|
||||
execution.execution_metadata_dict = overrides.get("execution_metadata_dict", {})
|
||||
return execution
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_node_execution_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildNodeExecutionData:
|
||||
def test_basic_fields_populated(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution()
|
||||
result = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id="run-override",
|
||||
)
|
||||
|
||||
assert result["workflow_id"] == "wf-1"
|
||||
assert result["tenant_id"] == "tenant-1"
|
||||
assert result["app_id"] == "app-1"
|
||||
assert result["node_execution_id"] == "exec-1"
|
||||
assert result["node_id"] == "node-1"
|
||||
assert result["node_type"] == "llm"
|
||||
assert result["title"] == "My LLM Node"
|
||||
assert result["status"] == "succeeded"
|
||||
assert result["error"] is None
|
||||
assert result["elapsed_time"] == 1.5
|
||||
assert result["index"] == 1
|
||||
|
||||
def test_workflow_execution_id_prefers_parameter(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(workflow_run_id="run-from-model")
|
||||
result = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id="explicit-run",
|
||||
)
|
||||
assert result["workflow_execution_id"] == "explicit-run"
|
||||
|
||||
def test_workflow_execution_id_falls_back_to_run_id(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(workflow_run_id="run-from-model")
|
||||
result = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id=None,
|
||||
)
|
||||
assert result["workflow_execution_id"] == "run-from-model"
|
||||
|
||||
def test_workflow_execution_id_falls_back_to_execution_id(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(workflow_run_id=None, id="exec-fallback")
|
||||
result = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id=None,
|
||||
)
|
||||
assert result["workflow_execution_id"] == "exec-fallback"
|
||||
|
||||
def test_outputs_param_overrides_execution_outputs(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(outputs_dict={"from_model": True})
|
||||
result = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs={"from_param": True},
|
||||
workflow_execution_id=None,
|
||||
)
|
||||
assert result["node_outputs"] == {"from_param": True}
|
||||
|
||||
def test_outputs_none_uses_execution_outputs_dict(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(outputs_dict={"from_model": True})
|
||||
result = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id=None,
|
||||
)
|
||||
assert result["node_outputs"] == {"from_model": True}
|
||||
|
||||
def test_metadata_token_fields_default_to_zero(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(execution_metadata_dict={})
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["total_tokens"] == 0
|
||||
assert result["total_price"] == 0.0
|
||||
assert result["currency"] is None
|
||||
|
||||
def test_metadata_token_fields_populated_from_metadata(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 200,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.05,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
|
||||
}
|
||||
execution = _make_execution(execution_metadata_dict=metadata)
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["total_tokens"] == 200
|
||||
assert result["total_price"] == 0.05
|
||||
assert result["currency"] == "USD"
|
||||
|
||||
def test_tool_name_extracted_from_tool_info_dict(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"tool_name": "web_search"},
|
||||
}
|
||||
execution = _make_execution(execution_metadata_dict=metadata)
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["tool_name"] == "web_search"
|
||||
|
||||
def test_tool_name_is_none_when_tool_info_not_dict(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: "not-a-dict"}
|
||||
execution = _make_execution(execution_metadata_dict=metadata)
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["tool_name"] is None
|
||||
|
||||
def test_tool_name_is_none_when_tool_info_absent(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(execution_metadata_dict={})
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["tool_name"] is None
|
||||
|
||||
def test_iteration_and_loop_fields(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: "iter-1",
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: 3,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: "loop-1",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: 2,
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_ID: "par-1",
|
||||
}
|
||||
execution = _make_execution(execution_metadata_dict=metadata)
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["iteration_id"] == "iter-1"
|
||||
assert result["iteration_index"] == 3
|
||||
assert result["loop_id"] == "loop-1"
|
||||
assert result["loop_index"] == 2
|
||||
assert result["parallel_id"] == "par-1"
|
||||
|
||||
def test_node_inputs_and_process_data_included(self) -> None:
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(
|
||||
inputs_dict={"q": "test"},
|
||||
process_data_dict={"step": 1},
|
||||
)
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["node_inputs"] == {"q": "test"}
|
||||
assert result["process_data"] == {"step": 1}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_draft_node_execution_trace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueDraftNodeExecutionTrace:
|
||||
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
|
||||
def test_emits_telemetry_event(self, mock_emit: MagicMock) -> None:
|
||||
from core.telemetry import TelemetryEvent, TraceTaskName
|
||||
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
|
||||
|
||||
execution = _make_execution()
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=execution,
|
||||
outputs={"result": "ok"},
|
||||
workflow_execution_id="run-x",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
event: TelemetryEvent = mock_emit.call_args[0][0]
|
||||
assert event.name == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
|
||||
assert event.context.tenant_id == "tenant-1"
|
||||
assert event.context.user_id == "user-1"
|
||||
assert event.context.app_id == "app-1"
|
||||
|
||||
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
|
||||
def test_payload_contains_node_execution_data(self, mock_emit: MagicMock) -> None:
|
||||
from core.telemetry import TelemetryEvent
|
||||
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
|
||||
|
||||
execution = _make_execution()
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id=None,
|
||||
user_id="user-2",
|
||||
)
|
||||
|
||||
event: TelemetryEvent = mock_emit.call_args[0][0]
|
||||
node_data = event.payload["node_execution_data"]
|
||||
assert node_data["workflow_id"] == "wf-1"
|
||||
assert node_data["node_type"] == "llm"
|
||||
assert node_data["status"] == "succeeded"
|
||||
|
||||
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
|
||||
def test_outputs_forwarded_to_build(self, mock_emit: MagicMock) -> None:
|
||||
from core.telemetry import TelemetryEvent
|
||||
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
|
||||
|
||||
execution = _make_execution(outputs_dict={"default": True})
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=execution,
|
||||
outputs={"explicit": True},
|
||||
workflow_execution_id=None,
|
||||
user_id="user-3",
|
||||
)
|
||||
|
||||
event: TelemetryEvent = mock_emit.call_args[0][0]
|
||||
assert event.payload["node_execution_data"]["node_outputs"] == {"explicit": True}
|
||||
|
||||
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
|
||||
def test_none_outputs_uses_execution_outputs(self, mock_emit: MagicMock) -> None:
|
||||
from core.telemetry import TelemetryEvent
|
||||
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
|
||||
|
||||
execution = _make_execution(outputs_dict={"from_model": "yes"})
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id=None,
|
||||
user_id="user-4",
|
||||
)
|
||||
|
||||
event: TelemetryEvent = mock_emit.call_args[0][0]
|
||||
assert event.payload["node_execution_data"]["node_outputs"] == {"from_model": "yes"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end token/model data flow: _build_node_execution_data →
|
||||
# ops_trace_manager.draft_node_execution_trace → DraftNodeExecutionTrace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_llm_execution() -> MagicMock:
|
||||
"""Return a WorkflowNodeExecutionModel mock that mimics a real LLM node.
|
||||
|
||||
The field values match what graphon/nodes/llm/node.py produces:
|
||||
- process_data_dict contains model_provider, model_name, and usage
|
||||
- outputs_dict contains usage with prompt/completion breakdown
|
||||
- execution_metadata_dict contains total_tokens/total_price/currency
|
||||
"""
|
||||
return _make_execution(
|
||||
tenant_id="tenant-flow",
|
||||
app_id="app-flow",
|
||||
workflow_id="wf-flow",
|
||||
id="exec-flow",
|
||||
node_id="node-llm",
|
||||
node_type="llm",
|
||||
title="GPT-4o Node",
|
||||
status="succeeded",
|
||||
elapsed_time=2.3,
|
||||
workflow_run_id=None,
|
||||
process_data_dict={
|
||||
"model_mode": "chat",
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"prompts": [{"role": "user", "text": "hello"}],
|
||||
"usage": {
|
||||
"prompt_tokens": 50,
|
||||
"prompt_unit_price": 0.00001,
|
||||
"prompt_price_unit": 0.001,
|
||||
"prompt_price": 0.0005,
|
||||
"completion_tokens": 30,
|
||||
"completion_unit_price": 0.00003,
|
||||
"completion_price_unit": 0.001,
|
||||
"completion_price": 0.0009,
|
||||
"total_tokens": 80,
|
||||
"total_price": 0.0014,
|
||||
"currency": "USD",
|
||||
"latency": 2.3,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
outputs_dict={
|
||||
"text": "world",
|
||||
"usage": {
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 30,
|
||||
"total_tokens": 80,
|
||||
"total_price": 0.0014,
|
||||
"currency": "USD",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
execution_metadata_dict={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 80,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.0014,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestDraftTraceTokenDataFlow:
|
||||
"""End-to-end test: verify all token and model fields survive from
|
||||
_build_node_execution_data through ops_trace_manager.draft_node_execution_trace
|
||||
to the DraftNodeExecutionTrace that enterprise_trace.py consumes.
|
||||
"""
|
||||
|
||||
def test_all_token_and_model_fields_reach_trace_info(self) -> None:
|
||||
"""Simulate the full draft trace data flow for an LLM node and
|
||||
assert every token/model field that enterprise_trace._emit_node_execution_trace
|
||||
reads is populated correctly on the resulting DraftNodeExecutionTrace."""
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_llm_execution()
|
||||
node_data = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id="run-flow",
|
||||
)
|
||||
|
||||
# Simulate what ops_trace_manager.draft_node_execution_trace does:
|
||||
# it calls node_execution_trace(node_execution_data=node_data) which
|
||||
# reads top-level keys from node_data. Verify all expected keys exist.
|
||||
expected_keys = {
|
||||
# Token fields — read by enterprise_trace._emit_node_execution_trace
|
||||
"total_tokens",
|
||||
"total_price",
|
||||
"currency",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
# Model fields — read for span attrs and metric labels
|
||||
"model_provider",
|
||||
"model_name",
|
||||
# Node identity — read for span attrs
|
||||
"node_type",
|
||||
"node_execution_id",
|
||||
"node_id",
|
||||
"title",
|
||||
"status",
|
||||
"error",
|
||||
"elapsed_time",
|
||||
# Workflow context
|
||||
"workflow_id",
|
||||
"workflow_execution_id",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
# Structure fields
|
||||
"index",
|
||||
"predecessor_node_id",
|
||||
"iteration_id",
|
||||
"iteration_index",
|
||||
"loop_id",
|
||||
"loop_index",
|
||||
"parallel_id",
|
||||
# Tool field
|
||||
"tool_name",
|
||||
# Content fields
|
||||
"node_inputs",
|
||||
"node_outputs",
|
||||
"process_data",
|
||||
# Timestamps
|
||||
"created_at",
|
||||
"finished_at",
|
||||
}
|
||||
assert set(node_data.keys()) == expected_keys
|
||||
|
||||
# Verify token/model values are correct (not None/zero when data exists)
|
||||
assert node_data["total_tokens"] == 80
|
||||
assert node_data["total_price"] == 0.0014
|
||||
assert node_data["currency"] == "USD"
|
||||
assert node_data["prompt_tokens"] == 50
|
||||
assert node_data["completion_tokens"] == 30
|
||||
assert node_data["model_provider"] == "openai"
|
||||
assert node_data["model_name"] == "gpt-4o"
|
||||
assert node_data["node_type"] == "llm"
|
||||
|
||||
def test_non_llm_node_has_none_for_model_and_token_breakdown(self) -> None:
|
||||
"""For non-LLM nodes (e.g. code, IF), model and token breakdown
|
||||
should be None, but total_tokens from metadata should still work."""
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(
|
||||
node_type="code",
|
||||
process_data_dict={"code": "print('hi')"},
|
||||
outputs_dict={"result": "hi"},
|
||||
execution_metadata_dict={},
|
||||
)
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["model_provider"] is None
|
||||
assert result["model_name"] is None
|
||||
assert result["prompt_tokens"] is None
|
||||
assert result["completion_tokens"] is None
|
||||
assert result["total_tokens"] == 0
|
||||
|
||||
def test_none_process_data_and_none_outputs(self) -> None:
|
||||
"""Both process_data_dict and outputs_dict are None — exercises
|
||||
the `or {}` fallback and isinstance guard together."""
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_execution(process_data_dict=None, outputs_dict=None)
|
||||
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
|
||||
|
||||
assert result["model_provider"] is None
|
||||
assert result["model_name"] is None
|
||||
assert result["prompt_tokens"] is None
|
||||
assert result["completion_tokens"] is None
|
||||
|
||||
def test_node_data_feeds_into_draft_node_execution_trace(self) -> None:
|
||||
"""Verify the node_data dict can be consumed by
|
||||
ops_trace_manager.draft_node_execution_trace without error and
|
||||
produces a DraftNodeExecutionTrace with correct token/model fields."""
|
||||
|
||||
from enterprise.telemetry.draft_trace import _build_node_execution_data
|
||||
|
||||
execution = _make_llm_execution()
|
||||
node_data = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=None,
|
||||
workflow_execution_id="run-e2e",
|
||||
)
|
||||
|
||||
# Directly construct DraftNodeExecutionTrace the way
|
||||
# ops_trace_manager.node_execution_trace does (lines 1315-1350),
|
||||
# skipping DB lookups by providing minimal metadata.
|
||||
from core.ops.entities.trace_entity import DraftNodeExecutionTrace
|
||||
|
||||
trace_info = DraftNodeExecutionTrace(
|
||||
workflow_id=node_data.get("workflow_id", ""),
|
||||
workflow_run_id=node_data.get("workflow_execution_id", ""),
|
||||
tenant_id=node_data.get("tenant_id", ""),
|
||||
node_execution_id=node_data.get("node_execution_id", ""),
|
||||
node_id=node_data.get("node_id", ""),
|
||||
node_type=node_data.get("node_type", ""),
|
||||
title=node_data.get("title", ""),
|
||||
status=node_data.get("status", ""),
|
||||
error=node_data.get("error"),
|
||||
elapsed_time=node_data.get("elapsed_time", 0.0),
|
||||
index=node_data.get("index", 0),
|
||||
predecessor_node_id=node_data.get("predecessor_node_id"),
|
||||
total_tokens=node_data.get("total_tokens", 0),
|
||||
total_price=node_data.get("total_price", 0.0),
|
||||
currency=node_data.get("currency"),
|
||||
model_provider=node_data.get("model_provider"),
|
||||
model_name=node_data.get("model_name"),
|
||||
prompt_tokens=node_data.get("prompt_tokens"),
|
||||
completion_tokens=node_data.get("completion_tokens"),
|
||||
tool_name=node_data.get("tool_name"),
|
||||
iteration_id=node_data.get("iteration_id"),
|
||||
iteration_index=node_data.get("iteration_index"),
|
||||
loop_id=node_data.get("loop_id"),
|
||||
loop_index=node_data.get("loop_index"),
|
||||
parallel_id=node_data.get("parallel_id"),
|
||||
node_inputs=node_data.get("node_inputs"),
|
||||
node_outputs=node_data.get("node_outputs"),
|
||||
process_data=node_data.get("process_data"),
|
||||
start_time=node_data.get("created_at"),
|
||||
end_time=node_data.get("finished_at"),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
# These are the fields enterprise_trace._emit_node_execution_trace reads
|
||||
assert trace_info.total_tokens == 80
|
||||
assert trace_info.prompt_tokens == 50
|
||||
assert trace_info.completion_tokens == 30
|
||||
assert trace_info.model_provider == "openai"
|
||||
assert trace_info.model_name == "gpt-4o"
|
||||
assert trace_info.node_type == "llm"
|
||||
assert trace_info.total_price == 0.0014
|
||||
assert trace_info.currency == "USD"
|
||||
1327
api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py
Normal file
1327
api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,54 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enterprise.telemetry import event_handlers
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gateway_emit():
|
||||
with patch("core.telemetry.gateway.emit") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_handle_app_created_calls_task(mock_gateway_emit):
|
||||
sender = MagicMock()
|
||||
sender.id = "app-123"
|
||||
sender.tenant_id = "tenant-456"
|
||||
sender.mode = "chat"
|
||||
|
||||
event_handlers._handle_app_created(sender)
|
||||
|
||||
mock_gateway_emit.assert_called_once_with(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
context={"tenant_id": "tenant-456"},
|
||||
payload={"app_id": "app-123", "mode": "chat"},
|
||||
)
|
||||
|
||||
|
||||
def test_handle_app_created_no_exporter(mock_gateway_emit):
|
||||
"""Gateway handles exporter availability internally; handler always calls gateway."""
|
||||
sender = MagicMock()
|
||||
sender.id = "app-123"
|
||||
sender.tenant_id = "tenant-456"
|
||||
|
||||
event_handlers._handle_app_created(sender)
|
||||
|
||||
mock_gateway_emit.assert_called_once()
|
||||
|
||||
|
||||
def test_handlers_create_valid_envelopes(mock_gateway_emit):
|
||||
"""Verify handlers pass correct TelemetryCase and payload structure."""
|
||||
sender = MagicMock()
|
||||
sender.id = "app-123"
|
||||
sender.tenant_id = "tenant-456"
|
||||
sender.mode = "chat"
|
||||
|
||||
event_handlers._handle_app_created(sender)
|
||||
|
||||
call_kwargs = mock_gateway_emit.call_args[1]
|
||||
assert call_kwargs["case"] == TelemetryCase.APP_CREATED
|
||||
assert call_kwargs["context"]["tenant_id"] == "tenant-456"
|
||||
assert call_kwargs["payload"]["app_id"] == "app-123"
|
||||
assert call_kwargs["payload"]["mode"] == "chat"
|
||||
628
api/tests/unit_tests/enterprise/telemetry/test_exporter.py
Normal file
628
api/tests/unit_tests/enterprise/telemetry/test_exporter.py
Normal file
@ -0,0 +1,628 @@
|
||||
"""Unit tests for EnterpriseExporter and _ExporterFactory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from configs.enterprise import EnterpriseTelemetryConfig
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
|
||||
from enterprise.telemetry.exporter import EnterpriseExporter, _datetime_to_ns, _parse_otlp_headers
|
||||
|
||||
|
||||
def test_config_api_key_default_empty():
|
||||
"""Test that ENTERPRISE_OTLP_API_KEY defaults to empty string."""
|
||||
config = EnterpriseTelemetryConfig()
|
||||
assert config.ENTERPRISE_OTLP_API_KEY == ""
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_api_key_only_injects_bearer_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that API key alone injects Bearer authorization header."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
|
||||
ENTERPRISE_OTLP_HEADERS="",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="test-secret-key",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify span exporter was called with Bearer header
|
||||
assert mock_span_exporter.call_args is not None
|
||||
headers = mock_span_exporter.call_args.kwargs.get("headers")
|
||||
assert headers is not None
|
||||
assert ("authorization", "Bearer test-secret-key") in headers
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_empty_api_key_no_auth_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that empty API key does not inject authorization header."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
|
||||
ENTERPRISE_OTLP_HEADERS="",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify span exporter was called without authorization header
|
||||
assert mock_span_exporter.call_args is not None
|
||||
headers = mock_span_exporter.call_args.kwargs.get("headers")
|
||||
# Headers should be None or not contain authorization
|
||||
if headers is not None:
|
||||
assert not any(key == "authorization" for key, _ in headers)
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_api_key_and_custom_headers_merge(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that API key and custom headers are merged correctly."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
|
||||
ENTERPRISE_OTLP_HEADERS="x-custom=foo",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="test-key",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify both headers are present
|
||||
assert mock_span_exporter.call_args is not None
|
||||
headers = mock_span_exporter.call_args.kwargs.get("headers")
|
||||
assert headers is not None
|
||||
assert ("authorization", "Bearer test-key") in headers
|
||||
assert ("x-custom", "foo") in headers
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.logger")
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_api_key_overrides_conflicting_header(
|
||||
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock, mock_logger: MagicMock
|
||||
) -> None:
|
||||
"""Test that API key overrides conflicting authorization header and logs warning."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
|
||||
ENTERPRISE_OTLP_HEADERS="authorization=Basic+old",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="test-key",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify Bearer header takes precedence
|
||||
assert mock_span_exporter.call_args is not None
|
||||
headers = mock_span_exporter.call_args.kwargs.get("headers")
|
||||
assert headers is not None
|
||||
assert ("authorization", "Bearer test-key") in headers
|
||||
# Verify old authorization header is not present
|
||||
assert ("authorization", "Basic old") not in headers
|
||||
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert mock_logger.warning.call_args is not None
|
||||
warning_message = mock_logger.warning.call_args[0][0]
|
||||
assert "ENTERPRISE_OTLP_API_KEY is set" in warning_message
|
||||
assert "authorization" in warning_message
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_https_endpoint_uses_secure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that https:// endpoint enables TLS (insecure=False) for gRPC."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
|
||||
ENTERPRISE_OTLP_HEADERS="",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="test-key",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify insecure=False for both exporters (https:// scheme)
|
||||
assert mock_span_exporter.call_args is not None
|
||||
assert mock_span_exporter.call_args.kwargs["insecure"] is False
|
||||
|
||||
assert mock_metric_exporter.call_args is not None
|
||||
assert mock_metric_exporter.call_args.kwargs["insecure"] is False
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_http_endpoint_uses_insecure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that http:// endpoint uses insecure gRPC (insecure=True)."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com",
|
||||
ENTERPRISE_OTLP_HEADERS="",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify insecure=True for both exporters (http:// scheme)
|
||||
assert mock_span_exporter.call_args is not None
|
||||
assert mock_span_exporter.call_args.kwargs["insecure"] is True
|
||||
|
||||
assert mock_metric_exporter.call_args is not None
|
||||
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.HTTPSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.HTTPMetricExporter")
|
||||
def test_insecure_not_passed_to_http_exporters(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that insecure parameter is not passed to HTTP exporters."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com",
|
||||
ENTERPRISE_OTLP_HEADERS="",
|
||||
ENTERPRISE_OTLP_PROTOCOL="http",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="test-key",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify insecure kwarg is NOT in HTTP exporter calls
|
||||
assert mock_span_exporter.call_args is not None
|
||||
assert "insecure" not in mock_span_exporter.call_args.kwargs
|
||||
|
||||
assert mock_metric_exporter.call_args is not None
|
||||
assert "insecure" not in mock_metric_exporter.call_args.kwargs
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_api_key_with_special_chars_preserved(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that API key with special characters is preserved without mangling."""
|
||||
special_key = "abc+def/ghi=jkl=="
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
|
||||
ENTERPRISE_OTLP_HEADERS="",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY=special_key,
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify special characters are preserved in Bearer header
|
||||
assert mock_span_exporter.call_args is not None
|
||||
headers = mock_span_exporter.call_args.kwargs.get("headers")
|
||||
assert headers is not None
|
||||
assert ("authorization", f"Bearer {special_key}") in headers
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_no_scheme_localhost_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that endpoint without scheme defaults to insecure for localhost."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="localhost:4317",
|
||||
ENTERPRISE_OTLP_HEADERS="",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify insecure=True for localhost without scheme
|
||||
assert mock_span_exporter.call_args is not None
|
||||
assert mock_span_exporter.call_args.kwargs["insecure"] is True
|
||||
|
||||
assert mock_metric_exporter.call_args is not None
|
||||
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_no_scheme_production_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""Test that endpoint without scheme defaults to insecure (not https://)."""
|
||||
mock_config = SimpleNamespace(
|
||||
ENTERPRISE_OTLP_ENDPOINT="collector.example.com:4317",
|
||||
ENTERPRISE_OTLP_HEADERS="",
|
||||
ENTERPRISE_OTLP_PROTOCOL="grpc",
|
||||
ENTERPRISE_SERVICE_NAME="dify",
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
|
||||
ENTERPRISE_INCLUDE_CONTENT=True,
|
||||
ENTERPRISE_OTLP_API_KEY="",
|
||||
)
|
||||
|
||||
EnterpriseExporter(mock_config)
|
||||
|
||||
# Verify insecure=True for any endpoint without https:// scheme
|
||||
assert mock_span_exporter.call_args is not None
|
||||
assert mock_span_exporter.call_args.kwargs["insecure"] is True
|
||||
|
||||
assert mock_metric_exporter.call_args is not None
|
||||
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_otlp_headers (line 55 — pair without "=" is skipped)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_otlp_headers_empty_returns_empty_dict() -> None:
|
||||
assert _parse_otlp_headers("") == {}
|
||||
|
||||
|
||||
def test_parse_otlp_headers_value_may_contain_equals() -> None:
|
||||
result = _parse_otlp_headers("token=abc=def==")
|
||||
assert result == {"token": "abc=def=="}
|
||||
|
||||
|
||||
def test_parse_otlp_headers_url_encoded() -> None:
|
||||
result = _parse_otlp_headers("key=%E4%BD%A0%E5%A5%BD")
|
||||
|
||||
assert result == {"key": "你好"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _datetime_to_ns (lines 64-68)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_datetime_to_ns_naive_treated_as_utc() -> None:
|
||||
"""Naive datetime must be interpreted as UTC (line 64-65)."""
|
||||
naive = datetime(2024, 1, 1, 0, 0, 0) # no tzinfo
|
||||
aware_utc = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
assert _datetime_to_ns(naive) == _datetime_to_ns(aware_utc)
|
||||
|
||||
|
||||
def test_datetime_to_ns_tz_aware_converted_to_utc() -> None:
|
||||
"""Timezone-aware datetime must be converted to UTC before computing ns (line 66-67)."""
|
||||
import zoneinfo
|
||||
|
||||
eastern = zoneinfo.ZoneInfo("America/New_York")
|
||||
dt_east = datetime(2024, 6, 1, 12, 0, 0, tzinfo=eastern) # UTC-4 in summer
|
||||
dt_utc = dt_east.astimezone(UTC)
|
||||
assert _datetime_to_ns(dt_east) == _datetime_to_ns(dt_utc)
|
||||
|
||||
|
||||
def test_datetime_to_ns_returns_integer_nanoseconds() -> None:
|
||||
dt = datetime(2024, 1, 1, 0, 0, 1, tzinfo=UTC)
|
||||
result = _datetime_to_ns(dt)
|
||||
# 2024-01-01 00:00:01 UTC = epoch + some_seconds; result should be in nanoseconds
|
||||
assert isinstance(result, int)
|
||||
# 1 second past epoch start of 2024 — should be > 1_700_000_000_000_000_000 (rough lower bound)
|
||||
assert result > 1_700_000_000_000_000_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EnterpriseExporter constructor — include_content property (line 115 / 288-289)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_grpc_config(**overrides) -> SimpleNamespace:
|
||||
defaults = {
|
||||
"ENTERPRISE_OTLP_ENDPOINT": "https://collector.example.com",
|
||||
"ENTERPRISE_OTLP_HEADERS": "",
|
||||
"ENTERPRISE_OTLP_PROTOCOL": "grpc",
|
||||
"ENTERPRISE_SERVICE_NAME": "dify",
|
||||
"ENTERPRISE_OTEL_SAMPLING_RATE": 1.0,
|
||||
"ENTERPRISE_INCLUDE_CONTENT": True,
|
||||
"ENTERPRISE_OTLP_API_KEY": "",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return SimpleNamespace(**defaults)
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_include_content_true_stored_on_exporter(
|
||||
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
|
||||
) -> None:
|
||||
"""include_content=True is stored as a public attribute (line 115)."""
|
||||
exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=True))
|
||||
assert exporter.include_content is True
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_include_content_false_stored_on_exporter(
|
||||
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
|
||||
) -> None:
|
||||
"""include_content=False is preserved (lines 288-289 path exercised by callers)."""
|
||||
exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=False))
|
||||
assert exporter.include_content is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EnterpriseExporter constructor — gRPC setup (lines 64-68 exporter-init path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_grpc_exporter_created_with_correct_endpoint(
|
||||
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
|
||||
) -> None:
|
||||
"""GRPCSpanExporter and GRPCMetricExporter receive the configured endpoint."""
|
||||
EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT="https://my-collector:4317"))
|
||||
|
||||
assert mock_span_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317"
|
||||
assert mock_metric_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317"
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_grpc_exporter_empty_endpoint_passes_none(
|
||||
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
|
||||
) -> None:
|
||||
"""Empty string endpoint is normalised to None for both gRPC exporters."""
|
||||
EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT=""))
|
||||
|
||||
assert mock_span_exporter.call_args.kwargs["endpoint"] is None
|
||||
assert mock_metric_exporter.call_args.kwargs["endpoint"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EnterpriseExporter.export_span (lines 204-271)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_exporter_with_mock_tracer() -> tuple[EnterpriseExporter, MagicMock, MagicMock]:
|
||||
"""Return (exporter, mock_tracer, mock_span) with OTEL internals fully mocked."""
|
||||
mock_span = MagicMock()
|
||||
mock_span.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_tracer = MagicMock()
|
||||
mock_tracer.start_as_current_span.return_value = mock_span
|
||||
|
||||
with (
|
||||
patch("enterprise.telemetry.exporter.GRPCSpanExporter"),
|
||||
patch("enterprise.telemetry.exporter.GRPCMetricExporter"),
|
||||
):
|
||||
exporter = EnterpriseExporter(_make_grpc_config())
|
||||
|
||||
exporter._tracer = mock_tracer
|
||||
return exporter, mock_tracer, mock_span
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.set_correlation_id")
|
||||
@patch("enterprise.telemetry.exporter.set_span_id_source")
|
||||
def test_export_span_sets_and_clears_context(mock_set_span: MagicMock, mock_set_corr: MagicMock) -> None:
|
||||
"""export_span sets correlation/span context before the span and clears them in finally."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
exporter.export_span(
|
||||
name="test.span",
|
||||
attributes={"k": "v"},
|
||||
correlation_id="corr-1",
|
||||
span_id_source="span-src-1",
|
||||
)
|
||||
|
||||
# Context was set at the start of the call
|
||||
mock_set_corr.assert_any_call("corr-1")
|
||||
mock_set_span.assert_any_call("span-src-1")
|
||||
# Context was cleared in finally
|
||||
mock_set_corr.assert_called_with(None)
|
||||
mock_set_span.assert_called_with(None)
|
||||
|
||||
|
||||
def test_export_span_sets_attributes_on_span() -> None:
|
||||
"""All non-None attribute values are set on the span via set_attribute."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
exporter.export_span(
|
||||
name="test.span",
|
||||
attributes={"key1": "value1", "key2": None, "key3": 42},
|
||||
)
|
||||
|
||||
# set_attribute should be called for non-None values only
|
||||
calls = list(mock_span.set_attribute.call_args_list)
|
||||
keys_set = {c[0][0] for c in calls}
|
||||
assert "key1" in keys_set
|
||||
assert "key3" in keys_set
|
||||
assert "key2" not in keys_set
|
||||
|
||||
|
||||
def test_export_span_no_end_time_uses_end_on_exit() -> None:
|
||||
"""When end_time is None, end_on_exit=True is passed to start_as_current_span."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
exporter.export_span(name="test.span", attributes={})
|
||||
|
||||
_, kwargs = mock_tracer.start_as_current_span.call_args
|
||||
assert kwargs["end_on_exit"] is True
|
||||
|
||||
|
||||
def test_export_span_with_end_time_calls_span_end() -> None:
|
||||
"""When end_time is provided, span.end() is called with the converted ns timestamp."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
start = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
end = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
|
||||
|
||||
exporter.export_span(name="test.span", attributes={}, start_time=start, end_time=end)
|
||||
|
||||
mock_span.end.assert_called_once()
|
||||
end_ns = mock_span.end.call_args.kwargs["end_time"]
|
||||
assert end_ns == _datetime_to_ns(end)
|
||||
|
||||
|
||||
def test_export_span_with_start_time_passed_to_start_as_current_span() -> None:
|
||||
"""When start_time is provided it is converted to ns and passed to start_as_current_span."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
start = datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC)
|
||||
exporter.export_span(name="test.span", attributes={}, start_time=start)
|
||||
|
||||
_, kwargs = mock_tracer.start_as_current_span.call_args
|
||||
assert kwargs["start_time"] == _datetime_to_ns(start)
|
||||
|
||||
|
||||
def test_export_span_root_span_no_parent_context() -> None:
|
||||
"""When span_id_source == correlation_id the span is root — no parent context."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
exporter.export_span(
|
||||
name="root.span",
|
||||
attributes={},
|
||||
correlation_id=uid,
|
||||
span_id_source=uid,
|
||||
)
|
||||
|
||||
_, kwargs = mock_tracer.start_as_current_span.call_args
|
||||
assert kwargs["context"] is None
|
||||
|
||||
|
||||
def test_export_span_child_span_has_parent_context() -> None:
|
||||
"""When correlation_id != span_id_source the child span gets a parent context."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
corr_uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
node_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
|
||||
|
||||
exporter.export_span(
|
||||
name="child.span",
|
||||
attributes={},
|
||||
correlation_id=corr_uid,
|
||||
span_id_source=node_uid,
|
||||
)
|
||||
|
||||
_, kwargs = mock_tracer.start_as_current_span.call_args
|
||||
assert kwargs["context"] is not None
|
||||
|
||||
|
||||
def test_export_span_cross_workflow_parent_context() -> None:
|
||||
"""When parent_span_id_source is set, the cross-workflow parent context is built."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
corr_uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
|
||||
|
||||
exporter.export_span(
|
||||
name="cross.span",
|
||||
attributes={},
|
||||
correlation_id=corr_uid,
|
||||
parent_span_id_source=parent_uid,
|
||||
)
|
||||
|
||||
_, kwargs = mock_tracer.start_as_current_span.call_args
|
||||
assert kwargs["context"] is not None
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.logger")
|
||||
def test_export_span_logs_exception_on_error(mock_logger: MagicMock) -> None:
|
||||
"""If the span block raises, the exception is logged and context is still cleared."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
mock_tracer.start_as_current_span.side_effect = RuntimeError("boom")
|
||||
|
||||
exporter.export_span(name="bad.span", attributes={}) # must not raise
|
||||
|
||||
mock_logger.exception.assert_called_once()
|
||||
assert "bad.span" in mock_logger.exception.call_args[0][1]
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.logger")
|
||||
def test_export_span_invalid_trace_correlation_logs_warning(mock_logger: MagicMock) -> None:
|
||||
"""Invalid UUID for trace_correlation_override triggers a warning log."""
|
||||
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
|
||||
|
||||
parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
|
||||
exporter.export_span(
|
||||
name="link.span",
|
||||
attributes={},
|
||||
correlation_id="not-a-valid-uuid",
|
||||
parent_span_id_source=parent_uid,
|
||||
)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EnterpriseExporter.increment_counter (lines 276-278)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_increment_counter_calls_add_on_counter(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""increment_counter calls .add() on the matching counter instrument."""
|
||||
exporter = EnterpriseExporter(_make_grpc_config())
|
||||
|
||||
mock_counter = MagicMock()
|
||||
exporter._counters[EnterpriseTelemetryCounter.TOKENS] = mock_counter
|
||||
|
||||
labels = {"tenant_id": "t1", "app_id": "app-1"}
|
||||
exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 50, labels)
|
||||
|
||||
mock_counter.add.assert_called_once_with(50, labels)
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_increment_counter_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""increment_counter silently does nothing when the counter is not found."""
|
||||
exporter = EnterpriseExporter(_make_grpc_config())
|
||||
exporter._counters.clear()
|
||||
|
||||
# Should not raise
|
||||
exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 5, {})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EnterpriseExporter.record_histogram (lines 283-285)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_record_histogram_calls_record_on_histogram(
|
||||
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
|
||||
) -> None:
|
||||
"""record_histogram calls .record() on the matching histogram instrument."""
|
||||
exporter = EnterpriseExporter(_make_grpc_config())
|
||||
|
||||
mock_histogram = MagicMock()
|
||||
exporter._histograms[EnterpriseTelemetryHistogram.WORKFLOW_DURATION] = mock_histogram
|
||||
|
||||
labels = {"tenant_id": "t1"}
|
||||
exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 3.14, labels)
|
||||
|
||||
mock_histogram.record.assert_called_once_with(3.14, labels)
|
||||
|
||||
|
||||
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
|
||||
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
|
||||
def test_record_histogram_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
|
||||
"""record_histogram silently does nothing when the histogram is not found."""
|
||||
exporter = EnterpriseExporter(_make_grpc_config())
|
||||
exporter._histograms.clear()
|
||||
|
||||
# Should not raise
|
||||
exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 1.0, {})
|
||||
272
api/tests/unit_tests/enterprise/telemetry/test_gateway.py
Normal file
272
api/tests/unit_tests/enterprise/telemetry/test_gateway.py
Normal file
@ -0,0 +1,272 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.telemetry.gateway import (
|
||||
CASE_ROUTING,
|
||||
CASE_TO_TRACE_TASK,
|
||||
PAYLOAD_SIZE_THRESHOLD_BYTES,
|
||||
emit,
|
||||
)
|
||||
from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope
|
||||
|
||||
|
||||
class TestCaseRoutingTable:
|
||||
def test_all_cases_have_routing(self) -> None:
|
||||
for case in TelemetryCase:
|
||||
assert case in CASE_ROUTING, f"Missing routing for {case}"
|
||||
|
||||
def test_trace_cases(self) -> None:
|
||||
trace_cases = [
|
||||
TelemetryCase.WORKFLOW_RUN,
|
||||
TelemetryCase.MESSAGE_RUN,
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
]
|
||||
for case in trace_cases:
|
||||
assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace"
|
||||
|
||||
def test_metric_log_cases(self) -> None:
|
||||
metric_log_cases = [
|
||||
TelemetryCase.APP_CREATED,
|
||||
TelemetryCase.APP_UPDATED,
|
||||
TelemetryCase.APP_DELETED,
|
||||
TelemetryCase.FEEDBACK_CREATED,
|
||||
]
|
||||
for case in metric_log_cases:
|
||||
assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log"
|
||||
|
||||
def test_ce_eligible_cases(self) -> None:
|
||||
ce_eligible_cases = [
|
||||
TelemetryCase.WORKFLOW_RUN,
|
||||
TelemetryCase.MESSAGE_RUN,
|
||||
TelemetryCase.TOOL_EXECUTION,
|
||||
TelemetryCase.MODERATION_CHECK,
|
||||
TelemetryCase.SUGGESTED_QUESTION,
|
||||
TelemetryCase.DATASET_RETRIEVAL,
|
||||
TelemetryCase.GENERATE_NAME,
|
||||
]
|
||||
for case in ce_eligible_cases:
|
||||
assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible"
|
||||
|
||||
def test_enterprise_only_cases(self) -> None:
|
||||
enterprise_only_cases = [
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
]
|
||||
for case in enterprise_only_cases:
|
||||
assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only"
|
||||
|
||||
def test_trace_cases_have_task_name_mapping(self) -> None:
|
||||
trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE]
|
||||
for case in trace_cases:
|
||||
assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ops_trace_manager():
|
||||
mock_module = MagicMock()
|
||||
mock_trace_task_class = MagicMock()
|
||||
mock_trace_task_class.return_value = MagicMock()
|
||||
mock_module.TraceTask = mock_trace_task_class
|
||||
mock_module.TraceQueueManager = MagicMock()
|
||||
|
||||
mock_trace_entity = MagicMock()
|
||||
mock_trace_task_name = MagicMock()
|
||||
mock_trace_task_name.return_value = "workflow"
|
||||
mock_trace_entity.TraceTaskName = mock_trace_task_name
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
|
||||
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
|
||||
):
|
||||
yield mock_module, mock_trace_entity
|
||||
|
||||
|
||||
class TestGatewayTraceRouting:
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_trace_case_routes_to_trace_manager(
|
||||
self,
|
||||
mock_ee_enabled: MagicMock,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False)
|
||||
def test_ce_eligible_trace_enqueued_when_ee_disabled(
|
||||
self,
|
||||
mock_ee_enabled: MagicMock,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False)
|
||||
def test_enterprise_only_trace_dropped_when_ee_disabled(
|
||||
self,
|
||||
mock_ee_enabled: MagicMock,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_enterprise_only_trace_enqueued_when_ee_enabled(
|
||||
self,
|
||||
mock_ee_enabled: MagicMock,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
|
||||
class TestGatewayMetricLogRouting:
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_metric_case_routes_to_celery_task(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_ee_enabled: MagicMock,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc", "name": "My App"}
|
||||
|
||||
emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_delay.assert_called_once()
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.case == TelemetryCase.APP_CREATED
|
||||
assert envelope.tenant_id == "tenant-123"
|
||||
assert envelope.payload["app_id"] == "app-abc"
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_envelope_has_unique_event_id(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_ee_enabled: MagicMock,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc"}
|
||||
|
||||
emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
assert mock_delay.call_count == 2
|
||||
envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0])
|
||||
envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0])
|
||||
assert envelope1.event_id != envelope2.event_id
|
||||
|
||||
|
||||
class TestGatewayPayloadSizing:
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_small_payload_inlined(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_ee_enabled: MagicMock,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"key": "small_value"}
|
||||
|
||||
emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == payload
|
||||
assert envelope.metadata is None
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
@patch("core.telemetry.gateway.storage")
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_large_payload_stored(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_storage: MagicMock,
|
||||
mock_ee_enabled: MagicMock,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
|
||||
payload = {"key": large_value}
|
||||
|
||||
emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_storage.save.assert_called_once()
|
||||
storage_key = mock_storage.save.call_args[0][0]
|
||||
assert storage_key.startswith("telemetry/tenant-123/")
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == {}
|
||||
assert envelope.metadata is not None
|
||||
assert envelope.metadata["payload_ref"] == storage_key
|
||||
|
||||
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
|
||||
@patch("core.telemetry.gateway.storage")
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_large_payload_fallback_on_storage_error(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_storage: MagicMock,
|
||||
mock_ee_enabled: MagicMock,
|
||||
) -> None:
|
||||
mock_storage.save.side_effect = Exception("Storage failure")
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
|
||||
payload = {"key": large_value}
|
||||
|
||||
emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == payload
|
||||
assert envelope.metadata is None
|
||||
|
||||
|
||||
class TestTraceTaskNameMapping:
|
||||
def test_workflow_run_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
def test_message_run_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE
|
||||
|
||||
def test_node_execution_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE
|
||||
|
||||
def test_draft_node_execution_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
|
||||
|
||||
def test_prompt_generation_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE
|
||||
201
api/tests/unit_tests/enterprise/telemetry/test_id_generator.py
Normal file
201
api/tests/unit_tests/enterprise/telemetry/test_id_generator.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""Unit tests for enterprise/telemetry/id_generator.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_deterministic_span_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeDeterministicSpanId:
|
||||
def test_returns_lower_64_bits_of_uuid(self) -> None:
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
expected = uuid.UUID(uid).int & ((1 << 64) - 1)
|
||||
assert compute_deterministic_span_id(uid) == expected
|
||||
|
||||
def test_non_zero_result_returned_unchanged(self) -> None:
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
# This UUID has non-zero lower 64 bits
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
result = compute_deterministic_span_id(uid)
|
||||
assert result != 0
|
||||
|
||||
def test_zero_lower_bits_returns_one(self) -> None:
|
||||
"""When the lower 64 bits of the UUID int are 0, the function must return 1 (OTEL requirement)."""
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
# Craft a UUID whose lower 64 bits are 0: upper 64 bits are 1, lower 64 bits are 0.
|
||||
# int = (1 << 64), UUID fields constructed to produce this integer.
|
||||
target_int = 1 << 64 # lower 64 bits are 0x0000000000000000
|
||||
crafted_uuid = uuid.UUID(int=target_int)
|
||||
result = compute_deterministic_span_id(str(crafted_uuid))
|
||||
assert result == 1
|
||||
|
||||
def test_raises_on_invalid_uuid(self) -> None:
|
||||
import pytest
|
||||
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
with pytest.raises((ValueError, AttributeError)):
|
||||
compute_deterministic_span_id("not-a-uuid")
|
||||
|
||||
def test_different_uuids_produce_different_span_ids(self) -> None:
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
uid1 = "123e4567-e89b-12d3-a456-426614174000"
|
||||
uid2 = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
|
||||
assert compute_deterministic_span_id(uid1) != compute_deterministic_span_id(uid2)
|
||||
|
||||
def test_deterministic_same_input_same_output(self) -> None:
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
assert compute_deterministic_span_id(uid) == compute_deterministic_span_id(uid)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context variable helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestContextVariableHelpers:
|
||||
def test_set_and_get_correlation_id(self) -> None:
|
||||
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
|
||||
|
||||
set_correlation_id("corr-123")
|
||||
assert get_correlation_id() == "corr-123"
|
||||
|
||||
def test_clear_correlation_id(self) -> None:
|
||||
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
|
||||
|
||||
set_correlation_id("corr-abc")
|
||||
set_correlation_id(None)
|
||||
assert get_correlation_id() is None
|
||||
|
||||
def test_correlation_id_default_is_none(self) -> None:
|
||||
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
|
||||
|
||||
set_correlation_id(None)
|
||||
assert get_correlation_id() is None
|
||||
|
||||
def test_set_span_id_source_stored_in_context(self) -> None:
|
||||
from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source
|
||||
|
||||
set_span_id_source("span-src-1")
|
||||
assert _span_id_source_context.get() == "span-src-1"
|
||||
|
||||
def test_clear_span_id_source(self) -> None:
|
||||
from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source
|
||||
|
||||
set_span_id_source("span-src-1")
|
||||
set_span_id_source(None)
|
||||
assert _span_id_source_context.get() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CorrelationIdGenerator.generate_trace_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCorrelationIdGeneratorGenerateTraceId:
|
||||
def setup_method(self) -> None:
|
||||
from enterprise.telemetry.id_generator import set_correlation_id
|
||||
|
||||
set_correlation_id(None)
|
||||
|
||||
def test_returns_uuid_int_when_correlation_id_set(self) -> None:
|
||||
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
set_correlation_id(uid)
|
||||
gen = CorrelationIdGenerator()
|
||||
trace_id = gen.generate_trace_id()
|
||||
assert trace_id == uuid.UUID(uid).int
|
||||
|
||||
def test_returns_random_when_no_correlation_id(self) -> None:
|
||||
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
|
||||
|
||||
set_correlation_id(None)
|
||||
gen = CorrelationIdGenerator()
|
||||
# Should return a non-zero int without raising
|
||||
trace_id = gen.generate_trace_id()
|
||||
assert isinstance(trace_id, int)
|
||||
assert trace_id > 0
|
||||
|
||||
def test_returns_random_when_correlation_id_is_invalid_uuid(self) -> None:
|
||||
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
|
||||
|
||||
set_correlation_id("not-a-valid-uuid")
|
||||
gen = CorrelationIdGenerator()
|
||||
with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=42) as mock_rng:
|
||||
trace_id = gen.generate_trace_id()
|
||||
mock_rng.assert_called_once_with(128)
|
||||
assert trace_id == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CorrelationIdGenerator.generate_span_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCorrelationIdGeneratorGenerateSpanId:
|
||||
def setup_method(self) -> None:
|
||||
from enterprise.telemetry.id_generator import set_span_id_source
|
||||
|
||||
set_span_id_source(None)
|
||||
|
||||
def test_uses_deterministic_span_id_when_source_set(self) -> None:
|
||||
from enterprise.telemetry.id_generator import (
|
||||
CorrelationIdGenerator,
|
||||
compute_deterministic_span_id,
|
||||
set_span_id_source,
|
||||
)
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
set_span_id_source(uid)
|
||||
gen = CorrelationIdGenerator()
|
||||
span_id = gen.generate_span_id()
|
||||
assert span_id == compute_deterministic_span_id(uid)
|
||||
|
||||
def test_returns_random_when_no_source(self) -> None:
|
||||
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
|
||||
|
||||
set_span_id_source(None)
|
||||
gen = CorrelationIdGenerator()
|
||||
span_id = gen.generate_span_id()
|
||||
assert isinstance(span_id, int)
|
||||
assert span_id != 0
|
||||
|
||||
def test_returns_random_when_source_is_invalid_uuid(self) -> None:
|
||||
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
|
||||
|
||||
set_span_id_source("not-a-uuid")
|
||||
gen = CorrelationIdGenerator()
|
||||
with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=7) as mock_rng:
|
||||
span_id = gen.generate_span_id()
|
||||
assert span_id == 7
|
||||
|
||||
def test_random_span_id_retried_if_zero(self) -> None:
|
||||
"""generate_span_id must never return 0 — it retries until non-zero."""
|
||||
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
|
||||
|
||||
set_span_id_source(None)
|
||||
gen = CorrelationIdGenerator()
|
||||
# First call returns 0 (invalid), second returns 99
|
||||
with patch("enterprise.telemetry.id_generator.random.getrandbits", side_effect=[0, 99]):
|
||||
span_id = gen.generate_span_id()
|
||||
assert span_id == 99
|
||||
|
||||
def test_generate_span_id_always_non_zero(self) -> None:
|
||||
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
|
||||
|
||||
set_span_id_source(None)
|
||||
gen = CorrelationIdGenerator()
|
||||
for _ in range(20):
|
||||
assert gen.generate_span_id() != 0
|
||||
511
api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py
Normal file
511
api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py
Normal file
@ -0,0 +1,511 @@
|
||||
"""Unit tests for EnterpriseMetricHandler."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
with patch("enterprise.telemetry.metric_handler.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_envelope():
|
||||
return TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-123",
|
||||
payload={"app_id": "app-123", "name": "Test App"},
|
||||
)
|
||||
|
||||
|
||||
def test_dispatch_app_created(sample_envelope, mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_app_created") as mock_handler:
|
||||
handler.handle(sample_envelope)
|
||||
mock_handler.assert_called_once_with(sample_envelope)
|
||||
|
||||
|
||||
def test_dispatch_app_updated(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_UPDATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-456",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_app_updated") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_app_deleted(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_DELETED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-789",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_app_deleted") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_feedback_created(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.FEEDBACK_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-abc",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_feedback_created") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_message_run(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.MESSAGE_RUN,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-msg",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_message_run") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_tool_execution(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.TOOL_EXECUTION,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-tool",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_tool_execution") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_moderation_check(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.MODERATION_CHECK,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-mod",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_moderation_check") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_suggested_question(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.SUGGESTED_QUESTION,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-sq",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_suggested_question") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_dataset_retrieval(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.DATASET_RETRIEVAL,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-ds",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_dataset_retrieval") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_generate_name(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.GENERATE_NAME,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-gn",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_generate_name") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_dispatch_prompt_generation(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.PROMPT_GENERATION,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-pg",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_prompt_generation") as mock_handler:
|
||||
handler.handle(envelope)
|
||||
mock_handler.assert_called_once_with(envelope)
|
||||
|
||||
|
||||
def test_all_known_cases_have_handlers(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
handler = EnterpriseMetricHandler()
|
||||
|
||||
for case in TelemetryCase:
|
||||
envelope = TelemetryEnvelope(
|
||||
case=case,
|
||||
tenant_id="test-tenant",
|
||||
event_id=f"test-{case.value}",
|
||||
payload={},
|
||||
)
|
||||
handler.handle(envelope)
|
||||
|
||||
|
||||
def test_idempotency_duplicate(sample_envelope, mock_redis):
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch.object(handler, "_on_app_created") as mock_handler:
|
||||
handler.handle(sample_envelope)
|
||||
mock_handler.assert_not_called()
|
||||
|
||||
|
||||
def test_idempotency_first_seen(sample_envelope, mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
is_dup = handler._is_duplicate(sample_envelope)
|
||||
|
||||
assert is_dup is False
|
||||
mock_redis.set.assert_called_once_with(
|
||||
"telemetry:dedup:test-tenant:test-event-123",
|
||||
b"1",
|
||||
nx=True,
|
||||
ex=3600,
|
||||
)
|
||||
|
||||
|
||||
def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog):
|
||||
mock_redis.set.side_effect = Exception("Redis unavailable")
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
is_dup = handler._is_duplicate(sample_envelope)
|
||||
|
||||
assert is_dup is False
|
||||
assert "Redis unavailable for deduplication check" in caplog.text
|
||||
|
||||
|
||||
def test_rehydration_uses_payload(sample_envelope):
|
||||
handler = EnterpriseMetricHandler()
|
||||
payload = handler._rehydrate(sample_envelope)
|
||||
|
||||
assert payload == {"app_id": "app-123", "name": "Test App"}
|
||||
|
||||
|
||||
def test_rehydration_from_storage():
|
||||
"""Verify _rehydrate loads payload from object storage via payload_ref."""
|
||||
stored_data = {"app_id": "app-stored", "mode": "workflow"}
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-fb",
|
||||
payload={},
|
||||
metadata={"payload_ref": "telemetry/test-tenant/test-event-fb.json"},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch("enterprise.telemetry.metric_handler.storage") as mock_storage:
|
||||
mock_storage.load.return_value = json.dumps(stored_data).encode("utf-8")
|
||||
payload = handler._rehydrate(envelope)
|
||||
|
||||
assert payload == stored_data
|
||||
mock_storage.load.assert_called_once_with("telemetry/test-tenant/test-event-fb.json")
|
||||
|
||||
|
||||
def test_rehydration_storage_failure_emits_degraded_event():
|
||||
"""Verify _rehydrate emits degraded event when storage load fails."""
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-fail",
|
||||
payload={},
|
||||
metadata={"payload_ref": "telemetry/test-tenant/test-event-fail.json"},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("enterprise.telemetry.metric_handler.storage") as mock_storage,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_storage.load.side_effect = Exception("Storage unavailable")
|
||||
payload = handler._rehydrate(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
assert payload == {}
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED
|
||||
assert "dify.telemetry.error" in call_args[1]["attributes"]
|
||||
|
||||
|
||||
def test_rehydration_emits_degraded_event_on_empty_payload():
|
||||
"""Verify _rehydrate emits degraded event when payload is empty and no ref exists."""
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-empty",
|
||||
payload={},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit:
|
||||
payload = handler._rehydrate(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
assert payload == {}
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED
|
||||
assert "dify.telemetry.error" in call_args[1]["attributes"]
|
||||
|
||||
|
||||
def test_on_app_created_emits_correct_event(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"app_id": "app-789", "mode": "chat"},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_app_created(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_CREATED
|
||||
assert call_args[1]["tenant_id"] == "tenant-123"
|
||||
attrs = call_args[1]["attributes"]
|
||||
assert attrs["dify.app_id"] == "app-789"
|
||||
assert attrs["dify.tenant_id"] == "tenant-123"
|
||||
assert attrs["dify.event.id"] == "event-456"
|
||||
assert attrs["dify.app.mode"] == "chat"
|
||||
assert "dify.app.created_at" in attrs
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
|
||||
|
||||
mock_exporter.increment_counter.assert_called_once()
|
||||
counter_call = mock_exporter.increment_counter.call_args
|
||||
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_CREATED
|
||||
assert counter_call[0][1] == 1
|
||||
assert counter_call[0][2]["tenant_id"] == "tenant-123"
|
||||
assert counter_call[0][2]["app_id"] == "app-789"
|
||||
assert counter_call[0][2]["mode"] == "chat"
|
||||
|
||||
|
||||
def test_on_app_updated_emits_correct_event(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_UPDATED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"app_id": "app-789"},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_app_updated(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_UPDATED
|
||||
assert call_args[1]["tenant_id"] == "tenant-123"
|
||||
attrs = call_args[1]["attributes"]
|
||||
assert attrs["dify.app_id"] == "app-789"
|
||||
assert attrs["dify.tenant_id"] == "tenant-123"
|
||||
assert attrs["dify.event.id"] == "event-456"
|
||||
assert "dify.app.updated_at" in attrs
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
|
||||
|
||||
mock_exporter.increment_counter.assert_called_once()
|
||||
counter_call = mock_exporter.increment_counter.call_args
|
||||
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_UPDATED
|
||||
assert counter_call[0][1] == 1
|
||||
assert counter_call[0][2]["tenant_id"] == "tenant-123"
|
||||
assert counter_call[0][2]["app_id"] == "app-789"
|
||||
|
||||
|
||||
def test_on_app_deleted_emits_correct_event(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_DELETED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={"app_id": "app-789"},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_app_deleted(envelope)
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_DELETED
|
||||
assert call_args[1]["tenant_id"] == "tenant-123"
|
||||
attrs = call_args[1]["attributes"]
|
||||
assert attrs["dify.app_id"] == "app-789"
|
||||
assert attrs["dify.tenant_id"] == "tenant-123"
|
||||
assert attrs["dify.event.id"] == "event-456"
|
||||
assert "dify.app.deleted_at" in attrs
|
||||
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
|
||||
|
||||
mock_exporter.increment_counter.assert_called_once()
|
||||
counter_call = mock_exporter.increment_counter.call_args
|
||||
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_DELETED
|
||||
assert counter_call[0][1] == 1
|
||||
assert counter_call[0][2]["tenant_id"] == "tenant-123"
|
||||
assert counter_call[0][2]["app_id"] == "app-789"
|
||||
|
||||
|
||||
def test_on_feedback_created_emits_correct_event(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.FEEDBACK_CREATED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={
|
||||
"message_id": "msg-001",
|
||||
"app_id": "app-789",
|
||||
"conversation_id": "conv-123",
|
||||
"from_end_user_id": "user-456",
|
||||
"from_account_id": None,
|
||||
"rating": "like",
|
||||
"from_source": "api",
|
||||
"content": "Great!",
|
||||
},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_exporter.include_content = True
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_feedback_created(envelope)
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert call_args[1]["event_name"] == "dify.feedback.created"
|
||||
assert call_args[1]["attributes"]["dify.message.id"] == "msg-001"
|
||||
assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!"
|
||||
assert "dify.feedback.created_at" in call_args[1]["attributes"]
|
||||
assert call_args[1]["tenant_id"] == "tenant-123"
|
||||
assert call_args[1]["user_id"] == "user-456"
|
||||
|
||||
mock_exporter.increment_counter.assert_called_once()
|
||||
counter_args = mock_exporter.increment_counter.call_args
|
||||
assert counter_args[0][2]["app_id"] == "app-789"
|
||||
assert counter_args[0][2]["rating"] == "like"
|
||||
|
||||
|
||||
def test_on_feedback_created_without_content(mock_redis):
|
||||
mock_redis.set.return_value = True
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.FEEDBACK_CREATED,
|
||||
tenant_id="tenant-123",
|
||||
event_id="event-456",
|
||||
payload={
|
||||
"message_id": "msg-001",
|
||||
"app_id": "app-789",
|
||||
"conversation_id": "conv-123",
|
||||
"from_end_user_id": "user-456",
|
||||
"from_account_id": None,
|
||||
"rating": "like",
|
||||
"from_source": "api",
|
||||
"content": "Great!",
|
||||
},
|
||||
)
|
||||
|
||||
handler = EnterpriseMetricHandler()
|
||||
with (
|
||||
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
|
||||
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
|
||||
):
|
||||
mock_exporter = MagicMock()
|
||||
mock_exporter.include_content = False
|
||||
mock_get_exporter.return_value = mock_exporter
|
||||
|
||||
handler._on_feedback_created(envelope)
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
call_args = mock_emit.call_args
|
||||
assert "dify.feedback.content" not in call_args[1]["attributes"]
|
||||
327
api/tests/unit_tests/enterprise/telemetry/test_telemetry_log.py
Normal file
327
api/tests/unit_tests/enterprise/telemetry/test_telemetry_log.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""Unit tests for enterprise/telemetry/telemetry_log.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_trace_id_hex
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeTraceIdHex:
|
||||
def setup_method(self) -> None:
|
||||
# Clear lru_cache between tests to avoid cross-test pollution
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
compute_trace_id_hex.cache_clear()
|
||||
|
||||
def test_none_returns_empty(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
assert compute_trace_id_hex(None) == ""
|
||||
|
||||
def test_empty_string_returns_empty(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
assert compute_trace_id_hex("") == ""
|
||||
|
||||
def test_already_32_hex_chars_returned_as_is(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
hex_id = "a" * 32
|
||||
assert compute_trace_id_hex(hex_id) == hex_id
|
||||
|
||||
def test_valid_uuid_string_converted_to_32_hex(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
result = compute_trace_id_hex(uid)
|
||||
assert len(result) == 32
|
||||
assert all(ch in "0123456789abcdef" for ch in result)
|
||||
# Round-trip: int of the UUID should equal the int parsed from result
|
||||
assert int(result, 16) == uuid.UUID(uid).int
|
||||
|
||||
def test_invalid_string_returns_empty(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
assert compute_trace_id_hex("not-a-uuid") == ""
|
||||
|
||||
def test_whitespace_stripped(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
uid = " 123e4567-e89b-12d3-a456-426614174000 "
|
||||
result = compute_trace_id_hex(uid)
|
||||
assert len(result) == 32
|
||||
|
||||
def test_uppercase_uuid_accepted(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
uid = "123E4567-E89B-12D3-A456-426614174000"
|
||||
result = compute_trace_id_hex(uid)
|
||||
assert len(result) == 32
|
||||
|
||||
def test_result_is_cached(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
r1 = compute_trace_id_hex(uid)
|
||||
r2 = compute_trace_id_hex(uid)
|
||||
assert r1 == r2
|
||||
info = compute_trace_id_hex.cache_info()
|
||||
assert info.hits >= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_span_id_hex
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeSpanIdHex:
|
||||
def setup_method(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex
|
||||
|
||||
compute_span_id_hex.cache_clear()
|
||||
|
||||
def test_none_returns_empty(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex
|
||||
|
||||
assert compute_span_id_hex(None) == ""
|
||||
|
||||
def test_empty_string_returns_empty(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex
|
||||
|
||||
assert compute_span_id_hex("") == ""
|
||||
|
||||
def test_already_16_hex_chars_returned_as_is(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex
|
||||
|
||||
hex_id = "abcdef0123456789"
|
||||
assert compute_span_id_hex(hex_id) == hex_id
|
||||
|
||||
def test_valid_uuid_produces_16_hex_span_id(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
result = compute_span_id_hex(uid)
|
||||
assert len(result) == 16
|
||||
assert all(ch in "0123456789abcdef" for ch in result)
|
||||
|
||||
def test_invalid_string_returns_empty(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex
|
||||
|
||||
assert compute_span_id_hex("not-a-uuid-at-all!") == ""
|
||||
|
||||
def test_result_is_cached(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex
|
||||
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
compute_span_id_hex(uid)
|
||||
compute_span_id_hex(uid)
|
||||
info = compute_span_id_hex.cache_info()
|
||||
assert info.hits >= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# emit_telemetry_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitTelemetryLog:
|
||||
def setup_method(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex
|
||||
|
||||
compute_trace_id_hex.cache_clear()
|
||||
compute_span_id_hex.cache_clear()
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_logs_info_with_event_name_and_signal(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
|
||||
emit_telemetry_log(
|
||||
event_name="dify.workflow.run",
|
||||
attributes={"tenant_id": "t1"},
|
||||
signal="metric_only",
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
args, kwargs = mock_logger.info.call_args
|
||||
assert args[0] == "telemetry.%s"
|
||||
assert args[1] == "metric_only"
|
||||
extra = kwargs["extra"]
|
||||
assert extra["attributes"]["dify.event.name"] == "dify.workflow.run"
|
||||
assert extra["attributes"]["dify.event.signal"] == "metric_only"
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_no_log_when_info_disabled(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = False
|
||||
|
||||
emit_telemetry_log(event_name="dify.workflow.run", attributes={})
|
||||
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_trace_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source=uid)
|
||||
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert "trace_id" in extra
|
||||
assert len(extra["trace_id"]) == 32
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_trace_id_absent_when_invalid_source(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
|
||||
emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source="bad-id")
|
||||
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert "trace_id" not in extra
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_span_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
emit_telemetry_log(event_name="test.event", attributes={}, span_id_source=uid)
|
||||
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert "span_id" in extra
|
||||
assert len(extra["span_id"]) == 16
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_tenant_id_added_when_provided(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
|
||||
emit_telemetry_log(event_name="test.event", attributes={}, tenant_id="tenant-99")
|
||||
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert extra["tenant_id"] == "tenant-99"
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_user_id_added_when_provided(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
|
||||
emit_telemetry_log(event_name="test.event", attributes={}, user_id="user-42")
|
||||
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert extra["user_id"] == "user-42"
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_tenant_and_user_id_absent_when_not_provided(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
|
||||
emit_telemetry_log(event_name="test.event", attributes={})
|
||||
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert "tenant_id" not in extra
|
||||
assert "user_id" not in extra
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_caller_attributes_merged_into_attrs(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
|
||||
emit_telemetry_log(
|
||||
event_name="dify.node.run",
|
||||
attributes={"node_type": "code", "elapsed": 0.5},
|
||||
)
|
||||
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert extra["attributes"]["node_type"] == "code"
|
||||
assert extra["attributes"]["elapsed"] == 0.5
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_signal_span_detail_forwarded(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_telemetry_log
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
|
||||
emit_telemetry_log(event_name="test.event", attributes={}, signal="span_detail")
|
||||
|
||||
args = mock_logger.info.call_args[0]
|
||||
assert args[1] == "span_detail"
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert extra["attributes"]["dify.event.signal"] == "span_detail"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# emit_metric_only_event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitMetricOnlyEvent:
|
||||
def setup_method(self) -> None:
|
||||
from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex
|
||||
|
||||
compute_trace_id_hex.cache_clear()
|
||||
compute_span_id_hex.cache_clear()
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_delegates_to_emit_telemetry_log_with_metric_only_signal(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name="dify.app.created",
|
||||
attributes={"app_id": "app-1"},
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert extra["attributes"]["dify.event.signal"] == "metric_only"
|
||||
assert extra["attributes"]["dify.event.name"] == "dify.app.created"
|
||||
assert extra["attributes"]["app_id"] == "app-1"
|
||||
assert extra["tenant_id"] == "t1"
|
||||
assert extra["user_id"] == "u1"
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_trace_and_span_ids_passed_through(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
uid = "123e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name="dify.workflow.run",
|
||||
attributes={},
|
||||
trace_id_source=uid,
|
||||
span_id_source=uid,
|
||||
)
|
||||
|
||||
extra = mock_logger.info.call_args.kwargs["extra"]
|
||||
assert "trace_id" in extra
|
||||
assert "span_id" in extra
|
||||
|
||||
@patch("enterprise.telemetry.telemetry_log.logger")
|
||||
def test_no_log_emitted_when_logger_disabled(self, mock_logger: MagicMock) -> None:
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
|
||||
mock_logger.isEnabledFor.return_value = False
|
||||
|
||||
emit_metric_only_event(event_name="dify.workflow.run", attributes={})
|
||||
|
||||
mock_logger.info.assert_not_called()
|
||||
206
api/tests/unit_tests/events/test_app_event_signals.py
Normal file
206
api/tests/unit_tests/events/test_app_event_signals.py
Normal file
@ -0,0 +1,206 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
mock_db.session = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_deps():
|
||||
with (
|
||||
patch("services.app_service.BillingService"),
|
||||
patch("services.app_service.FeatureService"),
|
||||
patch("services.app_service.EnterpriseService"),
|
||||
patch("services.app_service.remove_app_and_related_data_task"),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model():
|
||||
app = MagicMock()
|
||||
app.id = "app-123"
|
||||
app.tenant_id = "tenant-456"
|
||||
app.name = "Old Name"
|
||||
app.icon_type = "emoji"
|
||||
app.icon = "🤖"
|
||||
app.icon_background = "#fff"
|
||||
app.enable_site = False
|
||||
app.enable_api = False
|
||||
return app
|
||||
|
||||
|
||||
def _make_collector(target: list):
|
||||
def handler(sender, **kw):
|
||||
target.append(sender)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_db", "_mock_deps")
|
||||
class TestAppWasDeletedSignal:
|
||||
def test_sends_signal(self, app_model):
|
||||
from events.app_event import app_was_deleted
|
||||
from services.app_service import AppService
|
||||
|
||||
received = []
|
||||
handler = _make_collector(received)
|
||||
app_was_deleted.connect(handler)
|
||||
try:
|
||||
AppService().delete_app(app_model)
|
||||
finally:
|
||||
app_was_deleted.disconnect(handler)
|
||||
|
||||
assert received == [app_model]
|
||||
|
||||
def test_signal_fires_before_db_delete(self, app_model, mock_db):
|
||||
from events.app_event import app_was_deleted
|
||||
from services.app_service import AppService
|
||||
|
||||
call_order: list[str] = []
|
||||
|
||||
def handler(sender, **kw):
|
||||
call_order.append("signal")
|
||||
|
||||
app_was_deleted.connect(handler)
|
||||
mock_db.session.delete.side_effect = lambda _: call_order.append("db_delete")
|
||||
|
||||
try:
|
||||
AppService().delete_app(app_model)
|
||||
finally:
|
||||
app_was_deleted.disconnect(handler)
|
||||
|
||||
assert call_order.index("signal") < call_order.index("db_delete")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_db")
|
||||
class TestAppWasUpdatedSignal:
|
||||
def test_update_app(self, app_model):
|
||||
from events.app_event import app_was_updated
|
||||
from services.app_service import AppService
|
||||
|
||||
received = []
|
||||
handler = _make_collector(received)
|
||||
app_was_updated.connect(handler)
|
||||
|
||||
with patch("services.app_service.current_user", MagicMock(id="user-1")):
|
||||
try:
|
||||
AppService().update_app(
|
||||
app_model,
|
||||
{
|
||||
"name": "New",
|
||||
"description": "Desc",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#fff",
|
||||
"use_icon_as_answer_icon": False,
|
||||
"max_active_requests": 0,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
app_was_updated.disconnect(handler)
|
||||
|
||||
assert received == [app_model]
|
||||
|
||||
def test_update_app_name(self, app_model):
|
||||
from events.app_event import app_was_updated
|
||||
from services.app_service import AppService
|
||||
|
||||
received = []
|
||||
handler = _make_collector(received)
|
||||
app_was_updated.connect(handler)
|
||||
|
||||
with patch("services.app_service.current_user", MagicMock(id="user-1")):
|
||||
try:
|
||||
AppService().update_app_name(app_model, "New Name")
|
||||
finally:
|
||||
app_was_updated.disconnect(handler)
|
||||
|
||||
assert received == [app_model]
|
||||
|
||||
def test_update_app_icon(self, app_model):
|
||||
from events.app_event import app_was_updated
|
||||
from services.app_service import AppService
|
||||
|
||||
received = []
|
||||
handler = _make_collector(received)
|
||||
app_was_updated.connect(handler)
|
||||
|
||||
with patch("services.app_service.current_user", MagicMock(id="user-1")):
|
||||
try:
|
||||
AppService().update_app_icon(app_model, "🎉", "#000")
|
||||
finally:
|
||||
app_was_updated.disconnect(handler)
|
||||
|
||||
assert received == [app_model]
|
||||
|
||||
def test_update_app_site_status_sends_when_changed(self, app_model):
|
||||
from events.app_event import app_was_updated
|
||||
from services.app_service import AppService
|
||||
|
||||
received = []
|
||||
handler = _make_collector(received)
|
||||
app_was_updated.connect(handler)
|
||||
|
||||
with patch("services.app_service.current_user", MagicMock(id="user-1")):
|
||||
try:
|
||||
app_model.enable_site = False
|
||||
AppService().update_app_site_status(app_model, True)
|
||||
finally:
|
||||
app_was_updated.disconnect(handler)
|
||||
|
||||
assert received == [app_model]
|
||||
|
||||
def test_update_app_site_status_skips_when_unchanged(self, app_model):
|
||||
from events.app_event import app_was_updated
|
||||
from services.app_service import AppService
|
||||
|
||||
received = []
|
||||
handler = _make_collector(received)
|
||||
app_was_updated.connect(handler)
|
||||
|
||||
try:
|
||||
app_model.enable_site = True
|
||||
AppService().update_app_site_status(app_model, True)
|
||||
finally:
|
||||
app_was_updated.disconnect(handler)
|
||||
|
||||
assert received == []
|
||||
|
||||
def test_update_app_api_status_sends_when_changed(self, app_model):
|
||||
from events.app_event import app_was_updated
|
||||
from services.app_service import AppService
|
||||
|
||||
received = []
|
||||
handler = _make_collector(received)
|
||||
app_was_updated.connect(handler)
|
||||
|
||||
with patch("services.app_service.current_user", MagicMock(id="user-1")):
|
||||
try:
|
||||
app_model.enable_api = False
|
||||
AppService().update_app_api_status(app_model, True)
|
||||
finally:
|
||||
app_was_updated.disconnect(handler)
|
||||
|
||||
assert received == [app_model]
|
||||
|
||||
def test_update_app_api_status_skips_when_unchanged(self, app_model):
|
||||
from events.app_event import app_was_updated
|
||||
from services.app_service import AppService
|
||||
|
||||
received = []
|
||||
handler = _make_collector(received)
|
||||
app_was_updated.connect(handler)
|
||||
|
||||
try:
|
||||
app_model.enable_api = True
|
||||
AppService().update_app_api_status(app_model, True)
|
||||
finally:
|
||||
app_was_updated.disconnect(handler)
|
||||
|
||||
assert received == []
|
||||
69
api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py
Normal file
69
api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""Unit tests for enterprise telemetry Celery task."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_envelope_json():
|
||||
envelope = TelemetryEnvelope(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
tenant_id="test-tenant",
|
||||
event_id="test-event-123",
|
||||
payload={"app_id": "app-123"},
|
||||
)
|
||||
return envelope.model_dump_json()
|
||||
|
||||
|
||||
def test_process_enterprise_telemetry_success(sample_envelope_json):
|
||||
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
|
||||
mock_handler = MagicMock()
|
||||
mock_handler_class.return_value = mock_handler
|
||||
|
||||
process_enterprise_telemetry(sample_envelope_json)
|
||||
|
||||
mock_handler.handle.assert_called_once()
|
||||
call_args = mock_handler.handle.call_args[0][0]
|
||||
assert isinstance(call_args, TelemetryEnvelope)
|
||||
assert call_args.case == TelemetryCase.APP_CREATED
|
||||
assert call_args.tenant_id == "test-tenant"
|
||||
assert call_args.event_id == "test-event-123"
|
||||
|
||||
|
||||
def test_process_enterprise_telemetry_invalid_json(caplog):
|
||||
invalid_json = "not valid json"
|
||||
|
||||
process_enterprise_telemetry(invalid_json)
|
||||
|
||||
assert "Failed to process enterprise telemetry envelope" in caplog.text
|
||||
|
||||
|
||||
def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog):
|
||||
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.handle.side_effect = Exception("Handler error")
|
||||
mock_handler_class.return_value = mock_handler
|
||||
|
||||
process_enterprise_telemetry(sample_envelope_json)
|
||||
|
||||
assert "Failed to process enterprise telemetry envelope" in caplog.text
|
||||
|
||||
|
||||
def test_process_enterprise_telemetry_validation_error(caplog):
|
||||
invalid_envelope = json.dumps(
|
||||
{
|
||||
"case": "INVALID_CASE",
|
||||
"tenant_id": "test-tenant",
|
||||
"event_id": "test-event",
|
||||
"payload": {},
|
||||
}
|
||||
)
|
||||
|
||||
process_enterprise_telemetry(invalid_envelope)
|
||||
|
||||
assert "Failed to process enterprise telemetry envelope" in caplog.text
|
||||
Reference in New Issue
Block a user