mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
feat(refactoring): Support Structured Logging (JSON) (#30170)
This commit is contained in:
@ -14,12 +14,12 @@ def test_successful_request(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
assert response.status_code == 200
|
||||
mock_client.request.assert_called_once()
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@ -27,7 +27,6 @@ def test_retry_exceed_max_retries(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
@ -72,34 +71,12 @@ class TestGetUserProvidedHostHeader:
|
||||
assert result in ("first.com", "second.com")
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_without_user_header(mock_get_client):
|
||||
"""Test that when no Host header is provided, the default behavior is maintained."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
|
||||
assert response.status_code == 200
|
||||
# Host should not be set if not provided by user
|
||||
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_with_user_header(mock_get_client):
|
||||
"""Test that user-provided Host header is preserved in the request."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
@ -107,3 +84,93 @@ def test_host_header_preservation_with_user_header(mock_get_client):
|
||||
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
|
||||
|
||||
assert response.status_code == 200
|
||||
# Verify client.request was called with the host header preserved (lowercase)
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs["headers"]["host"] == custom_host
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"])
|
||||
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
|
||||
"""Test that Host header is preserved regardless of case."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
# Host header should be normalized to lowercase "host"
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs["headers"]["host"] == "api.example.com"
|
||||
|
||||
|
||||
class TestFollowRedirectsParameter:
|
||||
"""Tests for follow_redirects parameter handling.
|
||||
|
||||
These tests verify that follow_redirects is correctly passed to client.request().
|
||||
"""
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_passed_to_request(self, mock_get_client):
|
||||
"""Verify follow_redirects IS passed to client.request()."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
make_request("GET", "http://example.com", follow_redirects=True)
|
||||
|
||||
# Verify follow_redirects was passed to request
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
|
||||
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Use allow_redirects (requests-style parameter)
|
||||
make_request("GET", "http://example.com", allow_redirects=True)
|
||||
|
||||
# Verify it was converted to follow_redirects
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
assert "allow_redirects" not in call_kwargs
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_not_set_when_not_specified(self, mock_get_client):
|
||||
"""Verify follow_redirects is not in kwargs when not specified (httpx default behavior)."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
make_request("GET", "http://example.com")
|
||||
|
||||
# follow_redirects should not be in kwargs, letting httpx use its default
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert "follow_redirects" not in call_kwargs
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
|
||||
"""Verify follow_redirects takes precedence when both are specified."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Both specified - follow_redirects should take precedence
|
||||
make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True)
|
||||
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
|
||||
0
api/tests/unit_tests/core/logging/__init__.py
Normal file
0
api/tests/unit_tests/core/logging/__init__.py
Normal file
79
api/tests/unit_tests/core/logging/test_context.py
Normal file
79
api/tests/unit_tests/core/logging/test_context.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Tests for logging context module."""
|
||||
|
||||
import uuid
|
||||
|
||||
from core.logging.context import (
|
||||
clear_request_context,
|
||||
get_request_id,
|
||||
get_trace_id,
|
||||
init_request_context,
|
||||
)
|
||||
|
||||
|
||||
class TestLoggingContext:
|
||||
"""Tests for the logging context functions."""
|
||||
|
||||
def test_init_creates_request_id(self):
|
||||
"""init_request_context should create a 10-char request ID."""
|
||||
init_request_context()
|
||||
request_id = get_request_id()
|
||||
assert len(request_id) == 10
|
||||
assert all(c in "0123456789abcdef" for c in request_id)
|
||||
|
||||
def test_init_creates_trace_id(self):
|
||||
"""init_request_context should create a 32-char trace ID."""
|
||||
init_request_context()
|
||||
trace_id = get_trace_id()
|
||||
assert len(trace_id) == 32
|
||||
assert all(c in "0123456789abcdef" for c in trace_id)
|
||||
|
||||
def test_trace_id_derived_from_request_id(self):
|
||||
"""trace_id should be deterministically derived from request_id."""
|
||||
init_request_context()
|
||||
request_id = get_request_id()
|
||||
trace_id = get_trace_id()
|
||||
|
||||
# Verify trace_id is derived using uuid5
|
||||
expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
|
||||
assert trace_id == expected_trace
|
||||
|
||||
def test_clear_resets_context(self):
|
||||
"""clear_request_context should reset both IDs to empty strings."""
|
||||
init_request_context()
|
||||
assert get_request_id() != ""
|
||||
assert get_trace_id() != ""
|
||||
|
||||
clear_request_context()
|
||||
assert get_request_id() == ""
|
||||
assert get_trace_id() == ""
|
||||
|
||||
def test_default_values_are_empty(self):
|
||||
"""Default values should be empty strings before init."""
|
||||
clear_request_context()
|
||||
assert get_request_id() == ""
|
||||
assert get_trace_id() == ""
|
||||
|
||||
def test_multiple_inits_create_different_ids(self):
|
||||
"""Each init should create new unique IDs."""
|
||||
init_request_context()
|
||||
first_request_id = get_request_id()
|
||||
first_trace_id = get_trace_id()
|
||||
|
||||
init_request_context()
|
||||
second_request_id = get_request_id()
|
||||
second_trace_id = get_trace_id()
|
||||
|
||||
assert first_request_id != second_request_id
|
||||
assert first_trace_id != second_trace_id
|
||||
|
||||
def test_context_isolation(self):
|
||||
"""Context should be isolated per-call (no thread leakage in same thread)."""
|
||||
init_request_context()
|
||||
id1 = get_request_id()
|
||||
|
||||
# Simulate another request
|
||||
init_request_context()
|
||||
id2 = get_request_id()
|
||||
|
||||
# IDs should be different
|
||||
assert id1 != id2
|
||||
114
api/tests/unit_tests/core/logging/test_filters.py
Normal file
114
api/tests/unit_tests/core/logging/test_filters.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""Tests for logging filters."""
|
||||
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def log_record():
|
||||
return logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
|
||||
class TestTraceContextFilter:
|
||||
def test_sets_empty_trace_id_without_context(self, log_record):
|
||||
from core.logging.context import clear_request_context
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
# Ensure no context is set
|
||||
clear_request_context()
|
||||
|
||||
filter = TraceContextFilter()
|
||||
result = filter.filter(log_record)
|
||||
|
||||
assert result is True
|
||||
assert hasattr(log_record, "trace_id")
|
||||
assert hasattr(log_record, "span_id")
|
||||
assert hasattr(log_record, "req_id")
|
||||
# Without context, IDs should be empty
|
||||
assert log_record.trace_id == ""
|
||||
assert log_record.req_id == ""
|
||||
|
||||
def test_sets_trace_id_from_context(self, log_record):
|
||||
"""Test that trace_id and req_id are set from ContextVar when initialized."""
|
||||
from core.logging.context import init_request_context
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
# Initialize context (no Flask needed!)
|
||||
init_request_context()
|
||||
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
# With context initialized, IDs should be set
|
||||
assert log_record.trace_id != ""
|
||||
assert len(log_record.trace_id) == 32
|
||||
assert log_record.req_id != ""
|
||||
assert len(log_record.req_id) == 10
|
||||
|
||||
def test_filter_always_returns_true(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
filter = TraceContextFilter()
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
|
||||
def test_sets_trace_id_from_otel_when_available(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_record.span_id == "051581bf3bb55c45"
|
||||
|
||||
|
||||
class TestIdentityContextFilter:
|
||||
def test_sets_empty_identity_without_request_context(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
result = filter.filter(log_record)
|
||||
|
||||
assert result is True
|
||||
assert log_record.tenant_id == ""
|
||||
assert log_record.user_id == ""
|
||||
assert log_record.user_type == ""
|
||||
|
||||
def test_filter_always_returns_true(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
|
||||
def test_handles_exception_gracefully(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
|
||||
# Should not raise even if something goes wrong
|
||||
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
assert log_record.tenant_id == ""
|
||||
267
api/tests/unit_tests/core/logging/test_structured_formatter.py
Normal file
267
api/tests/unit_tests/core/logging/test_structured_formatter.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""Tests for structured JSON formatter."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
class TestStructuredJSONFormatter:
|
||||
def test_basic_log_format(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter(service_name="test-service")
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=42,
|
||||
msg="Test message",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["severity"] == "INFO"
|
||||
assert log_dict["service"] == "test-service"
|
||||
assert log_dict["caller"] == "test.py:42"
|
||||
assert log_dict["message"] == "Test message"
|
||||
assert "ts" in log_dict
|
||||
assert log_dict["ts"].endswith("Z")
|
||||
|
||||
def test_severity_mapping(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
test_cases = [
|
||||
(logging.DEBUG, "DEBUG"),
|
||||
(logging.INFO, "INFO"),
|
||||
(logging.WARNING, "WARN"),
|
||||
(logging.ERROR, "ERROR"),
|
||||
(logging.CRITICAL, "ERROR"),
|
||||
]
|
||||
|
||||
for level, expected_severity in test_cases:
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=level,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}"
|
||||
|
||||
def test_error_with_stack_trace(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.ERROR,
|
||||
pathname="test.py",
|
||||
lineno=10,
|
||||
msg="Error occurred",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["severity"] == "ERROR"
|
||||
assert "stack_trace" in log_dict
|
||||
assert "ValueError: Test error" in log_dict["stack_trace"]
|
||||
|
||||
def test_no_stack_trace_for_info(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=10,
|
||||
msg="Info message",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert "stack_trace" not in log_dict
|
||||
|
||||
def test_trace_context_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
record.span_id = "051581bf3bb55c45"
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_dict["span_id"] == "051581bf3bb55c45"
|
||||
|
||||
def test_identity_context_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.tenant_id = "t-global-corp"
|
||||
record.user_id = "u-admin-007"
|
||||
record.user_type = "admin"
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert "identity" in log_dict
|
||||
assert log_dict["identity"]["tenant_id"] == "t-global-corp"
|
||||
assert log_dict["identity"]["user_id"] == "u-admin-007"
|
||||
assert log_dict["identity"]["user_type"] == "admin"
|
||||
|
||||
def test_no_identity_when_empty(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert "identity" not in log_dict
|
||||
|
||||
def test_attributes_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.attributes = {"order_id": "ord-123", "amount": 99.99}
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["attributes"]["order_id"] == "ord-123"
|
||||
assert log_dict["attributes"]["amount"] == 99.99
|
||||
|
||||
def test_message_with_args(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="User %s logged in from %s",
|
||||
args=("john", "192.168.1.1"),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["message"] == "User john logged in from 192.168.1.1"
|
||||
|
||||
def test_timestamp_format(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
# Verify ISO 8601 format with Z suffix
|
||||
ts = log_dict["ts"]
|
||||
assert ts.endswith("Z")
|
||||
assert "T" in ts
|
||||
# Should have milliseconds
|
||||
assert "." in ts
|
||||
|
||||
def test_fallback_for_non_serializable_attributes(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test with non-serializable",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
# Set is not serializable by orjson
|
||||
record.attributes = {"items": {1, 2, 3}, "custom": object()}
|
||||
|
||||
# Should not raise, fallback to json.dumps with default=str
|
||||
output = formatter.format(record)
|
||||
|
||||
# Verify it's valid JSON (parsed by stdlib json since orjson may fail)
|
||||
import json
|
||||
|
||||
log_dict = json.loads(output)
|
||||
assert log_dict["message"] == "Test with non-serializable"
|
||||
assert "attributes" in log_dict
|
||||
102
api/tests/unit_tests/core/logging/test_trace_helpers.py
Normal file
102
api/tests/unit_tests/core/logging/test_trace_helpers.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""Tests for trace helper functions."""
|
||||
|
||||
import re
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class TestGetSpanIdFromOtelContext:
|
||||
def test_returns_none_without_span(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result is None
|
||||
|
||||
def test_returns_span_id_when_available(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result == "051581bf3bb55c45"
|
||||
|
||||
def test_returns_none_on_exception(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateTraceparentHeader:
|
||||
def test_generates_valid_format(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
assert result is not None
|
||||
# Format: 00-{trace_id}-{span_id}-01
|
||||
parts = result.split("-")
|
||||
assert len(parts) == 4
|
||||
assert parts[0] == "00" # version
|
||||
assert len(parts[1]) == 32 # trace_id (32 hex chars)
|
||||
assert len(parts[2]) == 16 # span_id (16 hex chars)
|
||||
assert parts[3] == "01" # flags
|
||||
|
||||
def test_uses_otel_context_when_available(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
|
||||
def test_generates_hex_only_values(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
parts = result.split("-")
|
||||
# All parts should be valid hex
|
||||
assert re.match(r"^[0-9a-f]+$", parts[1])
|
||||
assert re.match(r"^[0-9a-f]+$", parts[2])
|
||||
|
||||
|
||||
class TestParseTraceparentHeader:
|
||||
def test_parses_valid_traceparent(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
result = parse_traceparent_header(traceparent)
|
||||
|
||||
assert result == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
|
||||
def test_returns_none_for_invalid_format(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
# Wrong number of parts
|
||||
assert parse_traceparent_header("00-abc-def") is None
|
||||
# Wrong trace_id length
|
||||
assert parse_traceparent_header("00-abc-def-01") is None
|
||||
|
||||
def test_returns_none_for_empty_string(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
assert parse_traceparent_header("") is None
|
||||
@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite():
|
||||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
|
||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
# Force exc_info() to return (None,None,None) only during request
|
||||
import libs.external_api as ext
|
||||
def test_external_api_param_mapping_and_quota():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None)
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
|
||||
|
||||
def test_unauthorized_and_force_logout_clears_cookies():
|
||||
|
||||
Reference in New Issue
Block a user