mirror of
https://github.com/langgenius/dify.git
synced 2026-03-21 06:18:27 +08:00
Merge branch 'feat/model-plugins-implementing' into deploy/dev
This commit is contained in:
@ -150,8 +150,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.TraceQueueManager", DummyTraceQueueManager)
|
||||
@ -1124,8 +1125,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
@ -1202,8 +1204,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -240,12 +240,12 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_iteration_and_loop_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = (
|
||||
lambda **kwargs: "iter_start"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: (
|
||||
"iter_start"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = (
|
||||
lambda **kwargs: "iter_done"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: (
|
||||
"iter_done"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
|
||||
@ -144,8 +144,9 @@ class TestWorkflowAppGeneratorGenerate:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -0,0 +1,326 @@
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
INVALID_SPAN_ID,
|
||||
SpanBuilder,
|
||||
TraceClient,
|
||||
build_endpoint,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_string_to_id,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
create_link,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_client_factory():
|
||||
"""Factory fixture for creating TraceClient instances with automatic cleanup."""
|
||||
clients_to_shutdown = []
|
||||
|
||||
def _factory(**kwargs):
|
||||
client = TraceClient(**kwargs)
|
||||
clients_to_shutdown.append(client)
|
||||
return client
|
||||
|
||||
yield _factory
|
||||
|
||||
# Cleanup: shutdown all created clients
|
||||
for client in clients_to_shutdown:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
class TestTraceClient:
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname")
|
||||
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
|
||||
mock_gethostname.return_value = "test-host"
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
|
||||
assert client.endpoint == "http://test-endpoint"
|
||||
assert client.max_queue_size == 1000
|
||||
assert client.schedule_delay_sec == 5
|
||||
assert client.done is False
|
||||
assert client.worker_thread.is_alive()
|
||||
|
||||
client.shutdown()
|
||||
assert client.done is True
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
spans = [MagicMock(spec=ReadableSpan)]
|
||||
client.export(spans)
|
||||
mock_exporter.export.assert_called_once_with(spans)
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
mock_head.return_value = mock_response
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.api_check() is True
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_head.return_value = mock_response
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.api_check() is False
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_head.side_effect = httpx.RequestError("Connection error")
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"):
|
||||
client.api_check()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_add_span(self, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(
|
||||
service_name="test-service",
|
||||
endpoint="http://test-endpoint",
|
||||
max_export_batch_size=2,
|
||||
)
|
||||
|
||||
# Test add None
|
||||
client.add_span(None)
|
||||
assert len(client.queue) == 0
|
||||
|
||||
# Test add valid SpanData
|
||||
span_data = SpanData(
|
||||
name="test-span",
|
||||
trace_id=123,
|
||||
span_id=456,
|
||||
parent_span_id=None,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status=Status(StatusCode.OK),
|
||||
span_kind=SpanKind.INTERNAL,
|
||||
)
|
||||
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.span_builder.build_span = MagicMock(return_value=mock_span)
|
||||
|
||||
with patch.object(client.condition, "notify") as mock_notify:
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
mock_notify.assert_not_called()
|
||||
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 2
|
||||
mock_notify.assert_called_once()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.logger")
|
||||
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
|
||||
|
||||
span_data = SpanData(
|
||||
name="test-span",
|
||||
trace_id=123,
|
||||
span_id=456,
|
||||
parent_span_id=None,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status=Status(StatusCode.OK),
|
||||
span_kind=SpanKind.INTERNAL,
|
||||
)
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.span_builder.build_span = MagicMock(return_value=mock_span)
|
||||
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
mock_exporter.export.side_effect = Exception("Export failed")
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.queue.append(mock_span)
|
||||
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger:
|
||||
client._export_batch()
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
|
||||
# We need to test the wait timeout in _worker
|
||||
# But _worker runs in a thread. Let's mock condition.wait.
|
||||
client = trace_client_factory(
|
||||
service_name="test-service",
|
||||
endpoint="http://test-endpoint",
|
||||
schedule_delay_sec=0.1,
|
||||
)
|
||||
|
||||
with patch.object(client.condition, "wait") as mock_wait:
|
||||
# Let it run for a bit then shut down
|
||||
time.sleep(0.2)
|
||||
client.shutdown()
|
||||
# mock_wait might have been called
|
||||
assert mock_wait.called or client.done
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.queue.append(mock_span)
|
||||
|
||||
client.shutdown()
|
||||
# Should have called export twice (once in worker/export_batch, once in shutdown)
|
||||
# or at least once if worker was waiting
|
||||
assert mock_exporter.export.called
|
||||
assert mock_exporter.shutdown.called
|
||||
|
||||
|
||||
class TestSpanBuilder:
|
||||
def test_build_span(self):
|
||||
resource = MagicMock()
|
||||
builder = SpanBuilder(resource)
|
||||
|
||||
span_data = SpanData(
|
||||
name="test-span",
|
||||
trace_id=123,
|
||||
span_id=456,
|
||||
parent_span_id=789,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status=Status(StatusCode.OK),
|
||||
span_kind=SpanKind.INTERNAL,
|
||||
attributes={"attr1": "val1"},
|
||||
events=[],
|
||||
links=[],
|
||||
)
|
||||
|
||||
span = builder.build_span(span_data)
|
||||
assert isinstance(span, ReadableSpan)
|
||||
assert span.name == "test-span"
|
||||
assert span.context.trace_id == 123
|
||||
assert span.context.span_id == 456
|
||||
assert span.parent.span_id == 789
|
||||
assert span.resource == resource
|
||||
assert span.attributes == {"attr1": "val1"}
|
||||
|
||||
def test_build_span_no_parent(self):
|
||||
resource = MagicMock()
|
||||
builder = SpanBuilder(resource)
|
||||
|
||||
span_data = SpanData(
|
||||
name="test-span",
|
||||
trace_id=123,
|
||||
span_id=456,
|
||||
parent_span_id=None,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status=Status(StatusCode.OK),
|
||||
span_kind=SpanKind.INTERNAL,
|
||||
)
|
||||
|
||||
span = builder.build_span(span_data)
|
||||
assert span.parent is None
|
||||
|
||||
|
||||
def test_create_link():
|
||||
trace_id_str = "0123456789abcdef0123456789abcdef"
|
||||
link = create_link(trace_id_str)
|
||||
assert link.context.trace_id == int(trace_id_str, 16)
|
||||
assert link.context.span_id == INVALID_SPAN_ID
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid trace ID format"):
|
||||
create_link("invalid-hex")
|
||||
|
||||
|
||||
def test_generate_span_id():
|
||||
# Test normal generation
|
||||
span_id = generate_span_id()
|
||||
assert isinstance(span_id, int)
|
||||
assert span_id != INVALID_SPAN_ID
|
||||
|
||||
# Test retry loop
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand:
|
||||
mock_rand.side_effect = [INVALID_SPAN_ID, 999]
|
||||
span_id = generate_span_id()
|
||||
assert span_id == 999
|
||||
assert mock_rand.call_count == 2
|
||||
|
||||
|
||||
def test_convert_to_trace_id():
|
||||
uid = str(uuid.uuid4())
|
||||
trace_id = convert_to_trace_id(uid)
|
||||
assert trace_id == uuid.UUID(uid).int
|
||||
|
||||
with pytest.raises(ValueError, match="UUID cannot be None"):
|
||||
convert_to_trace_id(None)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid UUID input"):
|
||||
convert_to_trace_id("not-a-uuid")
|
||||
|
||||
|
||||
def test_convert_string_to_id():
|
||||
assert convert_string_to_id("test") > 0
|
||||
# Test with None string
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen:
|
||||
mock_gen.return_value = 12345
|
||||
assert convert_string_to_id(None) == 12345
|
||||
|
||||
|
||||
def test_convert_to_span_id():
|
||||
uid = str(uuid.uuid4())
|
||||
span_id = convert_to_span_id(uid, "test-type")
|
||||
assert isinstance(span_id, int)
|
||||
|
||||
with pytest.raises(ValueError, match="UUID cannot be None"):
|
||||
convert_to_span_id(None, "test")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid UUID input"):
|
||||
convert_to_span_id("not-a-uuid", "test")
|
||||
|
||||
|
||||
def test_convert_datetime_to_nanoseconds():
|
||||
dt = datetime(2023, 1, 1, 12, 0, 0)
|
||||
ns = convert_datetime_to_nanoseconds(dt)
|
||||
assert ns == int(dt.timestamp() * 1e9)
|
||||
assert convert_datetime_to_nanoseconds(None) is None
|
||||
|
||||
|
||||
def test_build_endpoint():
|
||||
license_key = "abc"
|
||||
|
||||
# CMS 2.0 endpoint
|
||||
url1 = "https://log.aliyuncs.com"
|
||||
assert build_endpoint(url1, license_key) == "https://log.aliyuncs.com/adapt_abc/api/v1/traces"
|
||||
|
||||
# XTrace endpoint
|
||||
url2 = "https://example.com"
|
||||
assert build_endpoint(url2, license_key) == "https://example.com/adapt_abc/api/otlp/traces"
|
||||
@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
|
||||
|
||||
class TestTraceMetadata:
|
||||
def test_trace_metadata_init(self):
|
||||
links = [trace_api.Link(context=trace_api.SpanContext(0, 0, False))]
|
||||
metadata = TraceMetadata(
|
||||
trace_id=123, workflow_span_id=456, session_id="session_1", user_id="user_1", links=links
|
||||
)
|
||||
assert metadata.trace_id == 123
|
||||
assert metadata.workflow_span_id == 456
|
||||
assert metadata.session_id == "session_1"
|
||||
assert metadata.user_id == "user_1"
|
||||
assert metadata.links == links
|
||||
|
||||
|
||||
class TestSpanData:
|
||||
def test_span_data_init_required_fields(self):
|
||||
span_data = SpanData(trace_id=123, span_id=456, name="test_span", start_time=1000, end_time=2000)
|
||||
assert span_data.trace_id == 123
|
||||
assert span_data.span_id == 456
|
||||
assert span_data.name == "test_span"
|
||||
assert span_data.start_time == 1000
|
||||
assert span_data.end_time == 2000
|
||||
|
||||
# Check defaults
|
||||
assert span_data.parent_span_id is None
|
||||
assert span_data.attributes == {}
|
||||
assert span_data.events == []
|
||||
assert span_data.links == []
|
||||
assert span_data.status.status_code == StatusCode.UNSET
|
||||
assert span_data.span_kind == SpanKind.INTERNAL
|
||||
|
||||
def test_span_data_with_optional_fields(self):
|
||||
event = Event(name="event_1", timestamp=1500)
|
||||
link = trace_api.Link(context=trace_api.SpanContext(0, 0, False))
|
||||
status = Status(StatusCode.OK)
|
||||
|
||||
span_data = SpanData(
|
||||
trace_id=123,
|
||||
parent_span_id=111,
|
||||
span_id=456,
|
||||
name="test_span",
|
||||
attributes={"key": "value"},
|
||||
events=[event],
|
||||
links=[link],
|
||||
status=status,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
|
||||
assert span_data.parent_span_id == 111
|
||||
assert span_data.attributes == {"key": "value"}
|
||||
assert span_data.events == [event]
|
||||
assert span_data.links == [link]
|
||||
assert span_data.status.status_code == status.status_code
|
||||
assert span_data.span_kind == SpanKind.SERVER
|
||||
|
||||
def test_span_data_missing_required_fields(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SpanData(
|
||||
trace_id=123,
|
||||
# span_id missing
|
||||
name="test_span",
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
)
|
||||
|
||||
def test_span_data_arbitrary_types_allowed(self):
|
||||
# opentelemetry.trace.Status and Event are "arbitrary types" for Pydantic
|
||||
# This test ensures they are accepted thanks to model_config
|
||||
status = Status(StatusCode.ERROR, description="error occurred")
|
||||
event = Event(name="exception", timestamp=1234, attributes={"exception.type": "ValueError"})
|
||||
|
||||
span_data = SpanData(
|
||||
trace_id=123, span_id=456, name="test_span", status=status, events=[event], start_time=1000, end_time=2000
|
||||
)
|
||||
|
||||
assert span_data.status.status_code == status.status_code
|
||||
assert span_data.status.description == status.description
|
||||
assert span_data.events == [event]
|
||||
@ -0,0 +1,68 @@
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
ACS_ARMS_SERVICE_FEATURE,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROVIDER_NAME,
|
||||
GEN_AI_REQUEST_MODEL,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
GEN_AI_USER_ID,
|
||||
GEN_AI_USER_NAME,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert ACS_ARMS_SERVICE_FEATURE == "acs.arms.service.feature"
|
||||
assert GEN_AI_SESSION_ID == "gen_ai.session.id"
|
||||
assert GEN_AI_USER_ID == "gen_ai.user.id"
|
||||
assert GEN_AI_USER_NAME == "gen_ai.user.name"
|
||||
assert GEN_AI_SPAN_KIND == "gen_ai.span.kind"
|
||||
assert GEN_AI_FRAMEWORK == "gen_ai.framework"
|
||||
assert INPUT_VALUE == "input.value"
|
||||
assert OUTPUT_VALUE == "output.value"
|
||||
assert RETRIEVAL_QUERY == "retrieval.query"
|
||||
assert RETRIEVAL_DOCUMENT == "retrieval.document"
|
||||
assert GEN_AI_REQUEST_MODEL == "gen_ai.request.model"
|
||||
assert GEN_AI_PROVIDER_NAME == "gen_ai.provider.name"
|
||||
assert GEN_AI_USAGE_INPUT_TOKENS == "gen_ai.usage.input_tokens"
|
||||
assert GEN_AI_USAGE_OUTPUT_TOKENS == "gen_ai.usage.output_tokens"
|
||||
assert GEN_AI_USAGE_TOTAL_TOKENS == "gen_ai.usage.total_tokens"
|
||||
assert GEN_AI_PROMPT == "gen_ai.prompt"
|
||||
assert GEN_AI_COMPLETION == "gen_ai.completion"
|
||||
assert GEN_AI_RESPONSE_FINISH_REASON == "gen_ai.response.finish_reason"
|
||||
assert GEN_AI_INPUT_MESSAGE == "gen_ai.input.messages"
|
||||
assert GEN_AI_OUTPUT_MESSAGE == "gen_ai.output.messages"
|
||||
assert TOOL_NAME == "tool.name"
|
||||
assert TOOL_DESCRIPTION == "tool.description"
|
||||
assert TOOL_PARAMETERS == "tool.parameters"
|
||||
|
||||
|
||||
def test_gen_ai_span_kind_enum():
|
||||
assert GenAISpanKind.CHAIN == "CHAIN"
|
||||
assert GenAISpanKind.RETRIEVER == "RETRIEVER"
|
||||
assert GenAISpanKind.RERANKER == "RERANKER"
|
||||
assert GenAISpanKind.LLM == "LLM"
|
||||
assert GenAISpanKind.EMBEDDING == "EMBEDDING"
|
||||
assert GenAISpanKind.TOOL == "TOOL"
|
||||
assert GenAISpanKind.AGENT == "AGENT"
|
||||
assert GenAISpanKind.TASK == "TASK"
|
||||
|
||||
# Verify iteration works (covers the class definition)
|
||||
kinds = list(GenAISpanKind)
|
||||
assert len(kinds) == 8
|
||||
assert "LLM" in kinds
|
||||
647
api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py
Normal file
647
api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py
Normal file
@ -0,0 +1,647 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
|
||||
|
||||
import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module
|
||||
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_REQUEST_MODEL,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
|
||||
|
||||
class RecordingTraceClient:
|
||||
def __init__(self, service_name: str = "service", endpoint: str = "endpoint"):
|
||||
self.service_name = service_name
|
||||
self.endpoint = endpoint
|
||||
self.added_spans: list[object] = []
|
||||
|
||||
def add_span(self, span) -> None:
|
||||
self.added_spans.append(span)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
return "project-url"
|
||||
|
||||
|
||||
def _dt() -> datetime:
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags.SAMPLED,
|
||||
)
|
||||
return Link(context)
|
||||
|
||||
|
||||
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
|
||||
defaults = {
|
||||
"workflow_id": "workflow-id",
|
||||
"tenant_id": "tenant-id",
|
||||
"workflow_run_id": "00000000-0000-0000-0000-000000000001",
|
||||
"workflow_run_elapsed_time": 1.0,
|
||||
"workflow_run_status": "succeeded",
|
||||
"workflow_run_inputs": {"sys.query": "hello"},
|
||||
"workflow_run_outputs": {"answer": "world"},
|
||||
"workflow_run_version": "v1",
|
||||
"total_tokens": 1,
|
||||
"file_list": [],
|
||||
"query": "hello",
|
||||
"metadata": {"conversation_id": "conv", "user_id": "u", "app_id": "app"},
|
||||
"message_id": None,
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkflowTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_message_trace_info(**overrides) -> MessageTraceInfo:
|
||||
defaults = {
|
||||
"conversation_model": "chat",
|
||||
"message_tokens": 1,
|
||||
"answer_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"conversation_mode": "chat",
|
||||
"metadata": {"conversation_id": "conv", "ls_model_name": "m", "ls_provider": "p"},
|
||||
"message_id": "00000000-0000-0000-0000-000000000002",
|
||||
"message_data": SimpleNamespace(from_account_id="acc", from_end_user_id=None),
|
||||
"inputs": {"prompt": "hi"},
|
||||
"outputs": "ok",
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"error": None,
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MessageTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo:
|
||||
defaults = {
|
||||
"metadata": {"conversation_id": "conv", "user_id": "u"},
|
||||
"message_id": "00000000-0000-0000-0000-000000000003",
|
||||
"message_data": SimpleNamespace(),
|
||||
"inputs": "q",
|
||||
"documents": [SimpleNamespace()],
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return DatasetRetrievalTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_tool_trace_info(**overrides) -> ToolTraceInfo:
|
||||
defaults = {
|
||||
"tool_name": "tool",
|
||||
"tool_inputs": {"x": 1},
|
||||
"tool_outputs": "out",
|
||||
"tool_config": {"desc": "d"},
|
||||
"tool_parameters": {},
|
||||
"time_cost": 0.1,
|
||||
"metadata": {"conversation_id": "conv", "user_id": "u"},
|
||||
"message_id": "00000000-0000-0000-0000-000000000004",
|
||||
"message_data": SimpleNamespace(),
|
||||
"inputs": {"i": "v"},
|
||||
"outputs": {"o": "v"},
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"error": None,
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ToolTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo:
|
||||
defaults = {
|
||||
"suggested_question": ["q1", "q2"],
|
||||
"level": "info",
|
||||
"total_tokens": 1,
|
||||
"metadata": {"conversation_id": "conv", "user_id": "u", "ls_model_name": "m", "ls_provider": "p"},
|
||||
"message_id": "00000000-0000-0000-0000-000000000005",
|
||||
"inputs": {"i": 1},
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"error": None,
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return SuggestedQuestionTraceInfo(**defaults)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(monkeypatch: pytest.MonkeyPatch) -> AliyunDataTrace:
|
||||
monkeypatch.setattr(aliyun_trace_module, "build_endpoint", lambda base_url, license_key: "built-endpoint")
|
||||
monkeypatch.setattr(aliyun_trace_module, "TraceClient", RecordingTraceClient)
|
||||
# Mock get_service_account_with_tenant to avoid DB errors
|
||||
monkeypatch.setattr(AliyunDataTrace, "get_service_account_with_tenant", lambda self, app_id: MagicMock())
|
||||
|
||||
config = AliyunConfig(app_name="app", license_key="k", endpoint="https://example.com")
|
||||
trace = AliyunDataTrace(config)
|
||||
return trace
|
||||
|
||||
|
||||
def test_init_builds_endpoint_and_client(monkeypatch: pytest.MonkeyPatch):
|
||||
build_endpoint = MagicMock(return_value="built")
|
||||
trace_client_cls = MagicMock()
|
||||
monkeypatch.setattr(aliyun_trace_module, "build_endpoint", build_endpoint)
|
||||
monkeypatch.setattr(aliyun_trace_module, "TraceClient", trace_client_cls)
|
||||
|
||||
config = AliyunConfig(app_name="my-app", license_key="license", endpoint="https://example.com")
|
||||
trace = AliyunDataTrace(config)
|
||||
|
||||
build_endpoint.assert_called_once_with("https://example.com", "license")
|
||||
trace_client_cls.assert_called_once_with(service_name="my-app", endpoint="built")
|
||||
assert trace.trace_config == config
|
||||
|
||||
|
||||
def test_trace_dispatches_to_correct_methods(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
workflow_trace = MagicMock()
|
||||
message_trace = MagicMock()
|
||||
suggested_question_trace = MagicMock()
|
||||
dataset_retrieval_trace = MagicMock()
|
||||
tool_trace = MagicMock()
|
||||
monkeypatch.setattr(trace_instance, "workflow_trace", workflow_trace)
|
||||
monkeypatch.setattr(trace_instance, "message_trace", message_trace)
|
||||
monkeypatch.setattr(trace_instance, "suggested_question_trace", suggested_question_trace)
|
||||
monkeypatch.setattr(trace_instance, "dataset_retrieval_trace", dataset_retrieval_trace)
|
||||
monkeypatch.setattr(trace_instance, "tool_trace", tool_trace)
|
||||
|
||||
trace_instance.trace(_make_workflow_trace_info())
|
||||
workflow_trace.assert_called_once()
|
||||
|
||||
trace_instance.trace(_make_message_trace_info())
|
||||
message_trace.assert_called_once()
|
||||
|
||||
trace_instance.trace(_make_suggested_question_trace_info())
|
||||
suggested_question_trace.assert_called_once()
|
||||
|
||||
trace_instance.trace(_make_dataset_retrieval_trace_info())
|
||||
dataset_retrieval_trace.assert_called_once()
|
||||
|
||||
trace_instance.trace(_make_tool_trace_info())
|
||||
tool_trace.assert_called_once()
|
||||
|
||||
# Branches that do nothing but should be covered
|
||||
trace_instance.trace(ModerationTraceInfo(flagged=False, action="allow", preset_response="", query="", metadata={}))
|
||||
trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={}))
|
||||
|
||||
|
||||
def test_api_check_delegates(trace_instance: AliyunDataTrace):
|
||||
trace_instance.trace_client.api_check = MagicMock(return_value=False)
|
||||
assert trace_instance.api_check() is False
|
||||
|
||||
|
||||
def test_get_project_url_success(trace_instance: AliyunDataTrace):
|
||||
assert trace_instance.get_project_url() == "project-url"
|
||||
|
||||
|
||||
def test_get_project_url_error(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(trace_instance.trace_client, "get_project_url", MagicMock(side_effect=Exception("boom")))
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(aliyun_trace_module, "logger", logger_mock)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Aliyun get project url failed: boom"):
|
||||
trace_instance.get_project_url()
|
||||
logger_mock.info.assert_called()
|
||||
|
||||
|
||||
def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 111)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"workflow": 222}.get(span_type, 0)
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
|
||||
add_workflow_span = MagicMock()
|
||||
get_workflow_node_executions = MagicMock(return_value=[MagicMock(), MagicMock()])
|
||||
build_workflow_node_span = MagicMock(side_effect=["span-1", "span-2"])
|
||||
monkeypatch.setattr(trace_instance, "add_workflow_span", add_workflow_span)
|
||||
monkeypatch.setattr(trace_instance, "get_workflow_node_executions", get_workflow_node_executions)
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_node_span", build_workflow_node_span)
|
||||
|
||||
trace_info = _make_workflow_trace_info(
|
||||
trace_id="abcd", metadata={"conversation_id": "c", "user_id": "u", "app_id": "app"}
|
||||
)
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
add_workflow_span.assert_called_once()
|
||||
passed_trace_metadata = add_workflow_span.call_args.args[1]
|
||||
assert passed_trace_metadata.trace_id == 111
|
||||
assert passed_trace_metadata.workflow_span_id == 222
|
||||
assert passed_trace_metadata.session_id == "c"
|
||||
assert passed_trace_metadata.user_id == "u"
|
||||
assert passed_trace_metadata.links == []
|
||||
|
||||
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
|
||||
|
||||
|
||||
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_message_trace_info(message_data=None)
|
||||
trace_instance.message_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
|
||||
|
||||
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module,
|
||||
"convert_to_span_id",
|
||||
lambda _, span_type: {"message": 20, "llm": 30}.get(span_type, 0),
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_user_id_from_message_data", lambda _: "user")
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={"conversation_id": "conv", "ls_model_name": "model", "ls_provider": "provider"},
|
||||
message_tokens=7,
|
||||
answer_tokens=11,
|
||||
total_tokens=18,
|
||||
outputs="completion",
|
||||
)
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span, llm_span = trace_instance.trace_client.added_spans
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.trace_id == 10
|
||||
assert message_span.parent_span_id is None
|
||||
assert message_span.span_id == 20
|
||||
assert message_span.span_kind == SpanKind.SERVER
|
||||
assert message_span.status == status
|
||||
assert message_span.attributes["gen_ai.span.kind"] == GenAISpanKind.CHAIN
|
||||
|
||||
assert llm_span.name == "llm"
|
||||
assert llm_span.parent_span_id == 20
|
||||
assert llm_span.span_id == 30
|
||||
assert llm_span.status == status
|
||||
assert llm_span.attributes[GEN_AI_REQUEST_MODEL] == "model"
|
||||
assert llm_span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "18"
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 1)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 2}.get(span_type, 0)
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 3)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
|
||||
|
||||
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
assert span.name == "dataset_retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
|
||||
|
||||
|
||||
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_tool_trace_info(message_data=None)
|
||||
trace_instance.tool_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
|
||||
|
||||
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0)
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 30)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_instance.tool_trace(
|
||||
_make_tool_trace_info(
|
||||
tool_name="my-tool",
|
||||
tool_inputs={"a": 1},
|
||||
tool_config={"description": "x"},
|
||||
inputs={"i": 1},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
assert span.name == "my-tool"
|
||||
assert span.status == status
|
||||
assert span.attributes[TOOL_NAME] == "my-tool"
|
||||
assert span.attributes[TOOL_DESCRIPTION] == '{"description": "x"}'
|
||||
|
||||
|
||||
def test_get_workflow_node_executions_requires_app_id(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_workflow_trace_info(metadata={"conversation_id": "c"})
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.get_workflow_node_executions(trace_info)
|
||||
|
||||
|
||||
def test_get_workflow_node_executions_builds_repo_and_fetches(
|
||||
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
trace_info = _make_workflow_trace_info(metadata={"app_id": "app", "conversation_id": "c", "user_id": "u"})
|
||||
|
||||
account = object()
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", MagicMock(return_value=account))
|
||||
monkeypatch.setattr(aliyun_trace_module, "sessionmaker", MagicMock())
|
||||
monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine"))
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = ["node1"]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
result = trace_instance.get_workflow_node_executions(trace_info)
|
||||
assert result == ["node1"]
|
||||
repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id)
|
||||
|
||||
|
||||
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
|
||||
|
||||
node_execution.node_type = NodeType.LLM
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "llm"
|
||||
|
||||
|
||||
def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
||||
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
|
||||
|
||||
node_execution.node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "retrieval"
|
||||
|
||||
|
||||
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
|
||||
|
||||
node_execution.node_type = NodeType.TOOL
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "tool"
|
||||
|
||||
|
||||
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
|
||||
|
||||
node_execution.node_type = NodeType.CODE
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "task"
|
||||
|
||||
|
||||
def test_build_workflow_node_span_handles_errors(
|
||||
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_execution.node_type = NodeType.CODE
|
||||
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) is None
|
||||
assert "Error occurred in build_workflow_node_span" in caplog.text
|
||||
|
||||
|
||||
def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "title"
|
||||
node_execution.inputs = {"a": 1}
|
||||
node_execution.outputs = {"b": 2}
|
||||
node_execution.created_at = _dt()
|
||||
node_execution.finished_at = _dt()
|
||||
|
||||
span = trace_instance.build_workflow_task_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span.trace_id == 1
|
||||
assert span.span_id == 9
|
||||
assert span.status.status_code == StatusCode.OK
|
||||
assert span.attributes["gen_ai.span.kind"] == GenAISpanKind.TASK
|
||||
|
||||
|
||||
def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "my-tool"
|
||||
node_execution.inputs = {"a": 1}
|
||||
node_execution.outputs = {"b": 2}
|
||||
node_execution.created_at = _dt()
|
||||
node_execution.finished_at = _dt()
|
||||
node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"k": "v"}}
|
||||
|
||||
span = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span.attributes[TOOL_NAME] == "my-tool"
|
||||
assert span.attributes[TOOL_DESCRIPTION] == '{"k": "v"}'
|
||||
assert span.attributes[TOOL_PARAMETERS] == '{"a": 1}'
|
||||
assert span.status.status_code == StatusCode.OK
|
||||
|
||||
# Cover metadata is None and inputs is None
|
||||
node_execution.metadata = None
|
||||
node_execution.inputs = None
|
||||
span2 = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span2.attributes[TOOL_DESCRIPTION] == "{}"
|
||||
assert span2.attributes[TOOL_PARAMETERS] == "{}"
|
||||
|
||||
|
||||
def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
|
||||
)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "retrieval"
|
||||
node_execution.inputs = {"query": "q"}
|
||||
node_execution.outputs = {"result": [{"doc": "d"}]}
|
||||
node_execution.created_at = _dt()
|
||||
node_execution.finished_at = _dt()
|
||||
|
||||
span = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "q"
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"formatted": true}]'
|
||||
|
||||
# Cover empty inputs/outputs
|
||||
node_execution.inputs = None
|
||||
node_execution.outputs = None
|
||||
span2 = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span2.attributes[RETRIEVAL_QUERY] == ""
|
||||
assert span2.attributes[RETRIEVAL_DOCUMENT] == "[]"
|
||||
|
||||
|
||||
def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "llm"
|
||||
node_execution.process_data = {
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
"prompts": ["p"],
|
||||
"model_name": "m",
|
||||
"model_provider": "p1",
|
||||
}
|
||||
node_execution.outputs = {"text": "t", "finish_reason": "stop"}
|
||||
node_execution.created_at = _dt()
|
||||
node_execution.finished_at = _dt()
|
||||
|
||||
span = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "3"
|
||||
assert span.attributes[GEN_AI_REQUEST_MODEL] == "m"
|
||||
assert span.attributes[GEN_AI_PROMPT] == '["p"]'
|
||||
assert span.attributes[GEN_AI_COMPLETION] == "t"
|
||||
assert span.attributes[GEN_AI_RESPONSE_FINISH_REASON] == "stop"
|
||||
assert span.attributes[GEN_AI_INPUT_MESSAGE] == "in"
|
||||
assert span.attributes[GEN_AI_OUTPUT_MESSAGE] == "out"
|
||||
|
||||
# Cover usage from outputs if not in process_data
|
||||
node_execution.process_data = {"prompts": []}
|
||||
node_execution.outputs = {"usage": {"total_tokens": 10}, "text": ""}
|
||||
span2 = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span2.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "10"
|
||||
|
||||
|
||||
def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0)
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
|
||||
# CASE 1: With message_id
|
||||
trace_info = _make_workflow_trace_info(
|
||||
message_id="msg-1", workflow_run_inputs={"sys.query": "hi"}, workflow_run_outputs={"ans": "ok"}
|
||||
)
|
||||
trace_instance.add_workflow_span(trace_info, trace_metadata)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span = trace_instance.trace_client.added_spans[0]
|
||||
workflow_span = trace_instance.trace_client.added_spans[1]
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.span_kind == SpanKind.SERVER
|
||||
assert message_span.parent_span_id is None
|
||||
|
||||
assert workflow_span.name == "workflow"
|
||||
assert workflow_span.span_kind == SpanKind.INTERNAL
|
||||
assert workflow_span.parent_span_id == 20
|
||||
|
||||
trace_instance.trace_client.added_spans.clear()
|
||||
|
||||
# CASE 2: Without message_id
|
||||
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
|
||||
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
assert span.name == "workflow"
|
||||
assert span.span_kind == SpanKind.SERVER
|
||||
assert span.parent_span_id is None
|
||||
|
||||
|
||||
def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module,
|
||||
"convert_to_span_id",
|
||||
lambda _, span_type: {"message": 20, "suggested_question": 21}.get(span_type, 0),
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
assert span.name == "suggested_question"
|
||||
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'
|
||||
@ -0,0 +1,275 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from opentelemetry.trace import Link, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
)
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
extract_retrieval_documents,
|
||||
format_input_messages,
|
||||
format_output_messages,
|
||||
format_retrieval_documents,
|
||||
get_user_id_from_message_data,
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from models import EndUser
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_no_end_user(monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = None
|
||||
|
||||
assert get_user_id_from_message_data(message_data) == "account_id"
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
|
||||
end_user_data = MagicMock(spec=EndUser)
|
||||
end_user_data.session_id = "session_id"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = end_user_data
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
assert get_user_id_from_message_data(message_data) == "session_id"
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
assert get_user_id_from_message_data(message_data) == "account_id"
|
||||
|
||||
|
||||
def test_create_status_from_error():
|
||||
# Case OK
|
||||
status_ok = create_status_from_error(None)
|
||||
assert status_ok.status_code == StatusCode.OK
|
||||
|
||||
# Case Error
|
||||
status_err = create_status_from_error("some error")
|
||||
assert status_err.status_code == StatusCode.ERROR
|
||||
assert status_err.description == "some error"
|
||||
|
||||
|
||||
def test_get_workflow_node_status():
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
|
||||
# SUCCEEDED
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
status = get_workflow_node_status(node_execution)
|
||||
assert status.status_code == StatusCode.OK
|
||||
|
||||
# FAILED
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = "node fail"
|
||||
status = get_workflow_node_status(node_execution)
|
||||
assert status.status_code == StatusCode.ERROR
|
||||
assert status.description == "node fail"
|
||||
|
||||
# EXCEPTION
|
||||
node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
node_execution.error = "node exception"
|
||||
status = get_workflow_node_status(node_execution)
|
||||
assert status.status_code == StatusCode.ERROR
|
||||
assert status.description == "node exception"
|
||||
|
||||
# UNSET/OTHER
|
||||
node_execution.status = WorkflowNodeExecutionStatus.RUNNING
|
||||
status = get_workflow_node_status(node_execution)
|
||||
assert status.status_code == StatusCode.UNSET
|
||||
|
||||
|
||||
def test_create_links_from_trace_id(monkeypatch):
|
||||
# Mock create_link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
import core.ops.aliyun_trace.data_exporter.traceclient
|
||||
|
||||
monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
|
||||
|
||||
# Trace ID None
|
||||
assert create_links_from_trace_id(None) == []
|
||||
|
||||
# Trace ID Present
|
||||
links = create_links_from_trace_id("trace_id")
|
||||
assert len(links) == 1
|
||||
assert links[0] == mock_link
|
||||
|
||||
|
||||
def test_extract_retrieval_documents():
|
||||
doc1 = MagicMock(spec=Document)
|
||||
doc1.page_content = "content1"
|
||||
doc1.metadata = {"dataset_id": "ds1", "doc_id": "di1", "document_id": "dd1", "score": 0.9}
|
||||
|
||||
doc2 = MagicMock(spec=Document)
|
||||
doc2.page_content = "content2"
|
||||
doc2.metadata = {"dataset_id": "ds2"} # Missing some keys
|
||||
|
||||
documents = [doc1, doc2]
|
||||
extracted = extract_retrieval_documents(documents)
|
||||
|
||||
assert len(extracted) == 2
|
||||
assert extracted[0]["content"] == "content1"
|
||||
assert extracted[0]["metadata"]["dataset_id"] == "ds1"
|
||||
assert extracted[0]["score"] == 0.9
|
||||
|
||||
assert extracted[1]["content"] == "content2"
|
||||
assert extracted[1]["metadata"]["dataset_id"] == "ds2"
|
||||
assert extracted[1]["metadata"]["doc_id"] is None
|
||||
assert extracted[1]["score"] is None
|
||||
|
||||
|
||||
def test_serialize_json_data():
|
||||
data = {"a": 1}
|
||||
# Test ensure_ascii default (False)
|
||||
assert serialize_json_data(data) == json.dumps(data, ensure_ascii=False)
|
||||
# Test ensure_ascii True
|
||||
assert serialize_json_data(data, ensure_ascii=True) == json.dumps(data, ensure_ascii=True)
|
||||
|
||||
|
||||
def test_create_common_span_attributes():
|
||||
attrs = create_common_span_attributes(
|
||||
session_id="s1", user_id="u1", span_kind="kind1", framework="fw1", inputs="in1", outputs="out1"
|
||||
)
|
||||
assert attrs[GEN_AI_SESSION_ID] == "s1"
|
||||
assert attrs[GEN_AI_USER_ID] == "u1"
|
||||
assert attrs[GEN_AI_SPAN_KIND] == "kind1"
|
||||
assert attrs[GEN_AI_FRAMEWORK] == "fw1"
|
||||
assert attrs[INPUT_VALUE] == "in1"
|
||||
assert attrs[OUTPUT_VALUE] == "out1"
|
||||
|
||||
|
||||
def test_format_retrieval_documents():
|
||||
# Not a list
|
||||
assert format_retrieval_documents("not a list") == []
|
||||
|
||||
# Valid list
|
||||
docs = [
|
||||
{"metadata": {"score": 0.8, "document_id": "doc1", "source": "src1"}, "content": "c1", "title": "t1"},
|
||||
{
|
||||
"metadata": {"_source": "src2", "doc_metadata": {"extra": "val"}},
|
||||
"content": "c2",
|
||||
# Missing title
|
||||
},
|
||||
"not a dict", # Should be skipped
|
||||
]
|
||||
formatted = format_retrieval_documents(docs)
|
||||
|
||||
assert len(formatted) == 2
|
||||
assert formatted[0]["document"]["content"] == "c1"
|
||||
assert formatted[0]["document"]["metadata"]["title"] == "t1"
|
||||
assert formatted[0]["document"]["metadata"]["source"] == "src1"
|
||||
assert formatted[0]["document"]["score"] == 0.8
|
||||
assert formatted[0]["document"]["id"] == "doc1"
|
||||
|
||||
assert formatted[1]["document"]["content"] == "c2"
|
||||
assert formatted[1]["document"]["metadata"]["source"] == "src2"
|
||||
assert formatted[1]["document"]["metadata"]["extra"] == "val"
|
||||
assert "title" not in formatted[1]["document"]["metadata"]
|
||||
assert formatted[1]["document"]["score"] == 0.0 # Default
|
||||
|
||||
# Exception handling
|
||||
# We can trigger an exception by passing something that causes an error in the loop logic,
|
||||
# but the try/except covers the whole function.
|
||||
# Passing a list that contains something that throws when calling .get() - though dicts won't.
|
||||
# Let's mock a dict that raises on get.
|
||||
class BadDict:
|
||||
def get(self, *args, **kwargs):
|
||||
raise Exception("boom")
|
||||
|
||||
assert format_retrieval_documents([BadDict()]) == []
|
||||
|
||||
|
||||
def test_format_input_messages():
|
||||
# Not a dict
|
||||
assert format_input_messages(None) == serialize_json_data([])
|
||||
|
||||
# No prompts
|
||||
assert format_input_messages({}) == serialize_json_data([])
|
||||
|
||||
# Valid prompts
|
||||
process_data = {
|
||||
"prompts": [
|
||||
{"role": "user", "text": "hello"},
|
||||
{"role": "assistant", "text": "hi"},
|
||||
{"role": "system", "text": "be helpful"},
|
||||
{"role": "tool", "text": "result"},
|
||||
{"role": "invalid", "text": "skip me"},
|
||||
"not a dict",
|
||||
{"role": "user", "text": ""}, # Empty text, should be skipped? Code says `if text: message = ...`
|
||||
]
|
||||
}
|
||||
result = format_input_messages(process_data)
|
||||
result_list = json.loads(result)
|
||||
|
||||
assert len(result_list) == 4
|
||||
assert result_list[0]["role"] == "user"
|
||||
assert result_list[0]["parts"][0]["content"] == "hello"
|
||||
assert result_list[1]["role"] == "assistant"
|
||||
assert result_list[2]["role"] == "system"
|
||||
assert result_list[3]["role"] == "tool"
|
||||
|
||||
# Exception path
|
||||
assert format_input_messages({"prompts": [None]}) == serialize_json_data([])
|
||||
|
||||
|
||||
def test_format_output_messages():
|
||||
# Not a dict
|
||||
assert format_output_messages(None) == serialize_json_data([])
|
||||
|
||||
# No text
|
||||
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])
|
||||
|
||||
# Valid
|
||||
outputs = {"text": "done", "finish_reason": "length"}
|
||||
result = format_output_messages(outputs)
|
||||
result_list = json.loads(result)
|
||||
assert len(result_list) == 1
|
||||
assert result_list[0]["role"] == "assistant"
|
||||
assert result_list[0]["parts"][0]["content"] == "done"
|
||||
assert result_list[0]["finish_reason"] == "length"
|
||||
|
||||
# Invalid finish reason
|
||||
outputs2 = {"text": "done", "finish_reason": "unknown"}
|
||||
result2 = format_output_messages(outputs2)
|
||||
result_list2 = json.loads(result2)
|
||||
assert result_list2[0]["finish_reason"] == "stop"
|
||||
|
||||
# Exception path
|
||||
# Trigger exception in serialize_json_data by passing non-serializable
|
||||
assert format_output_messages({"text": MagicMock()}) == serialize_json_data([])
|
||||
@ -0,0 +1,398 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import Tracer
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
|
||||
ArizePhoenixDataTrace,
|
||||
datetime_to_nanos,
|
||||
error_to_string,
|
||||
safe_json_dumps,
|
||||
set_span_status,
|
||||
setup_tracer,
|
||||
wrap_span_metadata,
|
||||
)
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def _dt():
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def _make_workflow_info(**kwargs):
|
||||
defaults = {
|
||||
"workflow_id": "w1",
|
||||
"tenant_id": "t1",
|
||||
"workflow_run_id": "r1",
|
||||
"workflow_run_elapsed_time": 1.0,
|
||||
"workflow_run_status": "succeeded",
|
||||
"workflow_run_inputs": {"in": "val"},
|
||||
"workflow_run_outputs": {"out": "val"},
|
||||
"workflow_run_version": "1.0",
|
||||
"total_tokens": 10,
|
||||
"file_list": ["f1"],
|
||||
"query": "hi",
|
||||
"metadata": {"app_id": "app1"},
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt() + timedelta(seconds=1),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return WorkflowTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_message_info(**kwargs):
|
||||
defaults = {
|
||||
"conversation_model": "chat",
|
||||
"message_tokens": 5,
|
||||
"answer_tokens": 5,
|
||||
"total_tokens": 10,
|
||||
"conversation_mode": "chat",
|
||||
"metadata": {"app_id": "app1"},
|
||||
"inputs": {"in": "val"},
|
||||
"outputs": "val",
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"message_id": "m1",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return MessageTraceInfo(**defaults)
|
||||
|
||||
|
||||
# --- Utility Function Tests ---
|
||||
|
||||
|
||||
def test_datetime_to_nanos():
|
||||
dt = _dt()
|
||||
expected = int(dt.timestamp() * 1_000_000_000)
|
||||
assert datetime_to_nanos(dt) == expected
|
||||
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt:
|
||||
mock_now = MagicMock()
|
||||
mock_now.timestamp.return_value = 1704110400.0
|
||||
mock_dt.now.return_value = mock_now
|
||||
assert datetime_to_nanos(None) == 1704110400000000000
|
||||
|
||||
|
||||
def test_error_to_string():
|
||||
try:
|
||||
raise ValueError("boom")
|
||||
except ValueError as e:
|
||||
err = e
|
||||
|
||||
res = error_to_string(err)
|
||||
assert "ValueError: boom" in res
|
||||
assert "traceback" in res.lower() or "line" in res.lower()
|
||||
|
||||
assert error_to_string("str error") == "str error"
|
||||
assert error_to_string(None) == "Empty Stack Trace"
|
||||
|
||||
|
||||
def test_set_span_status():
|
||||
span = MagicMock()
|
||||
# OK
|
||||
set_span_status(span, None)
|
||||
span.set_status.assert_called()
|
||||
assert span.set_status.call_args[0][0].status_code == StatusCode.OK
|
||||
|
||||
# Error Exception
|
||||
span.reset_mock()
|
||||
set_span_status(span, ValueError("fail"))
|
||||
assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR
|
||||
span.record_exception.assert_called()
|
||||
|
||||
# Error String
|
||||
span.reset_mock()
|
||||
set_span_status(span, "fail-str")
|
||||
assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR
|
||||
span.add_event.assert_called()
|
||||
|
||||
# repr branch
|
||||
class SilentError:
|
||||
def __str__(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
return "SilentErrorRepr"
|
||||
|
||||
span.reset_mock()
|
||||
set_span_status(span, SilentError())
|
||||
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"
|
||||
|
||||
|
||||
def test_safe_json_dumps():
|
||||
assert safe_json_dumps({"a": _dt()}) == '{"a": "2024-01-01 00:00:00+00:00"}'
|
||||
|
||||
|
||||
def test_wrap_span_metadata():
|
||||
res = wrap_span_metadata({"a": 1}, b=2)
|
||||
assert res == {"a": 1, "b": 2, "created_from": "Dify"}
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
def test_setup_tracer_arize(mock_provider, mock_exporter):
|
||||
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
|
||||
setup_tracer(config)
|
||||
mock_exporter.assert_called_once()
|
||||
assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1"
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
def test_setup_tracer_phoenix(mock_provider, mock_exporter):
|
||||
config = PhoenixConfig(endpoint="http://p.com", project="p")
|
||||
setup_tracer(config)
|
||||
mock_exporter.assert_called_once()
|
||||
assert mock_exporter.call_args[1]["endpoint"] == "http://p.com/v1/traces"
|
||||
|
||||
|
||||
def test_setup_tracer_exception():
|
||||
config = ArizeConfig(endpoint="http://a.com", project="p")
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
|
||||
with pytest.raises(Exception, match="boom"):
|
||||
setup_tracer(config)
|
||||
|
||||
|
||||
# --- ArizePhoenixDataTrace Class Tests ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance():
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup:
|
||||
mock_tracer = MagicMock(spec=Tracer)
|
||||
mock_processor = MagicMock()
|
||||
mock_setup.return_value = (mock_tracer, mock_processor)
|
||||
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
|
||||
return ArizePhoenixDataTrace(config)
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance):
|
||||
with (
|
||||
patch.object(trace_instance, "workflow_trace") as m1,
|
||||
patch.object(trace_instance, "message_trace") as m2,
|
||||
patch.object(trace_instance, "moderation_trace") as m3,
|
||||
patch.object(trace_instance, "suggested_question_trace") as m4,
|
||||
patch.object(trace_instance, "dataset_retrieval_trace") as m5,
|
||||
patch.object(trace_instance, "tool_trace") as m6,
|
||||
patch.object(trace_instance, "generate_name_trace") as m7,
|
||||
):
|
||||
trace_instance.trace(_make_workflow_info())
|
||||
m1.assert_called()
|
||||
|
||||
trace_instance.trace(_make_message_info())
|
||||
m2.assert_called()
|
||||
|
||||
trace_instance.trace(ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={}))
|
||||
m3.assert_called()
|
||||
|
||||
trace_instance.trace(SuggestedQuestionTraceInfo(suggested_question=[], total_tokens=0, level="i", metadata={}))
|
||||
m4.assert_called()
|
||||
|
||||
trace_instance.trace(DatasetRetrievalTraceInfo(metadata={}))
|
||||
m5.assert_called()
|
||||
|
||||
trace_instance.trace(
|
||||
ToolTraceInfo(
|
||||
tool_name="t",
|
||||
tool_inputs={},
|
||||
tool_outputs="o",
|
||||
metadata={},
|
||||
tool_config={},
|
||||
time_cost=1,
|
||||
tool_parameters={},
|
||||
)
|
||||
)
|
||||
m6.assert_called()
|
||||
|
||||
trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={}))
|
||||
m7.assert_called()
|
||||
|
||||
|
||||
def test_trace_exception(trace_instance):
|
||||
with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("fail")):
|
||||
with pytest.raises(RuntimeError):
|
||||
trace_instance.trace(_make_workflow_info())
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_workflow_info()
|
||||
repo = MagicMock()
|
||||
mock_repo_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
|
||||
node1 = MagicMock()
|
||||
node1.node_type = "llm"
|
||||
node1.status = "succeeded"
|
||||
node1.inputs = {"q": "hi"}
|
||||
node1.outputs = {"a": "bye", "usage": {"total_tokens": 5}}
|
||||
node1.created_at = _dt()
|
||||
node1.elapsed_time = 1.0
|
||||
node1.process_data = {
|
||||
"prompts": [{"role": "user", "content": "hi"}],
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
}
|
||||
node1.metadata = {"k": "v"}
|
||||
node1.title = "title"
|
||||
node1.id = "n1"
|
||||
node1.error = None
|
||||
|
||||
repo.get_by_workflow_run.return_value = [node1]
|
||||
|
||||
with patch.object(trace_instance, "get_service_account_with_tenant"):
|
||||
trace_instance.workflow_trace(info)
|
||||
|
||||
assert trace_instance.tracer.start_span.call_count >= 2
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_workflow_trace_no_app_id(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_workflow_info()
|
||||
info.metadata = {}
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(info)
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_message_trace_success(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_message_info()
|
||||
info.message_data = MagicMock()
|
||||
info.message_data.from_account_id = "acc1"
|
||||
info.message_data.from_end_user_id = None
|
||||
info.message_data.query = "q"
|
||||
info.message_data.answer = "a"
|
||||
info.message_data.status = "s"
|
||||
info.message_data.model_id = "m"
|
||||
info.message_data.model_provider = "p"
|
||||
info.message_data.message_metadata = "{}"
|
||||
info.message_data.error = None
|
||||
info.error = None
|
||||
|
||||
trace_instance.message_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_message_trace_with_error(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_message_info()
|
||||
info.message_data = MagicMock()
|
||||
info.message_data.from_account_id = "acc1"
|
||||
info.message_data.from_end_user_id = None
|
||||
info.message_data.query = "q"
|
||||
info.message_data.answer = "a"
|
||||
info.message_data.status = "s"
|
||||
info.message_data.model_id = "m"
|
||||
info.message_data.model_provider = "p"
|
||||
info.message_data.message_metadata = "{}"
|
||||
info.message_data.error = "processing failed"
|
||||
info.error = "message error"
|
||||
|
||||
trace_instance.message_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_trace_methods_return_early_with_no_message_data(trace_instance):
|
||||
info = MagicMock()
|
||||
info.message_data = None
|
||||
|
||||
trace_instance.moderation_trace(info)
|
||||
trace_instance.suggested_question_trace(info)
|
||||
trace_instance.dataset_retrieval_trace(info)
|
||||
trace_instance.tool_trace(info)
|
||||
trace_instance.generate_name_trace(info)
|
||||
|
||||
assert trace_instance.tracer.start_span.call_count == 0
|
||||
|
||||
|
||||
def test_moderation_trace_ok(trace_instance):
|
||||
info = ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={})
|
||||
info.message_data = MagicMock()
|
||||
info.message_data.error = None
|
||||
trace_instance.moderation_trace(info)
|
||||
# root span (1) + moderation span (1) = 2
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_suggested_question_trace_ok(trace_instance):
|
||||
info = SuggestedQuestionTraceInfo(suggested_question=["?"], total_tokens=1, level="i", metadata={})
|
||||
info.message_data = MagicMock()
|
||||
info.error = None
|
||||
trace_instance.suggested_question_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_ok(trace_instance):
|
||||
info = DatasetRetrievalTraceInfo(documents=[], metadata={})
|
||||
info.message_data = MagicMock()
|
||||
info.error = None
|
||||
trace_instance.dataset_retrieval_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_tool_trace_ok(trace_instance):
|
||||
info = ToolTraceInfo(
|
||||
tool_name="t", tool_inputs={}, tool_outputs="o", metadata={}, tool_config={}, time_cost=1, tool_parameters={}
|
||||
)
|
||||
info.message_data = MagicMock()
|
||||
info.error = None
|
||||
trace_instance.tool_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_generate_name_trace_ok(trace_instance):
|
||||
info = GenerateNameTraceInfo(tenant_id="t", metadata={})
|
||||
info.message_data = MagicMock()
|
||||
info.message_data.error = None
|
||||
trace_instance.generate_name_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_get_project_url_phoenix(trace_instance):
|
||||
trace_instance.arize_phoenix_config = PhoenixConfig(endpoint="http://p.com", project="p")
|
||||
assert "p.com/projects/" in trace_instance.get_project_url()
|
||||
|
||||
|
||||
def test_set_attribute_none_logic(trace_instance):
|
||||
# Test role can be None
|
||||
attrs = trace_instance._construct_llm_attributes([{"role": None, "content": "hi"}])
|
||||
assert "llm.input_messages.0.message.role" not in attrs
|
||||
|
||||
# Test tool call id can be None
|
||||
tool_call_none_id = {"id": None, "function": {"name": "f1"}}
|
||||
attrs = trace_instance._construct_llm_attributes([{"role": "assistant", "tool_calls": [tool_call_none_id]}])
|
||||
assert "llm.input_messages.0.message.tool_calls.0.tool_call.id" not in attrs
|
||||
|
||||
|
||||
def test_construct_llm_attributes_dict_branch(trace_instance):
|
||||
attrs = trace_instance._construct_llm_attributes({"prompt": "hi"})
|
||||
assert '"prompt": "hi"' in attrs["llm.input_messages.0.message.content"]
|
||||
assert attrs["llm.input_messages.0.message.role"] == "user"
|
||||
|
||||
|
||||
def test_api_check_success(trace_instance):
|
||||
assert trace_instance.api_check() is True
|
||||
|
||||
|
||||
def test_ensure_root_span_basic(trace_instance):
|
||||
trace_instance.ensure_root_span("tid")
|
||||
assert "tid" in trace_instance.dify_trace_ids
|
||||
@ -0,0 +1,698 @@
|
||||
import collections
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
LangfuseTrace,
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from dify_graph.enums import NodeType
|
||||
from models import EndUser
|
||||
from models.enums import MessageStatus
|
||||
|
||||
|
||||
def _dt() -> datetime:
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def langfuse_config():
|
||||
return LangfuseConfig(public_key="pk-123", secret_key="sk-123", host="https://cloud.langfuse.com")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(langfuse_config, monkeypatch):
|
||||
# Mock Langfuse client to avoid network calls
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
|
||||
|
||||
instance = LangFuseDataTrace(langfuse_config)
|
||||
return instance
|
||||
|
||||
|
||||
def test_init(langfuse_config, monkeypatch):
|
||||
mock_langfuse = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = LangFuseDataTrace(langfuse_config)
|
||||
|
||||
mock_langfuse.assert_called_once_with(
|
||||
public_key=langfuse_config.public_key,
|
||||
secret_key=langfuse_config.secret_key,
|
||||
host=langfuse_config.host,
|
||||
)
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
"moderation_trace",
|
||||
"suggested_question_trace",
|
||||
"dataset_retrieval_trace",
|
||||
"tool_trace",
|
||||
"generate_name_trace",
|
||||
]
|
||||
mocks = {method: MagicMock() for method in methods}
|
||||
for method, m in mocks.items():
|
||||
monkeypatch.setattr(trace_instance, method, m)
|
||||
|
||||
# WorkflowTraceInfo
|
||||
info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["workflow_trace"].assert_called_once_with(info)
|
||||
|
||||
# MessageTraceInfo
|
||||
info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["message_trace"].assert_called_once_with(info)
|
||||
|
||||
# ModerationTraceInfo
|
||||
info = MagicMock(spec=ModerationTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["moderation_trace"].assert_called_once_with(info)
|
||||
|
||||
# SuggestedQuestionTraceInfo
|
||||
info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["suggested_question_trace"].assert_called_once_with(info)
|
||||
|
||||
# DatasetRetrievalTraceInfo
|
||||
info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
|
||||
|
||||
# ToolTraceInfo
|
||||
info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["tool_trace"].assert_called_once_with(info)
|
||||
|
||||
# GenerateNameTraceInfo
|
||||
info = MagicMock(spec=GenerateNameTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
# Setup trace info
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={"input": "hi"},
|
||||
workflow_run_outputs={"output": "hello"},
|
||||
workflow_run_version="1.0",
|
||||
message_id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
total_tokens=100,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id="trace-1",
|
||||
metadata={"app_id": "app-1", "user_id": "user-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
)
|
||||
|
||||
# Mock DB and Repositories
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
# Mock node executions
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = "node-llm"
|
||||
node_llm.title = "LLM Node"
|
||||
node_llm.node_type = NodeType.LLM
|
||||
node_llm.status = "succeeded"
|
||||
node_llm.process_data = {
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||
}
|
||||
node_llm.inputs = {"prompts": "p"}
|
||||
node_llm.outputs = {"text": "t"}
|
||||
node_llm.created_at = _dt()
|
||||
node_llm.elapsed_time = 0.5
|
||||
node_llm.metadata = {"foo": "bar"}
|
||||
|
||||
node_other = MagicMock()
|
||||
node_other.id = "node-other"
|
||||
node_other.title = "Other Node"
|
||||
node_other.node_type = NodeType.CODE
|
||||
node_other.status = "failed"
|
||||
node_other.process_data = None
|
||||
node_other.inputs = {"code": "print"}
|
||||
node_other.outputs = {"result": "ok"}
|
||||
node_other.created_at = None # Trigger datetime.now() branch
|
||||
node_other.elapsed_time = 0.2
|
||||
node_other.metadata = None
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node_llm, node_other]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
# Track calls to add_trace, add_span, add_generation
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
# Verify add_trace (Workflow Level)
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
|
||||
assert trace_data.id == "trace-1"
|
||||
assert trace_data.name == TraceTaskName.MESSAGE_TRACE
|
||||
assert "message" in trace_data.tags
|
||||
assert "workflow" in trace_data.tags
|
||||
|
||||
# Verify add_span (Workflow Run Span)
|
||||
assert trace_instance.add_span.call_count >= 1
|
||||
# First span should be workflow run span because message_id is present
|
||||
workflow_span = trace_instance.add_span.call_args_list[0][1]["langfuse_span_data"]
|
||||
assert workflow_span.id == "run-1"
|
||||
assert workflow_span.name == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
# Verify Generation for LLM node
|
||||
trace_instance.add_generation.assert_called_once()
|
||||
gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"]
|
||||
assert gen_data.id == "node-llm"
|
||||
assert gen_data.usage.input == 10
|
||||
assert gen_data.usage.output == 20
|
||||
|
||||
# Verify normal span for Other node
|
||||
# Second add_span call
|
||||
other_span = trace_instance.add_span.call_args_list[1][1]["langfuse_span_data"]
|
||||
assert other_span.id == "node-other"
|
||||
assert other_span.level == LevelEnum.ERROR
|
||||
|
||||
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1.0",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
trace_id=None, # Should fallback to workflow_run_id
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
|
||||
assert trace_data.id == "run-1"
|
||||
assert trace_data.name == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1.0",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={}, # Missing app_id
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = None
|
||||
message_data.provider_response_latency = 0.5
|
||||
message_data.conversation_id = "conv-1"
|
||||
message_data.total_price = 0.01
|
||||
message_data.model_id = "gpt-4"
|
||||
message_data.answer = "hello"
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.error = None
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={"query": "hi"},
|
||||
outputs={"answer": "hello"},
|
||||
message_tokens=10,
|
||||
answer_tokens=20,
|
||||
total_tokens=30,
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id="trace-1",
|
||||
metadata={"foo": "bar"},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=[],
|
||||
error=None,
|
||||
)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_instance.add_generation.assert_called_once()
|
||||
|
||||
gen_data = trace_instance.add_generation.call_args[0][0]
|
||||
assert gen_data.name == "llm"
|
||||
assert gen_data.usage.total == 30
|
||||
|
||||
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = "end-user-1"
|
||||
message_data.conversation_id = "conv-1"
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.model_id = "gpt-4"
|
||||
message_data.error = ""
|
||||
message_data.answer = "hello"
|
||||
message_data.total_price = 0.0
|
||||
message_data.provider_response_latency = 0.1
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={},
|
||||
outputs={},
|
||||
message_tokens=0,
|
||||
answer_tokens=0,
|
||||
total_tokens=0,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=[],
|
||||
error=None,
|
||||
)
|
||||
|
||||
# Mock DB session for EndUser lookup
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
|
||||
assert trace_data.user_id == "session-id-123"
|
||||
assert trace_data.metadata["user_id"] == "session-id-123"
|
||||
|
||||
|
||||
def test_message_trace_none_data(trace_instance):
|
||||
trace_info = SimpleNamespace(message_data=None, file_list=[], metadata={})
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.message_trace(trace_info)
|
||||
trace_instance.add_trace.assert_not_called()
|
||||
|
||||
|
||||
def test_moderation_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={"q": "hi"},
|
||||
action="stop",
|
||||
flagged=True,
|
||||
preset_response="blocked",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={"foo": "bar"},
|
||||
trace_id="trace-1",
|
||||
query="hi",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
|
||||
assert span_data.name == TraceTaskName.MODERATION_TRACE
|
||||
assert span_data.output["flagged"] is True
|
||||
|
||||
|
||||
def test_suggested_question_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.error = None
|
||||
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs="hi",
|
||||
suggested_question=["q1"],
|
||||
total_tokens=10,
|
||||
level="info",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_generation = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
trace_instance.add_generation.assert_called_once()
|
||||
gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"]
|
||||
assert gen_data.name == TraceTaskName.SUGGESTED_QUESTION_TRACE
|
||||
assert gen_data.usage.unit == UnitEnum.CHARACTERS
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs="query",
|
||||
documents=[{"id": "doc1"}],
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
|
||||
assert span_data.name == TraceTaskName.DATASET_RETRIEVAL_TRACE
|
||||
assert span_data.output["documents"] == [{"id": "doc1"}]
|
||||
|
||||
|
||||
def test_tool_trace(trace_instance):
|
||||
trace_info = ToolTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=MagicMock(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
tool_name="my_tool",
|
||||
tool_inputs={"a": 1},
|
||||
tool_outputs="result_string",
|
||||
time_cost=0.1,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
tool_config={},
|
||||
tool_parameters={},
|
||||
error="some error",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.tool_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
|
||||
assert span_data.name == "my_tool"
|
||||
assert span_data.level == LevelEnum.ERROR
|
||||
|
||||
|
||||
def test_generate_name_trace(trace_instance):
|
||||
trace_info = GenerateNameTraceInfo(
|
||||
inputs={"q": "hi"},
|
||||
outputs={"name": "new"},
|
||||
tenant_id="tenant-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={"m": 1},
|
||||
)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.generate_name_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_instance.add_span.assert_called_once()
|
||||
|
||||
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
|
||||
assert trace_data.name == TraceTaskName.GENERATE_NAME_TRACE
|
||||
assert trace_data.user_id == "tenant-1"
|
||||
|
||||
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
|
||||
assert span_data.trace_id == "conv-1"
|
||||
|
||||
|
||||
def test_add_trace_success(trace_instance):
|
||||
data = LangfuseTrace(id="t1", name="trace")
|
||||
trace_instance.add_trace(data)
|
||||
trace_instance.langfuse_client.trace.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trace_error(trace_instance):
|
||||
trace_instance.langfuse_client.trace.side_effect = Exception("error")
|
||||
data = LangfuseTrace(id="t1", name="trace")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"):
|
||||
trace_instance.add_trace(data)
|
||||
|
||||
|
||||
def test_add_span_success(trace_instance):
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
trace_instance.add_span(data)
|
||||
trace_instance.langfuse_client.span.assert_called_once()
|
||||
|
||||
|
||||
def test_add_span_error(trace_instance):
|
||||
trace_instance.langfuse_client.span.side_effect = Exception("error")
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create span: error"):
|
||||
trace_instance.add_span(data)
|
||||
|
||||
|
||||
def test_update_span(trace_instance):
|
||||
span = MagicMock()
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
trace_instance.update_span(span, data)
|
||||
span.end.assert_called_once()
|
||||
|
||||
|
||||
def test_add_generation_success(trace_instance):
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
trace_instance.add_generation(data)
|
||||
trace_instance.langfuse_client.generation.assert_called_once()
|
||||
|
||||
|
||||
def test_add_generation_error(trace_instance):
|
||||
trace_instance.langfuse_client.generation.side_effect = Exception("error")
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"):
|
||||
trace_instance.add_generation(data)
|
||||
|
||||
|
||||
def test_update_generation(trace_instance):
|
||||
gen = MagicMock()
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
trace_instance.update_generation(gen, data)
|
||||
gen.end.assert_called_once()
|
||||
|
||||
|
||||
def test_api_check_success(trace_instance):
|
||||
trace_instance.langfuse_client.auth_check.return_value = True
|
||||
assert trace_instance.api_check() is True
|
||||
|
||||
|
||||
def test_api_check_error(trace_instance):
|
||||
trace_instance.langfuse_client.auth_check.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="LangFuse API check failed: fail"):
|
||||
trace_instance.api_check()
|
||||
|
||||
|
||||
def test_get_project_key_success(trace_instance):
|
||||
mock_data = MagicMock()
|
||||
mock_data.id = "proj-1"
|
||||
trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data])
|
||||
assert trace_instance.get_project_key() == "proj-1"
|
||||
|
||||
|
||||
def test_get_project_key_error(trace_instance):
|
||||
trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="LangFuse get project key failed: fail"):
|
||||
trace_instance.get_project_key()
|
||||
|
||||
|
||||
def test_moderation_trace_none(trace_instance):
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="m",
|
||||
message_data=None,
|
||||
inputs={},
|
||||
action="s",
|
||||
flagged=False,
|
||||
preset_response="",
|
||||
query="",
|
||||
metadata={},
|
||||
)
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_suggested_question_trace_none(trace_instance):
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="m", message_data=None, inputs={}, suggested_question=[], total_tokens=0, level="i", metadata={}
|
||||
)
|
||||
trace_instance.add_generation = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
trace_instance.add_generation.assert_not_called()
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_none(trace_instance):
|
||||
trace_info = DatasetRetrievalTraceInfo(message_id="m", message_data=None, inputs={}, documents=[], metadata={})
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_langfuse_trace_entity_with_list_dict_input():
|
||||
# To cover lines 29-31 in langfuse_trace_entity.py
|
||||
# We need to mock replace_text_with_content or just check if it works
|
||||
# Actually replace_text_with_content is imported from core.ops.utils
|
||||
data = LangfuseTrace(id="t1", name="n", input=[{"text": "hello"}])
|
||||
assert isinstance(data.input, list)
|
||||
assert data.input[0]["content"] == "hello"
|
||||
|
||||
|
||||
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog):
|
||||
# Setup trace info to trigger LLM node usage extraction
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="t",
|
||||
workflow_run_id="r",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="s",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="c",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="l",
|
||||
error="",
|
||||
)
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "n1"
|
||||
node.title = "LLM Node"
|
||||
node.node_type = NodeType.LLM
|
||||
node.status = "succeeded"
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
return super().get(key, default)
|
||||
|
||||
node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
|
||||
node.created_at = _dt()
|
||||
node.elapsed_time = 0.1
|
||||
node.metadata = {}
|
||||
node.outputs = {}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
assert "Failed to extract usage" in caplog.text
|
||||
trace_instance.add_generation.assert_called_once()
|
||||
@ -0,0 +1,608 @@
|
||||
import collections
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser
|
||||
|
||||
|
||||
def _dt() -> datetime:
|
||||
return datetime(2024, 1, 1, 0, 0, 0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def langsmith_config():
|
||||
return LangSmithConfig(api_key="ls-123", project="default", endpoint="https://api.smith.langchain.com")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(langsmith_config, monkeypatch):
|
||||
# Mock LangSmith client
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client)
|
||||
|
||||
instance = LangSmithDataTrace(langsmith_config)
|
||||
return instance
|
||||
|
||||
|
||||
def test_init(langsmith_config, monkeypatch):
|
||||
mock_client_class = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = LangSmithDataTrace(langsmith_config)
|
||||
|
||||
mock_client_class.assert_called_once_with(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
|
||||
assert instance.langsmith_key == langsmith_config.api_key
|
||||
assert instance.project_name == langsmith_config.project
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
"moderation_trace",
|
||||
"suggested_question_trace",
|
||||
"dataset_retrieval_trace",
|
||||
"tool_trace",
|
||||
"generate_name_trace",
|
||||
]
|
||||
mocks = {method: MagicMock() for method in methods}
|
||||
for method, m in mocks.items():
|
||||
monkeypatch.setattr(trace_instance, method, m)
|
||||
|
||||
# WorkflowTraceInfo
|
||||
info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["workflow_trace"].assert_called_once_with(info)
|
||||
|
||||
# MessageTraceInfo
|
||||
info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["message_trace"].assert_called_once_with(info)
|
||||
|
||||
# ModerationTraceInfo
|
||||
info = MagicMock(spec=ModerationTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["moderation_trace"].assert_called_once_with(info)
|
||||
|
||||
# SuggestedQuestionTraceInfo
|
||||
info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["suggested_question_trace"].assert_called_once_with(info)
|
||||
|
||||
# DatasetRetrievalTraceInfo
|
||||
info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
|
||||
|
||||
# ToolTraceInfo
|
||||
info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["tool_trace"].assert_called_once_with(info)
|
||||
|
||||
# GenerateNameTraceInfo
|
||||
info = MagicMock(spec=GenerateNameTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace(trace_instance, monkeypatch):
|
||||
# Setup trace info
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
tenant_id="tenant-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_inputs={"input": "hi"},
|
||||
workflow_run_outputs={"output": "hello"},
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_version="1.0",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
total_tokens=100,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
message_id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id="trace-1",
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
workflow_data=workflow_data,
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
# Mock node executions
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = "node-llm"
|
||||
node_llm.title = "LLM Node"
|
||||
node_llm.node_type = NodeType.LLM
|
||||
node_llm.status = "succeeded"
|
||||
node_llm.process_data = {
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||
}
|
||||
node_llm.inputs = {"prompts": "p"}
|
||||
node_llm.outputs = {"text": "t"}
|
||||
node_llm.created_at = _dt()
|
||||
node_llm.elapsed_time = 0.5
|
||||
node_llm.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 30}
|
||||
|
||||
node_other = MagicMock()
|
||||
node_other.id = "node-other"
|
||||
node_other.title = "Tool Node"
|
||||
node_other.node_type = NodeType.TOOL
|
||||
node_other.status = "succeeded"
|
||||
node_other.process_data = None
|
||||
node_other.inputs = {"tool_input": "val"}
|
||||
node_other.outputs = {"tool_output": "val"}
|
||||
node_other.created_at = None # Trigger datetime.now()
|
||||
node_other.elapsed_time = 0.2
|
||||
node_other.metadata = {}
|
||||
|
||||
node_retrieval = MagicMock()
|
||||
node_retrieval.id = "node-retrieval"
|
||||
node_retrieval.title = "Retrieval Node"
|
||||
node_retrieval.node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
node_retrieval.status = "succeeded"
|
||||
node_retrieval.process_data = None
|
||||
node_retrieval.inputs = {"query": "val"}
|
||||
node_retrieval.outputs = {"results": "val"}
|
||||
node_retrieval.created_at = _dt()
|
||||
node_retrieval.elapsed_time = 0.2
|
||||
node_retrieval.metadata = {}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
# Verify add_run calls
|
||||
# 1. message run (id="msg-1")
|
||||
# 2. workflow run (id="run-1")
|
||||
# 3. node llm run (id="node-llm")
|
||||
# 4. node other run (id="node-other")
|
||||
# 5. node retrieval run (id="node-retrieval")
|
||||
assert trace_instance.add_run.call_count == 5
|
||||
|
||||
call_args = [call[0][0] for call in trace_instance.add_run.call_args_list]
|
||||
|
||||
assert call_args[0].id == "msg-1"
|
||||
assert call_args[0].name == TraceTaskName.MESSAGE_TRACE
|
||||
|
||||
assert call_args[1].id == "run-1"
|
||||
assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE
|
||||
assert call_args[1].parent_run_id == "msg-1"
|
||||
|
||||
assert call_args[2].id == "node-llm"
|
||||
assert call_args[2].run_type == LangSmithRunType.llm
|
||||
|
||||
assert call_args[3].id == "node-other"
|
||||
assert call_args[3].run_type == LangSmithRunType.tool
|
||||
|
||||
assert call_args[4].id == "node-retrieval"
|
||||
assert call_args[4].run_type == LangSmithRunType.retriever
|
||||
|
||||
|
||||
def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
tenant_id="tenant-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_version="1.0",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
total_tokens=10,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
message_id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trace_id="trace-1",
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
workflow_data=workflow_data,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
assert trace_instance.add_run.called
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.trace_id = "trace-1"
|
||||
trace_info.message_id = None
|
||||
trace_info.workflow_run_id = "run-1"
|
||||
trace_info.start_time = None
|
||||
trace_info.workflow_data = MagicMock()
|
||||
trace_info.workflow_data.created_at = _dt()
|
||||
trace_info.metadata = {} # Empty metadata
|
||||
trace_info.workflow_app_log_id = "log-1"
|
||||
trace_info.file_list = []
|
||||
trace_info.total_tokens = 0
|
||||
trace_info.workflow_run_inputs = {}
|
||||
trace_info.workflow_run_outputs = {}
|
||||
trace_info.error = ""
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace(trace_instance, monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = "end-user-1"
|
||||
message_data.answer = "hello answer"
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={"input": "hi"},
|
||||
outputs={"answer": "hello"},
|
||||
message_tokens=10,
|
||||
answer_tokens=20,
|
||||
total_tokens=30,
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id="trace-1",
|
||||
metadata={"foo": "bar"},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=[],
|
||||
error=None,
|
||||
message_file_data=MagicMock(url="file-url"),
|
||||
)
|
||||
|
||||
# Mock EndUser lookup
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
# 1. message run
|
||||
# 2. llm run
|
||||
assert trace_instance.add_run.call_count == 2
|
||||
|
||||
call_args = [call[0][0] for call in trace_instance.add_run.call_args_list]
|
||||
assert call_args[0].id == "msg-1"
|
||||
assert call_args[0].extra["metadata"]["end_user_id"] == "session-id-123"
|
||||
assert call_args[1].parent_run_id == "msg-1"
|
||||
assert call_args[1].name == "llm"
|
||||
|
||||
|
||||
def test_message_trace_no_data(trace_instance):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.message_data = None
|
||||
trace_info.file_list = []
|
||||
trace_info.message_file_data = None
|
||||
trace_info.metadata = {}
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.message_trace(trace_info)
|
||||
trace_instance.add_run.assert_not_called()
|
||||
|
||||
|
||||
def test_moderation_trace_no_data(trace_instance):
|
||||
trace_info = MagicMock(spec=ModerationTraceInfo)
|
||||
trace_info.message_data = None
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
trace_instance.add_run.assert_not_called()
|
||||
|
||||
|
||||
def test_suggested_question_trace_no_data(trace_instance):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
trace_info.message_data = None
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
trace_instance.add_run.assert_not_called()
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_no_data(trace_instance):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_data = None
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
trace_instance.add_run.assert_not_called()
|
||||
|
||||
|
||||
def test_moderation_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={"q": "hi"},
|
||||
action="stop",
|
||||
flagged=True,
|
||||
preset_response="blocked",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
query="hi",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.MODERATION_TRACE
|
||||
|
||||
|
||||
def test_suggested_question_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs="hi",
|
||||
suggested_question=["q1"],
|
||||
total_tokens=10,
|
||||
level="info",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.SUGGESTED_QUESTION_TRACE
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs="query",
|
||||
documents=[{"id": "doc1"}],
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.DATASET_RETRIEVAL_TRACE
|
||||
|
||||
|
||||
def test_tool_trace(trace_instance):
|
||||
trace_info = ToolTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=MagicMock(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
tool_name="my_tool",
|
||||
tool_inputs={"a": 1},
|
||||
tool_outputs="result",
|
||||
time_cost=0.1,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
tool_config={},
|
||||
tool_parameters={},
|
||||
file_url="http://file",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.tool_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == "my_tool"
|
||||
|
||||
|
||||
def test_generate_name_trace(trace_instance):
|
||||
trace_info = GenerateNameTraceInfo(
|
||||
inputs={"q": "hi"},
|
||||
outputs={"name": "new"},
|
||||
tenant_id="tenant-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.generate_name_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.GENERATE_NAME_TRACE
|
||||
|
||||
|
||||
def test_add_run_success(trace_instance):
|
||||
run_data = LangSmithRunModel(
|
||||
id="run-1", name="test", inputs={}, outputs={}, run_type=LangSmithRunType.tool, start_time=_dt()
|
||||
)
|
||||
trace_instance.project_id = "proj-1"
|
||||
trace_instance.add_run(run_data)
|
||||
trace_instance.langsmith_client.create_run.assert_called_once()
|
||||
args, kwargs = trace_instance.langsmith_client.create_run.call_args
|
||||
assert kwargs["session_id"] == "proj-1"
|
||||
|
||||
|
||||
def test_add_run_error(trace_instance):
|
||||
run_data = LangSmithRunModel(id="run-1", name="test", run_type=LangSmithRunType.tool, start_time=_dt())
|
||||
trace_instance.langsmith_client.create_run.side_effect = Exception("failed")
|
||||
with pytest.raises(ValueError, match="LangSmith Failed to create run: failed"):
|
||||
trace_instance.add_run(run_data)
|
||||
|
||||
|
||||
def test_update_run_success(trace_instance):
|
||||
update_data = LangSmithRunUpdateModel(run_id="run-1", outputs={"out": "val"})
|
||||
trace_instance.update_run(update_data)
|
||||
trace_instance.langsmith_client.update_run.assert_called_once()
|
||||
|
||||
|
||||
def test_update_run_error(trace_instance):
|
||||
update_data = LangSmithRunUpdateModel(run_id="run-1")
|
||||
trace_instance.langsmith_client.update_run.side_effect = Exception("failed")
|
||||
with pytest.raises(ValueError, match="LangSmith Failed to update run: failed"):
|
||||
trace_instance.update_run(update_data)
|
||||
|
||||
|
||||
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog):
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
tenant_id="tenant-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_version="1.0",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
total_tokens=100,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
message_id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
trace_id="trace-1",
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
workflow_data=workflow_data,
|
||||
)
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
return super().get(key, default)
|
||||
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = "node-llm"
|
||||
node_llm.title = "LLM Node"
|
||||
node_llm.node_type = NodeType.LLM
|
||||
node_llm.status = "succeeded"
|
||||
node_llm.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
|
||||
node_llm.inputs = {}
|
||||
node_llm.outputs = {}
|
||||
node_llm.created_at = _dt()
|
||||
node_llm.elapsed_time = 0.5
|
||||
node_llm.metadata = {}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node_llm]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
assert "Failed to extract usage" in caplog.text
|
||||
|
||||
|
||||
def test_api_check_success(trace_instance):
|
||||
assert trace_instance.api_check() is True
|
||||
assert trace_instance.langsmith_client.create_project.called
|
||||
assert trace_instance.langsmith_client.delete_project.called
|
||||
|
||||
|
||||
def test_api_check_error(trace_instance):
|
||||
trace_instance.langsmith_client.create_project.side_effect = Exception("error")
|
||||
with pytest.raises(ValueError, match="LangSmith API check failed: error"):
|
||||
trace_instance.api_check()
|
||||
|
||||
|
||||
def test_get_project_url_success(trace_instance):
|
||||
trace_instance.langsmith_client.get_run_url.return_value = "https://smith.langchain.com/o/org/p/proj/r/run"
|
||||
url = trace_instance.get_project_url()
|
||||
assert url == "https://smith.langchain.com/o/org/p/proj"
|
||||
|
||||
|
||||
def test_get_project_url_error(trace_instance):
|
||||
trace_instance.langsmith_client.get_run_url.side_effect = Exception("error")
|
||||
with pytest.raises(ValueError, match="LangSmith get run url failed: error"):
|
||||
trace_instance.get_project_url()
|
||||
1019
api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py
Normal file
1019
api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py
Normal file
File diff suppressed because it is too large
Load Diff
678
api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py
Normal file
678
api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py
Normal file
@ -0,0 +1,678 @@
|
||||
import collections
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser
|
||||
from models.enums import MessageStatus
|
||||
|
||||
|
||||
def _dt() -> datetime:
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opik_config():
|
||||
return OpikConfig(
|
||||
project="test-project", workspace="test-workspace", url="https://cloud.opik.com/api/", api_key="api-key-123"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(opik_config, monkeypatch):
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client)
|
||||
|
||||
instance = OpikDataTrace(opik_config)
|
||||
return instance
|
||||
|
||||
|
||||
def test_wrap_dict():
|
||||
assert wrap_dict("input", {"a": 1}) == {"a": 1}
|
||||
assert wrap_dict("input", "hello") == {"input": "hello"}
|
||||
|
||||
|
||||
def test_wrap_metadata():
|
||||
assert wrap_metadata({"a": 1}, b=2) == {"a": 1, "b": 2, "created_from": "dify"}
|
||||
|
||||
|
||||
def test_prepare_opik_uuid():
|
||||
# Test with valid datetime and uuid string
|
||||
dt = datetime(2024, 1, 1)
|
||||
uuid_str = "b3e8e918-472e-4b69-8051-12502c34fc07"
|
||||
result = prepare_opik_uuid(dt, uuid_str)
|
||||
assert result is not None
|
||||
# We won't test the exact uuid7 value but just that it returns a string id
|
||||
|
||||
# Test with None dt and uuid_str
|
||||
result = prepare_opik_uuid(None, None)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_init(opik_config, monkeypatch):
|
||||
mock_opik = MagicMock()
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = OpikDataTrace(opik_config)
|
||||
|
||||
mock_opik.assert_called_once_with(
|
||||
project_name=opik_config.project,
|
||||
workspace=opik_config.workspace,
|
||||
host=opik_config.url,
|
||||
api_key=opik_config.api_key,
|
||||
)
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
assert instance.project == opik_config.project
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
"moderation_trace",
|
||||
"suggested_question_trace",
|
||||
"dataset_retrieval_trace",
|
||||
"tool_trace",
|
||||
"generate_name_trace",
|
||||
]
|
||||
mocks = {method: MagicMock() for method in methods}
|
||||
for method, m in mocks.items():
|
||||
monkeypatch.setattr(trace_instance, method, m)
|
||||
|
||||
# WorkflowTraceInfo
|
||||
info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["workflow_trace"].assert_called_once_with(info)
|
||||
|
||||
# MessageTraceInfo
|
||||
info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["message_trace"].assert_called_once_with(info)
|
||||
|
||||
# ModerationTraceInfo
|
||||
info = MagicMock(spec=ModerationTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["moderation_trace"].assert_called_once_with(info)
|
||||
|
||||
# SuggestedQuestionTraceInfo
|
||||
info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["suggested_question_trace"].assert_called_once_with(info)
|
||||
|
||||
# DatasetRetrievalTraceInfo
|
||||
info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
|
||||
|
||||
# ToolTraceInfo
|
||||
info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["tool_trace"].assert_called_once_with(info)
|
||||
|
||||
# GenerateNameTraceInfo
|
||||
info = MagicMock(spec=GenerateNameTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
# Define constants for better readability
|
||||
WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce"
|
||||
WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef"
|
||||
MESSAGE_ID = "04ec3956-85f3-488a-8539-1017251dc8c6"
|
||||
CONVERSATION_ID = "d3d01066-23ae-4830-9ce4-eb5640b42a7e"
|
||||
TRACE_ID = "bf26d929-6f15-4c2f-9abc-761c217056f3"
|
||||
WORKFLOW_APP_LOG_ID = "ca0e018e-edd4-43fb-a05a-ea001ca8ef4b"
|
||||
LLM_NODE_ID = "80d7dfa8-08f4-4ab7-aa37-0ca7d27207e3"
|
||||
CODE_NODE_ID = "b9cd9a7b-c534-4aa9-b5da-efd454140900"
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id=WORKFLOW_ID,
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id=WORKFLOW_RUN_ID,
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={"input": "hi"},
|
||||
workflow_run_outputs={"output": "hello"},
|
||||
workflow_run_version="1.0",
|
||||
message_id=MESSAGE_ID,
|
||||
conversation_id=CONVERSATION_ID,
|
||||
total_tokens=100,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id=TRACE_ID,
|
||||
metadata={"app_id": "app-1", "user_id": "user-1"},
|
||||
workflow_app_log_id=WORKFLOW_APP_LOG_ID,
|
||||
error="",
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = LLM_NODE_ID
|
||||
node_llm.title = "LLM Node"
|
||||
node_llm.node_type = NodeType.LLM
|
||||
node_llm.status = "succeeded"
|
||||
node_llm.process_data = {
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
node_llm.inputs = {"prompts": "p"}
|
||||
node_llm.outputs = {"text": "t"}
|
||||
node_llm.created_at = _dt()
|
||||
node_llm.elapsed_time = 0.5
|
||||
node_llm.metadata = {"foo": "bar"}
|
||||
|
||||
node_other = MagicMock()
|
||||
node_other.id = CODE_NODE_ID
|
||||
node_other.title = "Other Node"
|
||||
node_other.node_type = NodeType.CODE
|
||||
node_other.status = "failed"
|
||||
node_other.process_data = None
|
||||
node_other.inputs = {"code": "print"}
|
||||
node_other.outputs = {"result": "ok"}
|
||||
node_other.created_at = None
|
||||
node_other.elapsed_time = 0.2
|
||||
node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node_llm, node_other]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_data = trace_instance.add_trace.call_args[1].get("opik_trace_data", trace_instance.add_trace.call_args[0][0])
|
||||
assert trace_data["name"] == TraceTaskName.MESSAGE_TRACE
|
||||
assert "message" in trace_data["tags"]
|
||||
assert "workflow" in trace_data["tags"]
|
||||
|
||||
assert trace_instance.add_span.call_count >= 1
|
||||
|
||||
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
# Define constants for better readability
|
||||
WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3"
|
||||
WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac"
|
||||
CONVERSATION_ID = "88a17f2e-9436-4472-bab9-4b1601d5af3c"
|
||||
WORKFLOW_APP_LOG_ID = "41780d0d-ffba-4220-bc0c-401e4c89cdfb"
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id=WORKFLOW_ID,
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id=WORKFLOW_RUN_ID,
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1.0",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id=CONVERSATION_ID,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
trace_id=None,
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id=WORKFLOW_APP_LOG_ID,
|
||||
error="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="46f53304-1659-464b-bee5-116585f0bec8",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1.0",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="83f86b89-caef-4de8-a0f9-f164eddae1ea",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e",
|
||||
error="",
|
||||
)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
# Define constants for better readability
|
||||
MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab"
|
||||
CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3"
|
||||
MESSAGE_TRACE_ID = "710ace2f-bca8-41be-858c-54da42742a77"
|
||||
OPIT_TRACE_ID = "f7dfd978-0d10-4549-8abf-00f2cbc49d2c"
|
||||
|
||||
message_data = MagicMock()
|
||||
message_data.id = MESSAGE_DATA_ID
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = None
|
||||
message_data.provider_response_latency = 0.5
|
||||
message_data.conversation_id = CONVERSATION_ID
|
||||
message_data.total_price = 0.01
|
||||
message_data.model_id = "gpt-4"
|
||||
message_data.answer = "hello"
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.error = None
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id=MESSAGE_TRACE_ID,
|
||||
message_data=message_data,
|
||||
inputs={"query": "hi"},
|
||||
outputs={"answer": "hello"},
|
||||
message_tokens=10,
|
||||
answer_tokens=20,
|
||||
total_tokens=30,
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id=OPIT_TRACE_ID,
|
||||
metadata={"foo": "bar"},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=[],
|
||||
error=None,
|
||||
message_file_data=MagicMock(url="test.png"),
|
||||
)
|
||||
|
||||
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_1"))
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_instance.add_span.assert_called_once()
|
||||
|
||||
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e"
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = "end-user-1"
|
||||
message_data.conversation_id = "7d9f96d8-3be2-4e93-9c0e-922ff98dccc6"
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.model_id = "gpt-4"
|
||||
message_data.error = ""
|
||||
message_data.answer = "hello"
|
||||
message_data.total_price = 0.0
|
||||
message_data.provider_response_latency = 0.1
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id="6bff35c7-33b7-4acb-ba21-44569a0327d0",
|
||||
message_data=message_data,
|
||||
inputs={},
|
||||
outputs={},
|
||||
message_tokens=0,
|
||||
answer_tokens=0,
|
||||
total_tokens=0,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=["url1"],
|
||||
error=None,
|
||||
)
|
||||
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query)
|
||||
|
||||
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2"))
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
trace_data = trace_instance.add_trace.call_args[0][0]
|
||||
assert trace_data["metadata"]["user_id"] == "acc-1"
|
||||
assert trace_data["metadata"]["end_user_id"] == "session-id-123"
|
||||
|
||||
|
||||
def test_message_trace_none_data(trace_instance):
|
||||
trace_info = SimpleNamespace(message_data=None, file_list=[], message_file_data=None, metadata={})
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.message_trace(trace_info)
|
||||
trace_instance.add_trace.assert_not_called()
|
||||
|
||||
|
||||
def test_moderation_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="489d0dfd-065c-4106-8f9c-daded296c92d",
|
||||
message_data=message_data,
|
||||
inputs={"q": "hi"},
|
||||
action="stop",
|
||||
flagged=True,
|
||||
preset_response="blocked",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={"foo": "bar"},
|
||||
trace_id="6f16cf18-9f4b-4955-8b6b-43cfa10978fc",
|
||||
query="hi",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["name"] == TraceTaskName.MODERATION_TRACE
|
||||
assert span_data["output"]["flagged"] is True
|
||||
|
||||
|
||||
def test_moderation_trace_none(trace_instance):
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="cd732e4e-37f1-4c7e-8c64-820308bedcbf",
|
||||
message_data=None,
|
||||
inputs={},
|
||||
action="s",
|
||||
flagged=False,
|
||||
preset_response="",
|
||||
query="",
|
||||
metadata={},
|
||||
)
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_suggested_question_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="7de55bda-a91d-477e-98ab-85c53c438469",
|
||||
message_data=message_data,
|
||||
inputs="hi",
|
||||
suggested_question=["q1"],
|
||||
total_tokens=10,
|
||||
level="info",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="a6687292-68c7-42ba-ae51-285579944d7b",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["name"] == TraceTaskName.SUGGESTED_QUESTION_TRACE
|
||||
|
||||
|
||||
def test_suggested_question_trace_none(trace_instance):
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="23696fc5-7e7f-46ec-bce8-1adc3c7f297d",
|
||||
message_data=None,
|
||||
inputs={},
|
||||
suggested_question=[],
|
||||
total_tokens=0,
|
||||
level="i",
|
||||
metadata={},
|
||||
)
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id="3e1a819f-c391-4950-adfd-96f82e5419a1",
|
||||
message_data=message_data,
|
||||
inputs="query",
|
||||
documents=[{"id": "doc1"}],
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="41361000-e9be-4d11-b5e4-ab27ce0817d6",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["name"] == TraceTaskName.DATASET_RETRIEVAL_TRACE
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_none(trace_instance):
|
||||
trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id="35d6d44c-bccb-4e6e-8bd8-859257723ea8", message_data=None, inputs={}, documents=[], metadata={}
|
||||
)
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_tool_trace(trace_instance):
|
||||
trace_info = ToolTraceInfo(
|
||||
message_id="99db92c4-2254-496a-b5cc-18153315ce35",
|
||||
message_data=MagicMock(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
tool_name="my_tool",
|
||||
tool_inputs={"a": 1},
|
||||
tool_outputs="result_string",
|
||||
time_cost=0.1,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="a15a5fcb-7ffd-4458-8330-208f4cb1f796",
|
||||
tool_config={},
|
||||
tool_parameters={},
|
||||
error="some error",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.tool_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["name"] == "my_tool"
|
||||
|
||||
|
||||
def test_generate_name_trace(trace_instance):
|
||||
trace_info = GenerateNameTraceInfo(
|
||||
inputs={"q": "hi"},
|
||||
outputs={"name": "new"},
|
||||
tenant_id="tenant-1",
|
||||
conversation_id="271fe28f-6b86-416b-8d6b-bbbbfa9db791",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={"921f010e-6878-4831-ae6b-271bf68c56fb": 1},
|
||||
)
|
||||
|
||||
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_3"))
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.generate_name_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_instance.add_span.assert_called_once()
|
||||
|
||||
trace_data = trace_instance.add_trace.call_args[0][0]
|
||||
assert trace_data["name"] == TraceTaskName.GENERATE_NAME_TRACE
|
||||
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["trace_id"] == "trace_id_3"
|
||||
|
||||
|
||||
def test_add_trace_success(trace_instance):
|
||||
trace_data = {"id": "t1", "name": "trace"}
|
||||
trace_instance.opik_client.trace.return_value = MagicMock(id="t1")
|
||||
trace = trace_instance.add_trace(trace_data)
|
||||
trace_instance.opik_client.trace.assert_called_once()
|
||||
assert trace.id == "t1"
|
||||
|
||||
|
||||
def test_add_trace_error(trace_instance):
|
||||
trace_instance.opik_client.trace.side_effect = Exception("error")
|
||||
trace_data = {"id": "t1", "name": "trace"}
|
||||
with pytest.raises(ValueError, match="Opik Failed to create trace: error"):
|
||||
trace_instance.add_trace(trace_data)
|
||||
|
||||
|
||||
def test_add_span_success(trace_instance):
|
||||
span_data = {"id": "s1", "name": "span", "trace_id": "t1"}
|
||||
trace_instance.add_span(span_data)
|
||||
trace_instance.opik_client.span.assert_called_once()
|
||||
|
||||
|
||||
def test_add_span_error(trace_instance):
|
||||
trace_instance.opik_client.span.side_effect = Exception("error")
|
||||
span_data = {"id": "s1", "name": "span", "trace_id": "t1"}
|
||||
with pytest.raises(ValueError, match="Opik Failed to create span: error"):
|
||||
trace_instance.add_span(span_data)
|
||||
|
||||
|
||||
def test_api_check_success(trace_instance):
|
||||
trace_instance.opik_client.auth_check.return_value = True
|
||||
assert trace_instance.api_check() is True
|
||||
|
||||
|
||||
def test_api_check_error(trace_instance):
|
||||
trace_instance.opik_client.auth_check.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="Opik API check failed: fail"):
|
||||
trace_instance.api_check()
|
||||
|
||||
|
||||
def test_get_project_url_success(trace_instance):
|
||||
trace_instance.opik_client.get_project_url.return_value = "http://project.url"
|
||||
assert trace_instance.get_project_url() == "http://project.url"
|
||||
trace_instance.opik_client.get_project_url.assert_called_once_with(project_name=trace_instance.project)
|
||||
|
||||
|
||||
def test_get_project_url_error(trace_instance):
|
||||
trace_instance.opik_client.get_project_url.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="Opik get run url failed: fail"):
|
||||
trace_instance.get_project_url()
|
||||
|
||||
|
||||
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de",
|
||||
tenant_id="66e8e918-472e-4b69-8051-12502c34fc07",
|
||||
workflow_run_id="8403965c-3344-4d22-a8fe-d8d55cee64d9",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="s",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="7a02cb9d-6949-4c59-a89d-f25bbc881e0e",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={"app_id": "77e8e918-472e-4b69-8051-12502c34fc07"},
|
||||
workflow_app_log_id="82268424-e193-476c-a6db-f473388ee5fe",
|
||||
error="",
|
||||
)
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "88e8e918-472e-4b69-8051-12502c34fc07"
|
||||
node.title = "LLM Node"
|
||||
node.node_type = NodeType.LLM
|
||||
node.status = "succeeded"
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
return super().get(key, default)
|
||||
|
||||
node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
|
||||
node.created_at = _dt()
|
||||
node.elapsed_time = 0.1
|
||||
node.metadata = {}
|
||||
node.outputs = {}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
assert "Failed to extract usage" in caplog.text
|
||||
assert trace_instance.add_span.call_count >= 1
|
||||
# Verify that at least one of the spans is for the LLM Node
|
||||
span_names = [call.args[0]["name"] for call in trace_instance.add_span.call_args_list]
|
||||
assert "LLM Node" in span_names
|
||||
583
api/tests/unit_tests/core/ops/tencent_trace/test_client.py
Normal file
583
api/tests/unit_tests/core/ops/tencent_trace/test_client.py
Normal file
@ -0,0 +1,583 @@
|
||||
"""Tests for the TencentTraceClient helpers that drive tracing and metrics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from core.ops.tencent_trace import client as client_module
|
||||
from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
|
||||
metric_reader_instances: list[DummyMetricReader] = []
|
||||
meter_provider_instances: list[DummyMeterProvider] = []
|
||||
|
||||
|
||||
class DummyHistogram:
|
||||
"""Placeholder histogram type used by the stubbed metric stack."""
|
||||
|
||||
|
||||
class AggregationTemporality:
|
||||
DELTA = "delta"
|
||||
|
||||
|
||||
class DummyMeter:
|
||||
def __init__(self) -> None:
|
||||
self.created: list[tuple[dict[str, object], MagicMock]] = []
|
||||
|
||||
def create_histogram(self, **kwargs: object) -> MagicMock:
|
||||
hist = MagicMock(name=f"hist-{kwargs.get('name')}")
|
||||
self.created.append((kwargs, hist))
|
||||
return hist
|
||||
|
||||
|
||||
class DummyMeterProvider:
|
||||
def __init__(self, resource: object, metric_readers: list[object]) -> None:
|
||||
self.resource = resource
|
||||
self.metric_readers = metric_readers
|
||||
self.meter = DummyMeter()
|
||||
self.shutdown = MagicMock(name="meter_provider_shutdown")
|
||||
meter_provider_instances.append(self)
|
||||
|
||||
def get_meter(self, name: str, version: str) -> DummyMeter:
|
||||
return self.meter
|
||||
|
||||
|
||||
class DummyMetricReader:
|
||||
def __init__(self, exporter: object, export_interval_millis: int) -> None:
|
||||
self.exporter = exporter
|
||||
self.export_interval_millis = export_interval_millis
|
||||
self.shutdown = MagicMock(name="metric_reader_shutdown")
|
||||
metric_reader_instances.append(self)
|
||||
|
||||
|
||||
class DummyGrpcMetricExporter:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class DummyHttpMetricExporter:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class DummyJsonMetricExporter:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class DummyJsonMetricExporterNoTemporality:
|
||||
"""Exporter that rejects preferred_temporality to exercise fallback."""
|
||||
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
if "preferred_temporality" in kwargs:
|
||||
raise RuntimeError("unsupported preferred_temporality")
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Drop fake metric modules into sys.modules so the client imports resolve."""
|
||||
|
||||
metrics_module = types.ModuleType("opentelemetry.sdk.metrics")
|
||||
metrics_module.Histogram = DummyHistogram
|
||||
metrics_module.MeterProvider = DummyMeterProvider
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics", metrics_module)
|
||||
|
||||
metrics_export_module = types.ModuleType("opentelemetry.sdk.metrics.export")
|
||||
metrics_export_module.AggregationTemporality = AggregationTemporality
|
||||
metrics_export_module.PeriodicExportingMetricReader = DummyMetricReader
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics.export", metrics_export_module)
|
||||
|
||||
grpc_module = types.ModuleType("opentelemetry.exporter.otlp.proto.grpc.metric_exporter")
|
||||
grpc_module.OTLPMetricExporter = DummyGrpcMetricExporter
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.grpc.metric_exporter", grpc_module)
|
||||
|
||||
http_module = types.ModuleType("opentelemetry.exporter.otlp.proto.http.metric_exporter")
|
||||
http_module.OTLPMetricExporter = DummyHttpMetricExporter
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.http.metric_exporter", http_module)
|
||||
|
||||
http_json_module = types.ModuleType("opentelemetry.exporter.otlp.http.json.metric_exporter")
|
||||
http_json_module.OTLPMetricExporter = DummyJsonMetricExporter
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.http.json.metric_exporter", http_json_module)
|
||||
|
||||
legacy_json_module = types.ModuleType("opentelemetry.exporter.otlp.json.metric_exporter")
|
||||
legacy_json_module.OTLPMetricExporter = DummyJsonMetricExporter
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.json.metric_exporter", legacy_json_module)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
metric_reader_instances.clear()
|
||||
meter_provider_instances.clear()
|
||||
_add_stub_modules(monkeypatch)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
span_exporter = MagicMock(name="span_exporter")
|
||||
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
|
||||
|
||||
span_processor = MagicMock(name="span_processor")
|
||||
monkeypatch.setattr(client_module, "BatchSpanProcessor", MagicMock(return_value=span_processor))
|
||||
|
||||
tracer = MagicMock(name="tracer")
|
||||
span = MagicMock(name="span")
|
||||
tracer.start_span.return_value = span
|
||||
|
||||
tracer_provider = MagicMock(name="tracer_provider")
|
||||
tracer_provider.get_tracer.return_value = tracer
|
||||
tracer_provider.shutdown = MagicMock(name="tracer_provider_shutdown")
|
||||
monkeypatch.setattr(client_module, "TracerProvider", MagicMock(return_value=tracer_provider))
|
||||
|
||||
resource = MagicMock(name="resource")
|
||||
monkeypatch.setattr(client_module, "Resource", MagicMock(return_value=resource))
|
||||
|
||||
logger_mock = MagicMock(name="tencent_logger")
|
||||
monkeypatch.setattr(client_module, "logger", logger_mock)
|
||||
|
||||
trace_api_stub = SimpleNamespace(
|
||||
set_span_in_context=MagicMock(name="set_span_in_context", return_value="trace-context"),
|
||||
NonRecordingSpan=MagicMock(name="non_recording_span", side_effect=lambda ctx: f"non-{ctx}"),
|
||||
)
|
||||
monkeypatch.setattr(client_module, "trace_api", trace_api_stub)
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
project=SimpleNamespace(version="test"),
|
||||
COMMIT_SHA="sha",
|
||||
DEPLOY_ENV="dev",
|
||||
EDITION="cloud",
|
||||
)
|
||||
monkeypatch.setattr(client_module, "dify_config", fake_config)
|
||||
|
||||
monkeypatch.setattr(client_module.socket, "gethostname", lambda: "fake-host")
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "")
|
||||
|
||||
return {
|
||||
"span_exporter": span_exporter,
|
||||
"span_processor": span_processor,
|
||||
"tracer": tracer,
|
||||
"span": span,
|
||||
"tracer_provider": tracer_provider,
|
||||
"logger": logger_mock,
|
||||
"trace_api": trace_api_stub,
|
||||
}
|
||||
|
||||
|
||||
def _build_client() -> TencentTraceClient:
|
||||
return TencentTraceClient(
|
||||
service_name="service",
|
||||
endpoint="https://trace.example.com:4317",
|
||||
token="token",
|
||||
)
|
||||
|
||||
|
||||
def test_get_opentelemetry_sdk_version_reads_install(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(client_module, "version", lambda pkg: "2.0.0")
|
||||
assert _get_opentelemetry_sdk_version() == "2.0.0"
|
||||
|
||||
|
||||
def test_get_opentelemetry_sdk_version_falls_back(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(client_module, "version", MagicMock(side_effect=RuntimeError("boom")))
|
||||
assert _get_opentelemetry_sdk_version() == "1.27.0"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("endpoint", "expected"),
|
||||
[
|
||||
(
|
||||
"https://example.com:9090",
|
||||
("example.com:9090", False, "example.com", 9090),
|
||||
),
|
||||
(
|
||||
"http://localhost",
|
||||
("localhost:4317", True, "localhost", 4317),
|
||||
),
|
||||
(
|
||||
"example.com:bad",
|
||||
("example.com:4317", False, "example.com", 4317),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[str, bool, str, int]) -> None:
|
||||
assert TencentTraceClient._resolve_grpc_target(endpoint) == expected
|
||||
|
||||
|
||||
def test_resolve_grpc_target_handles_errors() -> None:
|
||||
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "attr_name", "args"),
|
||||
[
|
||||
("record_llm_duration", "hist_llm_duration", (0.3, {"foo": object()})),
|
||||
("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")),
|
||||
("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")),
|
||||
("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")),
|
||||
("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})),
|
||||
],
|
||||
)
|
||||
def test_record_methods_call_histograms(method: str, attr_name: str, args: tuple[object, ...]) -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name=attr_name)
|
||||
setattr(client, attr_name, hist_mock)
|
||||
|
||||
getattr(client, method)(*args)
|
||||
hist_mock.record.assert_called_once()
|
||||
|
||||
|
||||
def test_record_methods_skip_when_histogram_missing() -> None:
|
||||
client = _build_client()
|
||||
client.hist_llm_duration = None
|
||||
client.record_llm_duration(0.1)
|
||||
|
||||
client.hist_token_usage = None
|
||||
client.record_token_usage(1, "go", "chat", "model", "model", "addr", "provider")
|
||||
|
||||
client.hist_time_to_first_token = None
|
||||
client.record_time_to_first_token(0.2, "prov", "model")
|
||||
|
||||
client.hist_time_to_generate = None
|
||||
client.record_time_to_generate(0.3, "prov", "model")
|
||||
|
||||
client.hist_trace_duration = None
|
||||
client.record_trace_duration(0.5)
|
||||
|
||||
|
||||
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
|
||||
|
||||
client.record_llm_duration(0.2)
|
||||
logger = patch_core_components["logger"]
|
||||
logger.debug.assert_called()
|
||||
|
||||
|
||||
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
parent_span_id=None,
|
||||
span_id=2,
|
||||
name="span",
|
||||
attributes={"key": "value"},
|
||||
events=[Event(name="evt", attributes={"k": "v"}, timestamp=123)],
|
||||
status=Status(StatusCode.OK),
|
||||
start_time=10,
|
||||
end_time=20,
|
||||
)
|
||||
|
||||
client._create_and_export_span(data)
|
||||
span.set_attributes.assert_called_once()
|
||||
span.add_event.assert_called_once()
|
||||
span.set_status.assert_called_once()
|
||||
span.end.assert_called_once_with(end_time=20)
|
||||
assert client.span_contexts[2] == "ctx"
|
||||
|
||||
|
||||
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
client.span_contexts[10] = "existing"
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "child"
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
parent_span_id=10,
|
||||
span_id=11,
|
||||
name="span",
|
||||
attributes={},
|
||||
events=[],
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
|
||||
client._create_and_export_span(data)
|
||||
trace_api = patch_core_components["trace_api"]
|
||||
trace_api.NonRecordingSpan.assert_called_once_with("existing")
|
||||
trace_api.set_span_in_context.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
client.tracer.start_span.side_effect = RuntimeError("boom")
|
||||
|
||||
client._create_and_export_span(
|
||||
SpanData(
|
||||
trace_id=1,
|
||||
parent_span_id=None,
|
||||
span_id=2,
|
||||
name="span",
|
||||
attributes={},
|
||||
events=[],
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
)
|
||||
logger = patch_core_components["logger"]
|
||||
logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_api_check_connects_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = _build_client()
|
||||
|
||||
monkeypatch.setattr(
|
||||
TencentTraceClient,
|
||||
"_resolve_grpc_target",
|
||||
MagicMock(return_value=("host:123", False, "host", 123)),
|
||||
)
|
||||
|
||||
socket_mock = MagicMock()
|
||||
socket_instance = MagicMock()
|
||||
socket_instance.connect_ex.return_value = 0
|
||||
socket_mock.return_value = socket_instance
|
||||
monkeypatch.setattr(client_module.socket, "socket", socket_mock)
|
||||
|
||||
assert client.api_check()
|
||||
socket_instance.connect_ex.assert_called_once()
|
||||
|
||||
|
||||
def test_api_check_returns_false_and_handles_local(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = _build_client()
|
||||
|
||||
monkeypatch.setattr(
|
||||
TencentTraceClient,
|
||||
"_resolve_grpc_target",
|
||||
MagicMock(return_value=("host:123", False, "host", 123)),
|
||||
)
|
||||
|
||||
socket_mock = MagicMock()
|
||||
socket_instance = MagicMock()
|
||||
socket_instance.connect_ex.return_value = 1
|
||||
socket_mock.return_value = socket_instance
|
||||
monkeypatch.setattr(client_module.socket, "socket", socket_mock)
|
||||
|
||||
assert not client.api_check()
|
||||
|
||||
monkeypatch.setattr(
|
||||
TencentTraceClient,
|
||||
"_resolve_grpc_target",
|
||||
MagicMock(return_value=("localhost:4317", True, "localhost", 4317)),
|
||||
)
|
||||
socket_instance.connect_ex.return_value = 1
|
||||
assert client.api_check()
|
||||
|
||||
|
||||
def test_api_check_handles_exceptions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = TencentTraceClient("svc", "https://localhost", "token")
|
||||
|
||||
monkeypatch.setattr(client_module.socket, "socket", MagicMock(side_effect=RuntimeError("boom")))
|
||||
assert client.api_check()
|
||||
|
||||
|
||||
def test_get_project_url() -> None:
|
||||
client = _build_client()
|
||||
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
|
||||
|
||||
|
||||
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
span_processor = patch_core_components["span_processor"]
|
||||
tracer_provider = patch_core_components["tracer_provider"]
|
||||
|
||||
client.shutdown()
|
||||
span_processor.force_flush.assert_called_once()
|
||||
span_processor.shutdown.assert_called_once()
|
||||
tracer_provider.shutdown.assert_called_once()
|
||||
|
||||
meter_provider = meter_provider_instances[-1]
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
meter_provider.shutdown.assert_called_once()
|
||||
metric_reader.shutdown.assert_called_once()
|
||||
|
||||
|
||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
meter_provider = meter_provider_instances[-1]
|
||||
meter_provider.shutdown.side_effect = RuntimeError("boom")
|
||||
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
|
||||
|
||||
client.shutdown()
|
||||
logger = patch_core_components["logger"]
|
||||
logger.debug.assert_any_call(
|
||||
"[Tencent APM] Error shutting down meter provider",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.debug.assert_any_call(
|
||||
"[Tencent APM] Error shutting down metric reader",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(DummyMeterProvider, "__init__", MagicMock(side_effect=RuntimeError("err")))
|
||||
client = _build_client()
|
||||
|
||||
assert client.meter is None
|
||||
assert client.meter_provider is None
|
||||
assert client.hist_llm_duration is None
|
||||
assert client.hist_token_usage is None
|
||||
assert client.hist_time_to_first_token is None
|
||||
assert client.hist_time_to_generate is None
|
||||
assert client.hist_trace_duration is None
|
||||
assert client.metric_reader is None
|
||||
|
||||
|
||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
client.add_span(
|
||||
SpanData(
|
||||
trace_id=1,
|
||||
parent_span_id=None,
|
||||
span_id=2,
|
||||
name="span",
|
||||
attributes={},
|
||||
events=[],
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
)
|
||||
|
||||
logger = patch_core_components["logger"]
|
||||
logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
|
||||
data = SpanData.model_construct(
|
||||
trace_id=1,
|
||||
parent_span_id=None,
|
||||
span_id=2,
|
||||
name="span",
|
||||
attributes={"num": 5, "flag": True, "pi": 3.14, "text": "value"},
|
||||
events=[],
|
||||
links=[],
|
||||
status=Status(StatusCode.OK),
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
|
||||
client._create_and_export_span(data)
|
||||
(attrs,) = span.set_attributes.call_args.args
|
||||
assert attrs["num"] == 5
|
||||
assert attrs["flag"] is True
|
||||
assert attrs["pi"] == 3.14
|
||||
assert attrs["text"] == "value"
|
||||
|
||||
|
||||
def test_record_llm_duration_converts_attributes() -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration = hist_mock
|
||||
|
||||
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["foo"], str)
|
||||
assert attrs["bar"] == 2
|
||||
|
||||
|
||||
def test_record_trace_duration_converts_attributes() -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name="hist_trace_duration")
|
||||
client.hist_trace_duration = hist_mock
|
||||
|
||||
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["meta"], str)
|
||||
assert attrs["ok"] is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "attr_name", "args"),
|
||||
[
|
||||
("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")),
|
||||
("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")),
|
||||
("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")),
|
||||
("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})),
|
||||
],
|
||||
)
|
||||
def test_record_methods_handle_exceptions(
|
||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
|
||||
) -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name=attr_name)
|
||||
hist_mock.record.side_effect = RuntimeError("boom")
|
||||
setattr(client, attr_name, hist_mock)
|
||||
|
||||
getattr(client, method)(*args)
|
||||
logger = patch_core_components["logger"]
|
||||
logger.debug.assert_called()
|
||||
|
||||
|
||||
def test_metrics_initializes_grpc_metric_exporter() -> None:
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||
assert metric_reader.exporter.kwargs["insecure"] is False
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert "preferred_temporality" in metric_reader.exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
exporter_module = sys.modules["opentelemetry.exporter.otlp.http.json.metric_exporter"]
|
||||
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
|
||||
_ = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
|
||||
assert "preferred_temporality" not in metric_reader.exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
|
||||
def _fail_import(mod_path: str) -> types.ModuleType:
|
||||
raise ModuleNotFoundError(mod_path)
|
||||
|
||||
monkeypatch.setattr(client_module.importlib, "import_module", _fail_import)
|
||||
|
||||
_ = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
|
||||
359
api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py
Normal file
359
api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py
Normal file
@ -0,0 +1,359 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.entities.semconv import (
|
||||
GEN_AI_IS_ENTRY,
|
||||
GEN_AI_IS_STREAMING_REQUEST,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
INPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestTencentSpanBuilder:
|
||||
def test_get_time_nanoseconds(self):
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
|
||||
mock_convert.return_value = 123456789
|
||||
dt = datetime.now()
|
||||
result = TencentSpanBuilder._get_time_nanoseconds(dt)
|
||||
assert result == 123456789
|
||||
mock_convert.assert_called_once_with(dt)
|
||||
|
||||
def test_build_workflow_spans(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run_id"
|
||||
trace_info.error = None
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.workflow_run_inputs = {"sys.query": "hello"}
|
||||
trace_info.workflow_run_outputs = {"answer": "world"}
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
|
||||
|
||||
assert len(spans) == 2
|
||||
assert spans[0].name == "message"
|
||||
assert spans[0].span_id == 2
|
||||
assert spans[1].name == "workflow"
|
||||
assert spans[1].span_id == 1
|
||||
assert spans[1].parent_span_id == 2
|
||||
|
||||
def test_build_workflow_spans_no_message(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run_id"
|
||||
trace_info.error = "some error"
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.workflow_run_inputs = {}
|
||||
trace_info.workflow_run_outputs = {}
|
||||
trace_info.metadata = {} # No conversation_id
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 1
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
|
||||
|
||||
assert len(spans) == 1
|
||||
assert spans[0].name == "workflow"
|
||||
assert spans[0].status.status_code == StatusCode.ERROR
|
||||
assert spans[0].status.description == "some error"
|
||||
assert spans[0].attributes[GEN_AI_IS_ENTRY] == "true"
|
||||
|
||||
def test_build_workflow_llm_span(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.process_data = {
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "time_to_first_token": 0.5},
|
||||
"prompts": ["hello"],
|
||||
}
|
||||
node_execution.outputs = {"text": "world"}
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 456
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.name == "GENERATION"
|
||||
assert span.attributes[GEN_AI_MODEL_NAME] == "gpt-4"
|
||||
assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true"
|
||||
assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "10"
|
||||
|
||||
def test_build_workflow_llm_span_usage_in_outputs(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.process_data = {}
|
||||
node_execution.outputs = {
|
||||
"text": "world",
|
||||
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
|
||||
}
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 456
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "15"
|
||||
assert GEN_AI_IS_STREAMING_REQUEST not in span.attributes
|
||||
|
||||
def test_build_message_span_standalone(self):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.error = None
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.inputs = {"q": "hi"}
|
||||
trace_info.outputs = "hello"
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
trace_info.is_streaming_request = True
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 789
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
|
||||
|
||||
assert span.name == "message"
|
||||
assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true"
|
||||
assert span.attributes[INPUT_VALUE] == str(trace_info.inputs)
|
||||
|
||||
def test_build_message_span_standalone_with_error(self):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.error = "some error"
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.inputs = None
|
||||
trace_info.outputs = None
|
||||
trace_info.metadata = {}
|
||||
trace_info.is_streaming_request = False
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 789
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
|
||||
|
||||
assert span.status.status_code == StatusCode.ERROR
|
||||
assert span.status.description == "some error"
|
||||
assert span.attributes[INPUT_VALUE] == ""
|
||||
|
||||
def test_build_tool_span(self):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.tool_name = "search"
|
||||
trace_info.error = "tool error"
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.tool_parameters = {"p": 1}
|
||||
trace_info.tool_inputs = {"i": 2}
|
||||
trace_info.tool_outputs = "result"
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 101
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1)
|
||||
|
||||
assert span.name == "search"
|
||||
assert span.status.status_code == StatusCode.ERROR
|
||||
assert span.attributes[TOOL_NAME] == "search"
|
||||
|
||||
def test_build_retrieval_span(self):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.inputs = "query"
|
||||
trace_info.error = None
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
|
||||
doc = Document(
|
||||
page_content="content", metadata={"dataset_id": "d1", "doc_id": "di1", "document_id": "du1", "score": 0.9}
|
||||
)
|
||||
trace_info.documents = [doc]
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 202
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
|
||||
|
||||
assert span.name == "retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
||||
assert "content" in span.attributes[RETRIEVAL_DOCUMENT]
|
||||
|
||||
def test_build_retrieval_span_with_error(self):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.inputs = ""
|
||||
trace_info.error = "retrieval failed"
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.documents = []
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 202
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
|
||||
|
||||
assert span.status.status_code == StatusCode.ERROR
|
||||
assert span.status.description == "retrieval failed"
|
||||
|
||||
def test_get_workflow_node_status(self):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
|
||||
node.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.OK
|
||||
|
||||
node.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node.error = "fail"
|
||||
status = TencentSpanBuilder._get_workflow_node_status(node)
|
||||
assert status.status_code == StatusCode.ERROR
|
||||
assert status.description == "fail"
|
||||
|
||||
node.status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
node.error = "exc"
|
||||
status = TencentSpanBuilder._get_workflow_node_status(node)
|
||||
assert status.status_code == StatusCode.ERROR
|
||||
assert status.description == "exc"
|
||||
|
||||
node.status = WorkflowNodeExecutionStatus.RUNNING
|
||||
assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.UNSET
|
||||
|
||||
def test_build_workflow_retrieval_span(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my retrieval"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.inputs = {"query": "q1"}
|
||||
node_execution.outputs = {"result": [{"content": "c1"}]}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 303
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.name == "my retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "q1"
|
||||
assert "c1" in span.attributes[RETRIEVAL_DOCUMENT]
|
||||
|
||||
def test_build_workflow_retrieval_span_empty(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my retrieval"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.inputs = {}
|
||||
node_execution.outputs = {}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 303
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.attributes[RETRIEVAL_QUERY] == ""
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == ""
|
||||
|
||||
def test_build_workflow_tool_span(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my tool"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"info": "some"}}
|
||||
node_execution.inputs = {"param": "val"}
|
||||
node_execution.outputs = {"res": "ok"}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 404
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.name == "my tool"
|
||||
assert span.attributes[TOOL_NAME] == "my tool"
|
||||
assert "some" in span.attributes[TOOL_DESCRIPTION]
|
||||
|
||||
def test_build_workflow_tool_span_no_metadata(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my tool"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.metadata = None
|
||||
node_execution.inputs = None
|
||||
node_execution.outputs = {"res": "ok"}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 404
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.attributes[TOOL_DESCRIPTION] == "{}"
|
||||
assert span.attributes[TOOL_PARAMETERS] == "{}"
|
||||
|
||||
def test_build_workflow_task_span(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my task"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.inputs = {"in": 1}
|
||||
node_execution.outputs = {"out": 2}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 505
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.name == "my task"
|
||||
assert span.attributes[GEN_AI_SPAN_KIND] == GenAISpanKind.TASK.value
|
||||
@ -0,0 +1,647 @@
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import NodeType
|
||||
from models import Account, App, TenantAccountJoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tencent_config():
|
||||
return TencentConfig(service_name="test-service", endpoint="https://test-endpoint", token="test-token")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_client():
|
||||
with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_span_builder():
|
||||
with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_utils():
|
||||
with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tencent_data_trace(tencent_config, mock_trace_client):
|
||||
return TencentDataTrace(tencent_config)
|
||||
|
||||
|
||||
class TestTencentDataTrace:
|
||||
def test_init(self, tencent_config, mock_trace_client):
|
||||
trace = TencentDataTrace(tencent_config)
|
||||
mock_trace_client.assert_called_once_with(
|
||||
service_name=tencent_config.service_name,
|
||||
endpoint=tencent_config.endpoint,
|
||||
token=tencent_config.token,
|
||||
metrics_export_interval_sec=5,
|
||||
)
|
||||
assert trace.trace_client == mock_trace_client.return_value
|
||||
|
||||
def test_trace_dispatch(self, tencent_data_trace):
|
||||
methods = [
|
||||
(
|
||||
WorkflowTraceInfo(
|
||||
workflow_id="wf",
|
||||
tenant_id="t",
|
||||
workflow_run_id="run",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="s",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="v",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
metadata={},
|
||||
),
|
||||
"workflow_trace",
|
||||
),
|
||||
(
|
||||
MessageTraceInfo(
|
||||
message_id="msg",
|
||||
message_data={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-3.5-turbo",
|
||||
message_tokens=0,
|
||||
answer_tokens=0,
|
||||
total_tokens=0,
|
||||
metadata={},
|
||||
),
|
||||
"message_trace",
|
||||
),
|
||||
(
|
||||
ModerationTraceInfo(
|
||||
flagged=False, action="a", preset_response="p", query="q", metadata={}, message_id="m"
|
||||
),
|
||||
None,
|
||||
), # Pass
|
||||
(
|
||||
SuggestedQuestionTraceInfo(
|
||||
suggested_question=[],
|
||||
level="l",
|
||||
total_tokens=0,
|
||||
metadata={},
|
||||
message_id="m",
|
||||
message_data={},
|
||||
inputs={},
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
"suggested_question_trace",
|
||||
),
|
||||
(
|
||||
DatasetRetrievalTraceInfo(
|
||||
metadata={},
|
||||
message_id="m",
|
||||
message_data={},
|
||||
inputs={},
|
||||
documents=[],
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
"dataset_retrieval_trace",
|
||||
),
|
||||
(
|
||||
ToolTraceInfo(
|
||||
tool_name="t",
|
||||
tool_inputs={},
|
||||
tool_outputs="",
|
||||
tool_config={},
|
||||
tool_parameters={},
|
||||
time_cost=0,
|
||||
metadata={},
|
||||
message_id="m",
|
||||
inputs={},
|
||||
outputs={},
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
"tool_trace",
|
||||
),
|
||||
(
|
||||
GenerateNameTraceInfo(
|
||||
tenant_id="t", metadata={}, message_id="m", inputs={}, outputs={}, start_time=None, end_time=None
|
||||
),
|
||||
None,
|
||||
), # Pass
|
||||
]
|
||||
|
||||
for trace_info, method_name in methods:
|
||||
if method_name:
|
||||
with patch.object(tencent_data_trace, method_name) as mock_method:
|
||||
tencent_data_trace.trace(trace_info)
|
||||
mock_method.assert_called_once_with(trace_info)
|
||||
else:
|
||||
tencent_data_trace.trace(trace_info)
|
||||
|
||||
def test_api_check(self, tencent_data_trace):
|
||||
tencent_data_trace.trace_client.api_check.return_value = True
|
||||
assert tencent_data_trace.api_check() is True
|
||||
tencent_data_trace.trace_client.api_check.assert_called_once()
|
||||
|
||||
def test_get_project_url(self, tencent_data_trace):
|
||||
tencent_data_trace.trace_client.get_project_url.return_value = "http://url"
|
||||
assert tencent_data_trace.get_project_url() == "http://url"
|
||||
tencent_data_trace.trace_client.get_project_url.assert_called_once()
|
||||
|
||||
def test_workflow_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run-id"
|
||||
trace_info.trace_id = "parent-trace-id"
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
mock_trace_utils.create_link.return_value = "link"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
|
||||
with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc:
|
||||
with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur:
|
||||
mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()]
|
||||
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id")
|
||||
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
|
||||
mock_span_builder.build_workflow_spans.assert_called_once()
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_proc.assert_called_once_with(trace_info, 123)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
|
||||
def test_workflow_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run-id"
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
|
||||
|
||||
def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
trace_info.trace_id = "parent-trace-id"
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
mock_trace_utils.create_link.return_value = "link"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
|
||||
with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics:
|
||||
with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur:
|
||||
mock_span_builder.build_message_span.return_value = MagicMock()
|
||||
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
|
||||
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
|
||||
mock_span_builder.build_message_span.assert_called_once()
|
||||
tencent_data_trace.trace_client.add_span.assert_called_once()
|
||||
mock_metrics.assert_called_once_with(trace_info)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
|
||||
def test_message_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
|
||||
|
||||
def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
mock_trace_utils.convert_to_span_id.return_value = 456
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message")
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
|
||||
mock_span_builder.build_tool_span.assert_called_once_with(trace_info, 123, 456)
|
||||
tencent_data_trace.trace_client.add_span.assert_called_once()
|
||||
|
||||
def test_tool_trace_no_msg_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = None
|
||||
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
tencent_data_trace.trace_client.add_span.assert_not_called()
|
||||
|
||||
def test_tool_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
|
||||
|
||||
def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
mock_trace_utils.convert_to_span_id.return_value = 456
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message")
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
|
||||
mock_span_builder.build_retrieval_span.assert_called_once_with(trace_info, 123, 456)
|
||||
tencent_data_trace.trace_client.add_span.assert_called_once()
|
||||
|
||||
def test_dataset_retrieval_trace_no_msg_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = None
|
||||
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
tencent_data_trace.trace_client.add_span.assert_not_called()
|
||||
|
||||
def test_dataset_retrieval_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
|
||||
|
||||
def test_suggested_question_trace(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log:
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
|
||||
|
||||
def test_suggested_question_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
|
||||
|
||||
def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run-id"
|
||||
mock_trace_utils.convert_to_span_id.return_value = 111
|
||||
|
||||
node1 = MagicMock(spec=WorkflowNodeExecution)
|
||||
node1.id = "n1"
|
||||
node1.node_type = NodeType.LLM
|
||||
node2 = MagicMock(spec=WorkflowNodeExecution)
|
||||
node2.id = "n2"
|
||||
node2.node_type = NodeType.TOOL
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]):
|
||||
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]):
|
||||
with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics:
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_metrics.assert_called_once_with(node1)
|
||||
|
||||
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
mock_trace_utils.convert_to_span_id.return_value = 111
|
||||
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.id = "n1"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
|
||||
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
# The exception should be caught by the outer handler since convert_to_span_id is called first
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
|
||||
|
||||
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
|
||||
|
||||
def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
|
||||
nodes = [
|
||||
(NodeType.LLM, mock_span_builder.build_workflow_llm_span),
|
||||
(NodeType.KNOWLEDGE_RETRIEVAL, mock_span_builder.build_workflow_retrieval_span),
|
||||
(NodeType.TOOL, mock_span_builder.build_workflow_tool_span),
|
||||
(NodeType.CODE, mock_span_builder.build_workflow_task_span),
|
||||
]
|
||||
|
||||
for node_type, builder_method in nodes:
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.node_type = node_type
|
||||
builder_method.return_value = "span"
|
||||
|
||||
result = tencent_data_trace._build_workflow_node_span(node, 123, trace_info, 456)
|
||||
|
||||
assert result == "span"
|
||||
builder_method.assert_called_once_with(123, 456, trace_info, node)
|
||||
|
||||
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.node_type = NodeType.LLM
|
||||
node.id = "n1"
|
||||
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
|
||||
assert result is None
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_get_workflow_node_executions(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"app_id": "app-1"}
|
||||
trace_info.workflow_run_id = "run-1"
|
||||
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-1"
|
||||
app.created_by = "user-1"
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user-1"
|
||||
|
||||
tenant_join = MagicMock(spec=TenantAccountJoin)
|
||||
tenant_join.tenant_id = "tenant-1"
|
||||
|
||||
mock_executions = [MagicMock()]
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
|
||||
mock_db.engine = "engine"
|
||||
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
|
||||
session = mock_session_ctx.return_value.__enter__.return_value
|
||||
session.scalar.side_effect = [app, account]
|
||||
session.query.return_value.filter_by.return_value.first.return_value = tenant_join
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
|
||||
) as mock_repo:
|
||||
mock_repo.return_value.get_by_workflow_run.return_value = mock_executions
|
||||
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
|
||||
assert results == mock_executions
|
||||
account.set_tenant_id.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
assert results == []
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"app_id": "app-1"}
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
|
||||
mock_db.init_app = MagicMock() # Ensure init_app is mocked
|
||||
mock_db.engine = "engine"
|
||||
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
|
||||
session = mock_session_ctx.return_value.__enter__.return_value
|
||||
session.scalar.return_value = None
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
assert results == []
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_get_user_id_workflow(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.tenant_id = "tenant-1"
|
||||
trace_info.metadata = {"user_id": "user-1"}
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
|
||||
mock_db.init_app = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
|
||||
def test_get_user_id_only_user_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = {"user_id": "user-1"}
|
||||
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "user-1"
|
||||
|
||||
def test_get_user_id_anonymous(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "anonymous"
|
||||
|
||||
def test_get_user_id_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.tenant_id = "t"
|
||||
trace_info.metadata = {"user_id": "u"}
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
|
||||
|
||||
def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.process_data = {
|
||||
"usage": {
|
||||
"latency": 2.5,
|
||||
"time_to_first_token": 0.5,
|
||||
"time_to_generate": 2.0,
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
},
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"model_mode": "chat",
|
||||
}
|
||||
node.outputs = {}
|
||||
|
||||
tencent_data_trace._record_llm_metrics(node)
|
||||
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_time_to_generate.assert_called_once()
|
||||
assert tencent_data_trace.trace_client.record_token_usage.call_count == 2
|
||||
|
||||
def test_record_llm_metrics_usage_in_outputs(self, tencent_data_trace):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.process_data = {}
|
||||
node.outputs = {"usage": {"latency": 1.0, "prompt_tokens": 5}}
|
||||
|
||||
tencent_data_trace._record_llm_metrics(node)
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_token_usage.assert_called_once()
|
||||
|
||||
def test_record_llm_metrics_exception(self, tencent_data_trace):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.process_data = None
|
||||
node.outputs = None
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_llm_metrics(node)
|
||||
# Should not crash
|
||||
|
||||
def test_record_message_llm_metrics(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = {"ls_provider": "openai", "ls_model_name": "gpt-4"}
|
||||
trace_info.message_data = {"provider_response_latency": 1.1}
|
||||
trace_info.is_streaming_request = True
|
||||
trace_info.gen_ai_server_time_to_first_token = 0.2
|
||||
trace_info.llm_streaming_time_to_generate = 0.9
|
||||
trace_info.message_tokens = 15
|
||||
trace_info.answer_tokens = 25
|
||||
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_time_to_generate.assert_called_once()
|
||||
assert tencent_data_trace.trace_client.record_token_usage.call_count == 2
|
||||
|
||||
def test_record_message_llm_metrics_object_data(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
msg_data = MagicMock()
|
||||
msg_data.provider_response_latency = 1.1
|
||||
msg_data.model_provider = "anthropic"
|
||||
msg_data.model_id = "claude"
|
||||
trace_info.message_data = msg_data
|
||||
trace_info.is_streaming_request = False
|
||||
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
|
||||
def test_record_message_llm_metrics_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = None
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
# Should not crash
|
||||
|
||||
def test_record_workflow_trace_duration(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
trace_info.start_time = now
|
||||
trace_info.end_time = now + timedelta(seconds=3)
|
||||
trace_info.workflow_run_status = "succeeded"
|
||||
trace_info.conversation_id = "conv-1"
|
||||
|
||||
# Mock the record_trace_duration method to capture arguments
|
||||
with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record:
|
||||
tencent_data_trace._record_workflow_trace_duration(trace_info)
|
||||
|
||||
# Assert the method was called once
|
||||
mock_record.assert_called_once()
|
||||
|
||||
# Extract arguments passed to the method
|
||||
args, kwargs = mock_record.call_args
|
||||
|
||||
# Validate the duration argument
|
||||
assert args[0] == 3.0
|
||||
|
||||
# Validate the attributes dict in kwargs
|
||||
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
|
||||
assert attributes["conversation_mode"] == "workflow"
|
||||
assert attributes["has_conversation"] == "true"
|
||||
|
||||
def test_record_workflow_trace_duration_fallback(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.start_time = None
|
||||
trace_info.workflow_run_elapsed_time = 4.5
|
||||
trace_info.workflow_run_status = "failed"
|
||||
trace_info.conversation_id = None
|
||||
|
||||
with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record:
|
||||
tencent_data_trace._record_workflow_trace_duration(trace_info)
|
||||
mock_record.assert_called_once()
|
||||
args, kwargs = mock_record.call_args
|
||||
assert args[0] == 4.5
|
||||
# Check attributes dict (either in kwargs or as second positional arg)
|
||||
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
|
||||
assert attributes["has_conversation"] == "false"
|
||||
|
||||
def test_record_workflow_trace_duration_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_workflow_trace_duration(trace_info)
|
||||
|
||||
def test_record_message_trace_duration(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
trace_info.start_time = now
|
||||
trace_info.end_time = now + timedelta(seconds=2)
|
||||
trace_info.conversation_mode = "chat"
|
||||
trace_info.is_streaming_request = True
|
||||
|
||||
tencent_data_trace._record_message_trace_duration(trace_info)
|
||||
tencent_data_trace.trace_client.record_trace_duration.assert_called_once_with(
|
||||
2.0, {"conversation_mode": "chat", "stream": "true"}
|
||||
)
|
||||
|
||||
def test_record_message_trace_duration_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.start_time = None
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_message_trace_duration(trace_info)
|
||||
|
||||
def test_del(self, tencent_data_trace):
|
||||
client = tencent_data_trace.trace_client
|
||||
tencent_data_trace.__del__()
|
||||
client.shutdown.assert_called_once()
|
||||
|
||||
def test_del_exception(self, tencent_data_trace):
|
||||
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.__del__()
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
@ -0,0 +1,106 @@
|
||||
"""Unit tests for Tencent APM tracing utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from opentelemetry.trace import Link, TraceFlags
|
||||
|
||||
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||
|
||||
|
||||
def test_convert_to_trace_id_with_valid_uuid() -> None:
|
||||
uuid_str = "12345678-1234-5678-1234-567812345678"
|
||||
assert TencentTraceUtils.convert_to_trace_id(uuid_str) == uuid.UUID(uuid_str).int
|
||||
|
||||
|
||||
def test_convert_to_trace_id_uses_uuid4_when_none() -> None:
|
||||
expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
|
||||
assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int
|
||||
uuid4_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_convert_to_trace_id_raises_value_error_for_invalid_uuid() -> None:
|
||||
with pytest.raises(ValueError, match=r"^Invalid UUID input:"):
|
||||
TencentTraceUtils.convert_to_trace_id("not-a-uuid")
|
||||
|
||||
|
||||
def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None:
|
||||
uuid_str = "12345678-1234-5678-1234-567812345678"
|
||||
span_type = "llm"
|
||||
|
||||
uuid_obj = uuid.UUID(uuid_str)
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||
expected = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
|
||||
assert TencentTraceUtils.convert_to_span_id(uuid_str, span_type) == expected
|
||||
assert TencentTraceUtils.convert_to_span_id(uuid_str, "other") != expected
|
||||
|
||||
|
||||
def test_convert_to_span_id_uses_uuid4_when_none() -> None:
|
||||
expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
|
||||
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
|
||||
span_id = TencentTraceUtils.convert_to_span_id(None, "workflow")
|
||||
assert isinstance(span_id, int)
|
||||
uuid4_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None:
|
||||
with pytest.raises(ValueError, match=r"^Invalid UUID input:"):
|
||||
TencentTraceUtils.convert_to_span_id("bad-uuid", "span")
|
||||
|
||||
|
||||
def test_generate_span_id_skips_invalid_span_id() -> None:
|
||||
with patch(
|
||||
"core.ops.tencent_trace.utils.random.getrandbits",
|
||||
side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42],
|
||||
) as bits_mock:
|
||||
assert TencentTraceUtils.generate_span_id() == 42
|
||||
assert bits_mock.call_count == 2
|
||||
|
||||
|
||||
def test_convert_datetime_to_nanoseconds_accepts_datetime() -> None:
|
||||
start_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
expected = int(start_time.timestamp() * 1e9)
|
||||
assert TencentTraceUtils.convert_datetime_to_nanoseconds(start_time) == expected
|
||||
|
||||
|
||||
def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None:
|
||||
fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
expected = int(fixed.timestamp() * 1e9)
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock:
|
||||
datetime_mock.now.return_value = fixed
|
||||
assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected
|
||||
datetime_mock.now.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("trace_id_str", "expected_trace_id"),
|
||||
[
|
||||
("0" * 31 + "1", int("0" * 31 + "1", 16)),
|
||||
(str(uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc")), uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc").int),
|
||||
],
|
||||
)
|
||||
def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: int) -> None:
|
||||
link = TencentTraceUtils.create_link(trace_id_str)
|
||||
assert isinstance(link, Link)
|
||||
assert link.context.trace_id == expected_trace_id
|
||||
assert link.context.span_id == TencentTraceUtils.INVALID_SPAN_ID
|
||||
assert link.context.is_remote is False
|
||||
assert link.context.trace_flags == TraceFlags(TraceFlags.SAMPLED)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None])
|
||||
def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None:
|
||||
fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd")
|
||||
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
|
||||
link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type]
|
||||
assert link.context.trace_id == fallback_uuid.int
|
||||
uuid4_mock.assert_called_once()
|
||||
112
api/tests/unit_tests/core/ops/test_base_trace_instance.py
Normal file
112
api/tests/unit_tests/core/ops/test_base_trace_instance.py
Normal file
@ -0,0 +1,112 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.entities.trace_entity import BaseTraceInfo
|
||||
from models import Account, App, TenantAccountJoin
|
||||
|
||||
|
||||
class ConcreteTraceInstance(BaseTraceInstance):
|
||||
def __init__(self, trace_config: BaseTracingConfig):
|
||||
super().__init__(trace_config)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
super().trace(trace_info)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(monkeypatch):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.__enter__.return_value = mock_session
|
||||
mock_session.__exit__.return_value = None
|
||||
|
||||
mock_session_class = MagicMock(return_value=mock_session)
|
||||
|
||||
monkeypatch.setattr("core.ops.base_trace_instance.Session", mock_session_class)
|
||||
monkeypatch.setattr("core.ops.base_trace_instance.db", MagicMock())
|
||||
return mock_session
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_app_not_found(mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
with pytest.raises(ValueError, match="App with id some_app_id not found"):
|
||||
instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_no_creator(mock_db_session):
|
||||
mock_app = MagicMock(spec=App)
|
||||
mock_app.id = "some_app_id"
|
||||
mock_app.created_by = None
|
||||
mock_db_session.scalar.return_value = mock_app
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
with pytest.raises(ValueError, match="App with id some_app_id has no creator"):
|
||||
instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_creator_not_found(mock_db_session):
|
||||
mock_app = MagicMock(spec=App)
|
||||
mock_app.id = "some_app_id"
|
||||
mock_app.created_by = "creator_id"
|
||||
|
||||
# First call to scalar returns app, second returns None (for account)
|
||||
mock_db_session.scalar.side_effect = [mock_app, None]
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
with pytest.raises(ValueError, match="Creator account with id creator_id not found for app some_app_id"):
|
||||
instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_tenant_not_found(mock_db_session):
|
||||
mock_app = MagicMock(spec=App)
|
||||
mock_app.id = "some_app_id"
|
||||
mock_app.created_by = "creator_id"
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "creator_id"
|
||||
|
||||
mock_db_session.scalar.side_effect = [mock_app, mock_account]
|
||||
|
||||
# session.query(TenantAccountJoin).filter_by(...).first() returns None
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
with pytest.raises(ValueError, match="Current tenant not found for account creator_id"):
|
||||
instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_success(mock_db_session):
|
||||
mock_app = MagicMock(spec=App)
|
||||
mock_app.id = "some_app_id"
|
||||
mock_app.created_by = "creator_id"
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "creator_id"
|
||||
mock_account.set_tenant_id = MagicMock()
|
||||
|
||||
mock_db_session.scalar.side_effect = [mock_app, mock_account]
|
||||
|
||||
mock_tenant_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_tenant_join.tenant_id = "tenant_id"
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
result = instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
assert result == mock_account
|
||||
mock_account.set_tenant_id.assert_called_once_with("tenant_id")
|
||||
576
api/tests/unit_tests/core/ops/test_ops_trace_manager.py
Normal file
576
api/tests/unit_tests/core/ops/test_ops_trace_manager.py
Normal file
@ -0,0 +1,576 @@
|
||||
import contextlib
|
||||
import json
|
||||
import queue
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.ops_trace_manager import (
|
||||
OpsTraceManager,
|
||||
TraceQueueManager,
|
||||
TraceTask,
|
||||
TraceTaskName,
|
||||
)
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self._data = kwargs
|
||||
|
||||
def model_dump(self):
|
||||
return dict(self._data)
|
||||
|
||||
|
||||
class DummyTraceInstance:
|
||||
instances: list["DummyTraceInstance"] = []
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
DummyTraceInstance.instances.append(self)
|
||||
|
||||
def api_check(self):
|
||||
return True
|
||||
|
||||
def get_project_key(self):
|
||||
return "fake-key"
|
||||
|
||||
def get_project_url(self):
|
||||
return "https://project.fake"
|
||||
|
||||
|
||||
FAKE_PROVIDER_ENTRY = {
|
||||
"config_class": DummyConfig,
|
||||
"secret_keys": ["secret_value"],
|
||||
"other_keys": ["other_value"],
|
||||
"trace_instance": DummyTraceInstance,
|
||||
}
|
||||
|
||||
|
||||
class FakeProviderMap:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self._data:
|
||||
return self._data[key]
|
||||
raise KeyError(f"Unsupported tracing provider: {key}")
|
||||
|
||||
|
||||
class DummyTimer:
|
||||
def __init__(self, interval, function):
|
||||
self.interval = interval
|
||||
self.function = function
|
||||
self.name = ""
|
||||
self.daemon = False
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def is_alive(self):
|
||||
return False
|
||||
|
||||
|
||||
class FakeMessageFile:
|
||||
def __init__(self):
|
||||
self.url = "path/to/file"
|
||||
self.id = "file-id"
|
||||
self.type = "document"
|
||||
self.created_by_role = "role"
|
||||
self.created_by = "user"
|
||||
|
||||
|
||||
def make_message_data(**overrides):
|
||||
created_at = datetime(2025, 2, 20, 12, 0, 0)
|
||||
base = {
|
||||
"id": "msg-id",
|
||||
"conversation_id": "conv-id",
|
||||
"created_at": created_at,
|
||||
"updated_at": created_at + timedelta(seconds=3),
|
||||
"message": "hello",
|
||||
"provider_response_latency": 1,
|
||||
"message_tokens": 5,
|
||||
"answer_tokens": 7,
|
||||
"answer": "world",
|
||||
"error": "",
|
||||
"status": "complete",
|
||||
"model_provider": "provider",
|
||||
"model_id": "model",
|
||||
"from_end_user_id": "end-user",
|
||||
"from_account_id": "account",
|
||||
"agent_based": False,
|
||||
"workflow_run_id": "workflow-run",
|
||||
"from_source": "source",
|
||||
"message_metadata": json.dumps({"usage": {"time_to_first_token": 1, "time_to_generate": 2}}),
|
||||
"agent_thoughts": [],
|
||||
"query": "sample-query",
|
||||
"inputs": "sample-input",
|
||||
}
|
||||
base.update(overrides)
|
||||
|
||||
class MessageData:
|
||||
def __init__(self, data):
|
||||
self.__dict__.update(data)
|
||||
|
||||
def to_dict(self):
|
||||
return dict(self.__dict__)
|
||||
|
||||
return MessageData(base)
|
||||
|
||||
|
||||
def make_agent_thought(tool_name, created_at):
|
||||
return SimpleNamespace(
|
||||
tools=[tool_name],
|
||||
created_at=created_at,
|
||||
tool_meta={
|
||||
tool_name: {
|
||||
"tool_config": {"foo": "bar"},
|
||||
"time_cost": 5,
|
||||
"error": "",
|
||||
"tool_parameters": {"x": 1},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant",
|
||||
id="run-id",
|
||||
elapsed_time=10,
|
||||
status="finished",
|
||||
inputs_dict={"sys.file": ["f1"], "query": "search"},
|
||||
outputs_dict={"out": "value"},
|
||||
version="3",
|
||||
error=None,
|
||||
total_tokens=12,
|
||||
workflow_run_id="run-id",
|
||||
created_at=datetime(2025, 2, 20, 10, 0, 0),
|
||||
finished_at=datetime(2025, 2, 20, 10, 0, 5),
|
||||
triggered_from="user",
|
||||
app_id="app-id",
|
||||
to_dict=lambda self=None: {"run": "value"},
|
||||
)
|
||||
|
||||
|
||||
def configure_db_query(session, *, message_file=None, workflow_app_log=None):
|
||||
def _side_effect(model):
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value.first.return_value = None
|
||||
if message_file and model.__name__ == "MessageFile":
|
||||
query.filter_by.return_value.first.return_value = message_file
|
||||
if workflow_app_log and model.__name__ == "WorkflowAppLog":
|
||||
query.filter_by.return_value.first.return_value = workflow_app_log
|
||||
return query
|
||||
|
||||
session.query.side_effect = _side_effect
|
||||
|
||||
|
||||
class DummySessionContext:
|
||||
scalar_values = []
|
||||
|
||||
def __init__(self, engine):
|
||||
self._values = list(self.scalar_values)
|
||||
self._index = 0
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def scalar(self, *args, **kwargs):
|
||||
if self._index >= len(self._values):
|
||||
return None
|
||||
value = self._values[self._index]
|
||||
self._index += 1
|
||||
return value
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_provider_map(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({"dummy": FAKE_PROVIDER_ENTRY})
|
||||
)
|
||||
OpsTraceManager.ops_trace_instances_cache.clear()
|
||||
OpsTraceManager.decrypted_configs_cache.clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_timer_and_current_app(monkeypatch):
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.threading.Timer", DummyTimer)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_queue", queue.Queue())
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_timer", None)
|
||||
|
||||
class FakeApp:
|
||||
def app_context(self):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
fake_current = MagicMock()
|
||||
fake_current._get_current_object.return_value = FakeApp()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.current_app", fake_current)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_sqlalchemy_session(monkeypatch):
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.Session", DummySessionContext)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def encryption_mocks(monkeypatch):
|
||||
encrypt_mock = MagicMock(side_effect=lambda tenant, value: f"enc-{value}")
|
||||
batch_decrypt_mock = MagicMock(side_effect=lambda tenant, values: [f"dec-{value}" for value in values])
|
||||
obfuscate_mock = MagicMock(side_effect=lambda value: f"ob-{value}")
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.encrypt_token", encrypt_mock)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.batch_decrypt_token", batch_decrypt_mock)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.obfuscated_token", obfuscate_mock)
|
||||
return encrypt_mock, batch_decrypt_mock, obfuscate_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(monkeypatch):
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = ["chat"]
|
||||
db_mock = MagicMock()
|
||||
db_mock.session = session
|
||||
db_mock.engine = MagicMock()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.db", db_mock)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_repo_fixture(monkeypatch):
|
||||
repo = MagicMock()
|
||||
repo.get_workflow_run_by_id_without_tenant.return_value = make_workflow_run()
|
||||
monkeypatch.setattr(TraceTask, "_get_workflow_run_repo", classmethod(lambda cls: repo))
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_task_message(monkeypatch, mock_db):
|
||||
message_data = make_message_data()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data)
|
||||
configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id"))
|
||||
return message_data
|
||||
|
||||
|
||||
def test_encrypt_tracing_config_handles_star_and_encrypt(encryption_mocks):
|
||||
encrypted = OpsTraceManager.encrypt_tracing_config(
|
||||
"tenant",
|
||||
"dummy",
|
||||
{"secret_value": "value", "other_value": "info"},
|
||||
current_trace_config={"secret_value": "keep"},
|
||||
)
|
||||
assert encrypted["secret_value"] == "enc-value"
|
||||
assert encrypted["other_value"] == "info"
|
||||
|
||||
|
||||
def test_encrypt_tracing_config_preserves_star(encryption_mocks):
|
||||
encrypted = OpsTraceManager.encrypt_tracing_config(
|
||||
"tenant",
|
||||
"dummy",
|
||||
{"secret_value": "*", "other_value": "info"},
|
||||
current_trace_config={"secret_value": "keep"},
|
||||
)
|
||||
assert encrypted["secret_value"] == "keep"
|
||||
|
||||
|
||||
def test_decrypt_tracing_config_caches(encryption_mocks):
|
||||
_, decrypt_mock, _ = encryption_mocks
|
||||
payload = {"secret_value": "enc", "other_value": "info"}
|
||||
first = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload)
|
||||
second = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload)
|
||||
assert first == second
|
||||
assert decrypt_mock.call_count == 1
|
||||
|
||||
|
||||
def test_obfuscated_decrypt_token(encryption_mocks):
|
||||
_, _, obfuscate_mock = encryption_mocks
|
||||
result = OpsTraceManager.obfuscated_decrypt_token("dummy", {"secret_value": "value", "other_value": "info"})
|
||||
assert "secret_value" in result
|
||||
assert result["secret_value"] == "ob-value"
|
||||
obfuscate_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"})
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
app = SimpleNamespace(id="app-id", tenant_id="tenant")
|
||||
mock_db.scalar.return_value = app
|
||||
|
||||
decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
assert decrypted["other_value"] == "info"
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_missing_trace_config(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"})
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
mock_db.scalar.return_value = None
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_raises_for_none_config(mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config=None)
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant")
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None"):
|
||||
OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_handles_none_app(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({}))
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_success(monkeypatch, mock_db):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config",
|
||||
classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}),
|
||||
)
|
||||
instance = OpsTraceManager.get_ops_trace_instance("app-id")
|
||||
assert instance is not None
|
||||
cached_instance = OpsTraceManager.get_ops_trace_instance("app-id")
|
||||
assert instance is cached_instance
|
||||
|
||||
|
||||
def test_get_app_config_through_message_id_returns_none(mock_db):
|
||||
mock_db.scalar.return_value = None
|
||||
assert OpsTraceManager.get_app_config_through_message_id("m") is None
|
||||
|
||||
|
||||
def test_get_app_config_through_message_id_prefers_override(mock_db):
|
||||
message = SimpleNamespace(conversation_id="conv")
|
||||
conversation = SimpleNamespace(app_model_config_id=None, override_model_configs={"foo": "bar"})
|
||||
app_config = SimpleNamespace(id="config-id")
|
||||
mock_db.scalar.side_effect = [message, conversation]
|
||||
result = OpsTraceManager.get_app_config_through_message_id("m")
|
||||
assert result == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_get_app_config_through_message_id_app_model_config(mock_db):
|
||||
message = SimpleNamespace(conversation_id="conv")
|
||||
conversation = SimpleNamespace(app_model_config_id="cfg", override_model_configs=None)
|
||||
mock_db.scalar.side_effect = [message, conversation, SimpleNamespace(id="cfg")]
|
||||
result = OpsTraceManager.get_app_config_through_message_id("m")
|
||||
assert result.id == "cfg"
|
||||
|
||||
|
||||
def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider"):
|
||||
OpsTraceManager.update_app_tracing_config("app", True, "bad")
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
OpsTraceManager.update_app_tracing_config("app", True, None)
|
||||
|
||||
|
||||
def test_update_app_tracing_config_success(mock_db):
|
||||
app = SimpleNamespace(id="app-id", tracing="{}")
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
OpsTraceManager.update_app_tracing_config("app-id", True, "dummy")
|
||||
assert app.tracing is not None
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_app_tracing_config_errors_when_missing(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
OpsTraceManager.get_app_tracing_config("app")
|
||||
|
||||
|
||||
def test_get_app_tracing_config_returns_defaults(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None)
|
||||
assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None}
|
||||
|
||||
|
||||
def test_get_app_tracing_config_returns_payload(mock_db):
|
||||
payload = {"enabled": True, "tracing_provider": "dummy"}
|
||||
mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload))
|
||||
assert OpsTraceManager.get_app_tracing_config("app-id") == payload
|
||||
|
||||
|
||||
def test_check_and_project_helpers(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.provider_config_map",
|
||||
FakeProviderMap(
|
||||
{
|
||||
"dummy": {
|
||||
"config_class": DummyConfig,
|
||||
"trace_instance": type(
|
||||
"Trace",
|
||||
(),
|
||||
{
|
||||
"__init__": lambda self, cfg: None,
|
||||
"api_check": lambda self: True,
|
||||
"get_project_key": lambda self: "key",
|
||||
"get_project_url": lambda self: "url",
|
||||
},
|
||||
),
|
||||
"secret_keys": [],
|
||||
"other_keys": [],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
assert OpsTraceManager.check_trace_config_is_effective({}, "dummy")
|
||||
assert OpsTraceManager.get_trace_config_project_key({}, "dummy") == "key"
|
||||
assert OpsTraceManager.get_trace_config_project_url({}, "dummy") == "url"
|
||||
|
||||
|
||||
def test_trace_task_conversation_and_extract(monkeypatch):
|
||||
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE, message_id="msg")
|
||||
assert task.conversation_trace(foo="bar") == {"foo": "bar"}
|
||||
assert task._extract_streaming_metrics(make_message_data(message_metadata="not json")) == {}
|
||||
|
||||
|
||||
def test_trace_task_message_trace(trace_task_message, mock_db):
|
||||
task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id")
|
||||
result = task.message_trace("msg-id")
|
||||
assert result.message_id == "msg-id"
|
||||
|
||||
|
||||
def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db):
|
||||
DummySessionContext.scalar_values = ["wf-app-log", "message-ref"]
|
||||
execution = SimpleNamespace(id_="run-id")
|
||||
task = TraceTask(
|
||||
trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user"
|
||||
)
|
||||
result = task.workflow_trace(workflow_run_id="run-id", conversation_id="conv", user_id="user")
|
||||
assert result.workflow_run_id == "run-id"
|
||||
assert result.workflow_id == "wf-1"
|
||||
|
||||
|
||||
def test_trace_task_moderation_trace(trace_task_message):
|
||||
task = TraceTask(trace_type=TraceTaskName.MODERATION_TRACE, message_id="msg-id")
|
||||
moderation_result = SimpleNamespace(action="block", preset_response="no", query="q", flagged=True)
|
||||
timer = {"start": 1, "end": 2}
|
||||
result = task.moderation_trace("msg-id", timer, moderation_result=moderation_result, inputs={"src": "payload"})
|
||||
assert result.flagged is True
|
||||
assert result.message_id == "log-id"
|
||||
|
||||
|
||||
def test_trace_task_suggested_question_trace(trace_task_message):
|
||||
task = TraceTask(trace_type=TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id="msg-id")
|
||||
timer = {"start": 1, "end": 2}
|
||||
result = task.suggested_question_trace("msg-id", timer, suggested_question=["q1"])
|
||||
assert result.message_id == "log-id"
|
||||
assert "suggested_question" in result.__dict__
|
||||
|
||||
|
||||
def test_trace_task_dataset_retrieval_trace(trace_task_message):
|
||||
task = TraceTask(trace_type=TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id="msg-id")
|
||||
timer = {"start": 1, "end": 2}
|
||||
mock_doc = SimpleNamespace(model_dump=lambda: {"doc": "value"})
|
||||
result = task.dataset_retrieval_trace("msg-id", timer, documents=[mock_doc])
|
||||
assert result.documents == [{"doc": "value"}]
|
||||
|
||||
|
||||
def test_trace_task_tool_trace(monkeypatch, mock_db):
|
||||
custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))])
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message)
|
||||
configure_db_query(mock_db, message_file=FakeMessageFile())
|
||||
task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id")
|
||||
timer = {"start": 1, "end": 5}
|
||||
result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result")
|
||||
assert result.tool_name == "tool-a"
|
||||
assert result.time_cost == 5
|
||||
|
||||
|
||||
def test_trace_task_generate_name_trace():
|
||||
task = TraceTask(trace_type=TraceTaskName.GENERATE_NAME_TRACE, conversation_id="conv-id")
|
||||
timer = {"start": 1, "end": 2}
|
||||
assert task.generate_name_trace("conv-id", timer, tenant_id=None) == {}
|
||||
result = task.generate_name_trace(
|
||||
"conv-id", timer, tenant_id="tenant", generate_conversation_name="name", inputs="q"
|
||||
)
|
||||
assert result.outputs == "name"
|
||||
assert result.tenant_id == "tenant"
|
||||
|
||||
|
||||
def test_extract_streaming_metrics_invalid_json():
|
||||
task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id")
|
||||
fake_message = make_message_data(message_metadata="invalid")
|
||||
assert task._extract_streaming_metrics(fake_message) == {}
|
||||
|
||||
|
||||
def test_trace_queue_manager_add_and_collect(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
||||
)
|
||||
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
||||
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
|
||||
manager.add_trace_task(task)
|
||||
tasks = manager.collect_tasks()
|
||||
assert tasks == [task]
|
||||
|
||||
|
||||
def test_trace_queue_manager_run_invokes_send(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
||||
)
|
||||
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
||||
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
|
||||
called = {}
|
||||
|
||||
def fake_collect():
|
||||
return [task]
|
||||
|
||||
def fake_send(tasks):
|
||||
called["tasks"] = tasks
|
||||
|
||||
monkeypatch.setattr(TraceQueueManager, "collect_tasks", lambda self: fake_collect())
|
||||
monkeypatch.setattr(TraceQueueManager, "send_to_celery", lambda self, t: fake_send(t))
|
||||
manager.run()
|
||||
assert called["tasks"] == [task]
|
||||
|
||||
|
||||
def test_trace_queue_manager_send_to_celery(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
||||
)
|
||||
storage_save = MagicMock()
|
||||
process_delay = MagicMock()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.storage.save", storage_save)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.process_trace_tasks.delay", process_delay)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.uuid4", MagicMock(return_value=SimpleNamespace(hex="file-123")))
|
||||
|
||||
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
||||
|
||||
class DummyTraceInfo:
|
||||
def model_dump(self):
|
||||
return {"trace": "info"}
|
||||
|
||||
class DummyTask:
|
||||
def __init__(self):
|
||||
self.app_id = "app-id"
|
||||
|
||||
def execute(self):
|
||||
return DummyTraceInfo()
|
||||
|
||||
task = DummyTask()
|
||||
manager.send_to_celery([task])
|
||||
storage_save.assert_called_once()
|
||||
process_delay.assert_called_once_with({"file_id": "file-123", "app_id": "app-id"})
|
||||
@ -1,9 +1,20 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
|
||||
from core.ops.utils import (
|
||||
filter_none_values,
|
||||
generate_dotted_order,
|
||||
get_message_data,
|
||||
measure_time,
|
||||
replace_text_with_content,
|
||||
validate_integer_id,
|
||||
validate_project_name,
|
||||
validate_url,
|
||||
validate_url_with_path,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
@ -187,3 +198,92 @@ class TestGenerateDottedOrder:
|
||||
result = generate_dotted_order(run_id, start_time, None)
|
||||
|
||||
assert "." not in result
|
||||
|
||||
def test_dotted_order_with_string_start_time(self):
|
||||
"""Test dotted_order generation with string start_time."""
|
||||
start_time = "2025-12-23T04:19:55.111000"
|
||||
run_id = "test-run-id"
|
||||
result = generate_dotted_order(run_id, start_time)
|
||||
|
||||
assert result == "20251223T041955111000Ztest-run-id"
|
||||
|
||||
|
||||
class TestFilterNoneValues:
|
||||
"""Test cases for filter_none_values function"""
|
||||
|
||||
def test_filter_none_values(self):
|
||||
data = {"a": 1, "b": None, "c": "test", "d": datetime(2025, 1, 1, 12, 0, 0)}
|
||||
result = filter_none_values(data)
|
||||
assert result == {"a": 1, "c": "test", "d": "2025-01-01T12:00:00"}
|
||||
|
||||
def test_filter_none_values_empty(self):
|
||||
assert filter_none_values({}) == {}
|
||||
|
||||
|
||||
class TestGetMessageData:
|
||||
"""Test cases for get_message_data function"""
|
||||
|
||||
@patch("core.ops.utils.db")
|
||||
@patch("core.ops.utils.Message")
|
||||
@patch("core.ops.utils.select")
|
||||
def test_get_message_data(self, mock_select, mock_message, mock_db):
|
||||
mock_scalar = mock_db.session.scalar
|
||||
mock_msg_instance = MagicMock()
|
||||
mock_scalar.return_value = mock_msg_instance
|
||||
|
||||
result = get_message_data("message-id")
|
||||
|
||||
assert result == mock_msg_instance
|
||||
mock_select.assert_called_once()
|
||||
mock_scalar.assert_called_once()
|
||||
|
||||
|
||||
class TestMeasureTime:
|
||||
"""Test cases for measure_time function"""
|
||||
|
||||
def test_measure_time(self):
|
||||
with measure_time() as timing_info:
|
||||
assert "start" in timing_info
|
||||
assert isinstance(timing_info["start"], datetime)
|
||||
assert timing_info["end"] is None
|
||||
|
||||
assert timing_info["end"] is not None
|
||||
assert isinstance(timing_info["end"], datetime)
|
||||
assert timing_info["end"] >= timing_info["start"]
|
||||
|
||||
|
||||
class TestReplaceTextWithContent:
|
||||
"""Test cases for replace_text_with_content function"""
|
||||
|
||||
def test_replace_text_with_content_dict(self):
|
||||
data = {"text": "hello", "other": "world"}
|
||||
assert replace_text_with_content(data) == {"content": "hello", "other": "world"}
|
||||
|
||||
def test_replace_text_with_content_nested(self):
|
||||
data = {"text": "v1", "nested": {"text": "v2", "list": [{"text": "v3"}]}}
|
||||
expected = {"content": "v1", "nested": {"content": "v2", "list": [{"content": "v3"}]}}
|
||||
assert replace_text_with_content(data) == expected
|
||||
|
||||
def test_replace_text_with_content_list(self):
|
||||
data = [{"text": "v1"}, "v2"]
|
||||
assert replace_text_with_content(data) == [{"content": "v1"}, "v2"]
|
||||
|
||||
def test_replace_text_with_content_primitive(self):
|
||||
assert replace_text_with_content(123) == 123
|
||||
assert replace_text_with_content("text") == "text"
|
||||
|
||||
|
||||
class TestValidateIntegerId:
|
||||
"""Test cases for validate_integer_id function"""
|
||||
|
||||
def test_valid_integer_id(self):
|
||||
assert validate_integer_id("123") == "123"
|
||||
assert validate_integer_id(" 456 ") == "456"
|
||||
|
||||
def test_invalid_integer_id_raises_error(self):
|
||||
with pytest.raises(ValueError, match="ID must be a valid integer"):
|
||||
validate_integer_id("abc")
|
||||
|
||||
def test_empty_integer_id_raises_error(self):
|
||||
with pytest.raises(ValueError, match="ID must be a valid integer"):
|
||||
validate_integer_id("")
|
||||
|
||||
1196
api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py
Normal file
1196
api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,82 +1,603 @@
|
||||
from __future__ import annotations
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
from typing import Any
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
|
||||
from core.workflow import node_factory
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.nodes.code.entities import CodeLanguage
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
|
||||
|
||||
def _build_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
|
||||
graph_init_params = build_test_graph_init_params(
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
assert config["id"] == node_id
|
||||
assert isinstance(config["data"], BaseNodeData)
|
||||
assert config["data"].type == node_type
|
||||
assert config["data"].version == version
|
||||
|
||||
|
||||
class TestFetchMemory:
|
||||
@pytest.mark.parametrize(
|
||||
("conversation_id", "memory_config"),
|
||||
[
|
||||
(None, object()),
|
||||
("conversation-id", None),
|
||||
],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
def test_returns_none_when_memory_or_conversation_is_missing(self, conversation_id, memory_config):
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_conversation_does_not_exist(self, monkeypatch):
|
||||
class FakeSelect:
|
||||
def where(self, *_args):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
|
||||
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
|
||||
monkeypatch.setattr(node_factory, "Session", FakeSession)
|
||||
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=object(),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch):
|
||||
conversation = sentinel.conversation
|
||||
memory = sentinel.memory
|
||||
|
||||
class FakeSelect:
|
||||
def where(self, *_args):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return conversation
|
||||
|
||||
token_buffer_memory = MagicMock(return_value=memory)
|
||||
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
|
||||
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
|
||||
monkeypatch.setattr(node_factory, "Session", FakeSession)
|
||||
monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory)
|
||||
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=object(),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is memory
|
||||
token_buffer_memory.assert_called_once_with(
|
||||
conversation=conversation,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
|
||||
class TestDefaultWorkflowCodeExecutor:
|
||||
def test_execute_delegates_to_code_executor(self, monkeypatch):
|
||||
executor = node_factory.DefaultWorkflowCodeExecutor()
|
||||
execute_workflow_code_template = MagicMock(return_value={"answer": "ok"})
|
||||
monkeypatch.setattr(
|
||||
node_factory.CodeExecutor,
|
||||
"execute_workflow_code_template",
|
||||
execute_workflow_code_template,
|
||||
)
|
||||
|
||||
result = executor.execute(
|
||||
language=CodeLanguage.PYTHON3,
|
||||
code="print('ok')",
|
||||
inputs={"name": "workflow"},
|
||||
)
|
||||
|
||||
assert result == {"answer": "ok"}
|
||||
execute_workflow_code_template.assert_called_once_with(
|
||||
language=CodeLanguage.PYTHON3,
|
||||
code="print('ok')",
|
||||
inputs={"name": "workflow"},
|
||||
)
|
||||
|
||||
def test_is_execution_error_checks_code_execution_error_type(self):
|
||||
executor = node_factory.DefaultWorkflowCodeExecutor()
|
||||
|
||||
assert executor.is_execution_error(node_factory.CodeExecutionError("boom")) is True
|
||||
assert executor.is_execution_error(RuntimeError("boom")) is False
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
graph_runtime_state = sentinel.graph_runtime_state
|
||||
dify_context = SimpleNamespace(tenant_id="tenant-id")
|
||||
template_renderer = sentinel.template_renderer
|
||||
rag_retrieval = sentinel.rag_retrieval
|
||||
unstructured_api_config = sentinel.unstructured_api_config
|
||||
http_request_config = sentinel.http_request_config
|
||||
credentials_provider = sentinel.credentials_provider
|
||||
model_factory = sentinel.model_factory
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
node_factory.DifyNodeFactory,
|
||||
"_resolve_dify_context",
|
||||
return_value=dify_context,
|
||||
) as resolve_dify_context,
|
||||
patch.object(
|
||||
node_factory,
|
||||
"CodeExecutorJinja2TemplateRenderer",
|
||||
return_value=template_renderer,
|
||||
) as renderer_factory,
|
||||
patch.object(node_factory, "DatasetRetrieval", return_value=rag_retrieval),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"UnstructuredApiConfig",
|
||||
return_value=unstructured_api_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_http_request_config",
|
||||
return_value=http_request_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_dify_model_access",
|
||||
return_value=(credentials_provider, model_factory),
|
||||
) as build_dify_model_access,
|
||||
):
|
||||
factory = node_factory.DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
||||
build_dify_model_access.assert_called_once_with("tenant-id")
|
||||
renderer_factory.assert_called_once()
|
||||
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
||||
assert factory.graph_init_params is graph_init_params
|
||||
assert factory.graph_runtime_state is graph_runtime_state
|
||||
assert factory._dify_context is dify_context
|
||||
assert factory._template_renderer is template_renderer
|
||||
assert factory._rag_retrieval is rag_retrieval
|
||||
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
||||
assert factory._http_request_config is http_request_config
|
||||
assert factory._llm_credentials_provider is credentials_provider
|
||||
assert factory._llm_model_factory is model_factory
|
||||
|
||||
|
||||
class TestDifyNodeFactoryResolveContext:
|
||||
def test_requires_reserved_context_key(self):
|
||||
with pytest.raises(ValueError, match=DIFY_RUN_CONTEXT_KEY):
|
||||
node_factory.DifyNodeFactory._resolve_dify_context({})
|
||||
|
||||
def test_returns_existing_dify_context(self):
|
||||
dify_context = DifyRunContext(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
result = node_factory.DifyNodeFactory._resolve_dify_context({DIFY_RUN_CONTEXT_KEY: dify_context})
|
||||
|
||||
assert result is dify_context
|
||||
|
||||
def test_validates_mapping_context(self):
|
||||
raw_context = {
|
||||
DIFY_RUN_CONTEXT_KEY: {
|
||||
"tenant_id": "tenant-id",
|
||||
"app_id": "app-id",
|
||||
"user_id": "user-id",
|
||||
"user_from": UserFrom.ACCOUNT,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
}
|
||||
|
||||
result = node_factory.DifyNodeFactory._resolve_dify_context(raw_context)
|
||||
|
||||
assert isinstance(result, DifyRunContext)
|
||||
assert result.tenant_id == "tenant-id"
|
||||
|
||||
|
||||
class TestDifyNodeFactoryCreateNode:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory.graph_init_params = sentinel.graph_init_params
|
||||
factory.graph_runtime_state = sentinel.graph_runtime_state
|
||||
factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id")
|
||||
factory._code_executor = sentinel.code_executor
|
||||
factory._code_limits = sentinel.code_limits
|
||||
factory._template_renderer = sentinel.template_renderer
|
||||
factory._template_transform_max_output_length = 2048
|
||||
factory._http_request_http_client = sentinel.http_client
|
||||
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
||||
factory._http_request_file_manager = sentinel.file_manager
|
||||
factory._rag_retrieval = sentinel.rag_retrieval
|
||||
factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config
|
||||
factory._http_request_config = sentinel.http_request_config
|
||||
factory._llm_credentials_provider = sentinel.credentials_provider
|
||||
factory._llm_model_factory = sentinel.model_factory
|
||||
return factory
|
||||
|
||||
def test_rejects_unknown_node_type(self, factory):
|
||||
with pytest.raises(ValueError, match="Input should be"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": "missing"}})
|
||||
|
||||
def test_rejects_missing_class_mapping(self, monkeypatch, factory):
|
||||
monkeypatch.setattr(node_factory, "NODE_TYPE_CLASSES_MAPPING", {})
|
||||
|
||||
with pytest.raises(ValueError, match="No class mapping found for node type: start"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}})
|
||||
|
||||
def test_rejects_missing_latest_class(self, monkeypatch, factory):
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.START: {node_factory.LATEST_VERSION: None}},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No latest version class found for node type: start"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}})
|
||||
|
||||
def test_uses_version_specific_class_when_available(self, monkeypatch, factory):
|
||||
matched_node = sentinel.matched_node
|
||||
latest_node_class = MagicMock(return_value=sentinel.latest_node)
|
||||
matched_node_class = MagicMock(return_value=matched_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{
|
||||
NodeType.START: {
|
||||
node_factory.LATEST_VERSION: latest_node_class,
|
||||
"9": matched_node_class,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}})
|
||||
|
||||
assert result is matched_node
|
||||
matched_node_class.assert_called_once()
|
||||
kwargs = matched_node_class.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
latest_node_class.assert_not_called()
|
||||
|
||||
def test_falls_back_to_latest_class_when_version_specific_mapping_is_missing(self, monkeypatch, factory):
|
||||
latest_node = sentinel.latest_node
|
||||
latest_node_class = MagicMock(return_value=latest_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.START: {node_factory.LATEST_VERSION: latest_node_class}},
|
||||
)
|
||||
|
||||
result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}})
|
||||
|
||||
assert result is latest_node
|
||||
latest_node_class.assert_called_once()
|
||||
kwargs = latest_node_class.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name"),
|
||||
[
|
||||
(NodeType.CODE, "CodeNode"),
|
||||
(NodeType.TEMPLATE_TRANSFORM, "TemplateTransformNode"),
|
||||
(NodeType.HTTP_REQUEST, "HttpRequestNode"),
|
||||
(NodeType.HUMAN_INPUT, "HumanInputNode"),
|
||||
(NodeType.KNOWLEDGE_INDEX, "KnowledgeIndexNode"),
|
||||
(NodeType.DATASOURCE, "DatasourceNode"),
|
||||
(NodeType.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"),
|
||||
(NodeType.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"),
|
||||
],
|
||||
)
|
||||
return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
def test_creates_specialized_nodes(self, monkeypatch, factory, node_type, constructor_name):
|
||||
created_node = object()
|
||||
constructor = MagicMock(name=constructor_name, return_value=created_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{node_type: {node_factory.LATEST_VERSION: constructor}},
|
||||
)
|
||||
|
||||
if constructor_name == "HumanInputNode":
|
||||
form_repository = sentinel.form_repository
|
||||
form_repository_impl = MagicMock(return_value=form_repository)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"HumanInputFormRepositoryImpl",
|
||||
form_repository_impl,
|
||||
)
|
||||
elif constructor_name == "KnowledgeIndexNode":
|
||||
index_processor = sentinel.index_processor
|
||||
summary_index = sentinel.summary_index
|
||||
monkeypatch.setattr(node_factory, "IndexProcessor", MagicMock(return_value=index_processor))
|
||||
monkeypatch.setattr(node_factory, "SummaryIndex", MagicMock(return_value=summary_index))
|
||||
|
||||
node_config = {"id": "node-id", "data": {"type": node_type.value}}
|
||||
result = factory.create_node(node_config)
|
||||
|
||||
assert result is created_node
|
||||
kwargs = constructor.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
|
||||
if constructor_name == "CodeNode":
|
||||
assert kwargs["code_executor"] is sentinel.code_executor
|
||||
assert kwargs["code_limits"] is sentinel.code_limits
|
||||
elif constructor_name == "TemplateTransformNode":
|
||||
assert kwargs["template_renderer"] is sentinel.template_renderer
|
||||
assert kwargs["max_output_length"] == 2048
|
||||
elif constructor_name == "HttpRequestNode":
|
||||
assert kwargs["http_request_config"] is sentinel.http_request_config
|
||||
assert kwargs["http_client"] is sentinel.http_client
|
||||
assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory
|
||||
assert kwargs["file_manager"] is sentinel.file_manager
|
||||
elif constructor_name == "HumanInputNode":
|
||||
assert kwargs["form_repository"] is form_repository
|
||||
form_repository_impl.assert_called_once_with(tenant_id="tenant-id")
|
||||
elif constructor_name == "KnowledgeIndexNode":
|
||||
assert kwargs["index_processor"] is index_processor
|
||||
assert kwargs["summary_index_service"] is summary_index
|
||||
elif constructor_name == "DatasourceNode":
|
||||
assert kwargs["datasource_manager"] is node_factory.DatasourceManager
|
||||
elif constructor_name == "KnowledgeRetrievalNode":
|
||||
assert kwargs["rag_retrieval"] is sentinel.rag_retrieval
|
||||
elif constructor_name == "DocumentExtractorNode":
|
||||
assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config
|
||||
assert kwargs["http_client"] is sentinel.http_client
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
[
|
||||
(NodeType.LLM, "LLMNode", {"http_client": sentinel.http_client}),
|
||||
(NodeType.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
|
||||
(NodeType.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
||||
],
|
||||
)
|
||||
def test_creates_model_backed_nodes(
|
||||
self,
|
||||
monkeypatch,
|
||||
factory,
|
||||
node_type,
|
||||
constructor_name,
|
||||
expected_extra_kwargs,
|
||||
):
|
||||
created_node = object()
|
||||
constructor = MagicMock(name=constructor_name, return_value=created_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{node_type: {node_factory.LATEST_VERSION: constructor}},
|
||||
)
|
||||
llm_init_kwargs = {
|
||||
"credentials_provider": sentinel.credentials_provider,
|
||||
"model_factory": sentinel.model_factory,
|
||||
"model_instance": sentinel.model_instance,
|
||||
"memory": sentinel.memory,
|
||||
**expected_extra_kwargs,
|
||||
}
|
||||
build_llm_init_kwargs = MagicMock(return_value=llm_init_kwargs)
|
||||
factory._build_llm_compatible_node_init_kwargs = build_llm_init_kwargs
|
||||
|
||||
node_config = {"id": "node-id", "data": {"type": node_type.value}}
|
||||
result = factory.create_node(node_config)
|
||||
|
||||
assert result is created_node
|
||||
build_llm_init_kwargs.assert_called_once()
|
||||
helper_kwargs = build_llm_init_kwargs.call_args.kwargs
|
||||
assert helper_kwargs["node_class"] is constructor
|
||||
assert isinstance(helper_kwargs["node_data"], BaseNodeData)
|
||||
assert helper_kwargs["node_data"].type == node_type
|
||||
assert helper_kwargs["include_http_client"] is (node_type != NodeType.PARAMETER_EXTRACTOR)
|
||||
|
||||
constructor_kwargs = constructor.call_args.kwargs
|
||||
assert constructor_kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert constructor_kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider
|
||||
assert constructor_kwargs["model_factory"] is sentinel.model_factory
|
||||
assert constructor_kwargs["model_instance"] is sentinel.model_instance
|
||||
assert constructor_kwargs["memory"] is sentinel.memory
|
||||
for key, value in expected_extra_kwargs.items():
|
||||
assert constructor_kwargs[key] is value
|
||||
|
||||
|
||||
def test_create_node_uses_declared_node_data_type_for_llm_validation(monkeypatch):
|
||||
class _FactoryLLMNodeData(LLMNodeData):
|
||||
pass
|
||||
class TestDifyNodeFactoryModelInstance:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory._llm_credentials_provider = MagicMock()
|
||||
factory._llm_model_factory = MagicMock()
|
||||
return factory
|
||||
|
||||
llm_node_config = {
|
||||
"id": "llm-node",
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "LLM",
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o-mini",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"prompt_template": [],
|
||||
"context": {
|
||||
"enabled": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
graph_config = {"nodes": [llm_node_config], "edges": []}
|
||||
factory = _build_factory(graph_config)
|
||||
captured: dict[str, object] = {}
|
||||
@pytest.fixture
|
||||
def llm_model_setup(self, factory):
|
||||
def _configure(
|
||||
*,
|
||||
completion_params=None,
|
||||
has_provider_model=True,
|
||||
model_schema=sentinel.model_schema,
|
||||
):
|
||||
credentials = {"api_key": "secret"}
|
||||
node_data_model = SimpleNamespace(
|
||||
provider="provider",
|
||||
name="model",
|
||||
mode="chat",
|
||||
completion_params=completion_params or {},
|
||||
)
|
||||
node_data = SimpleNamespace(model=node_data_model)
|
||||
provider_model = MagicMock() if has_provider_model else None
|
||||
provider_model_bundle = SimpleNamespace(
|
||||
configuration=SimpleNamespace(get_provider_model=MagicMock(return_value=provider_model))
|
||||
)
|
||||
model_type_instance = MagicMock()
|
||||
model_type_instance.get_model_schema.return_value = model_schema
|
||||
model_instance = SimpleNamespace(
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
model_type_instance=model_type_instance,
|
||||
provider=None,
|
||||
model_name=None,
|
||||
credentials=None,
|
||||
parameters=None,
|
||||
stop=None,
|
||||
)
|
||||
factory._llm_credentials_provider.fetch.return_value = credentials
|
||||
factory._llm_model_factory.init_model_instance.return_value = model_instance
|
||||
return SimpleNamespace(
|
||||
node_data=node_data,
|
||||
credentials=credentials,
|
||||
provider_model=provider_model,
|
||||
model_type_instance=model_type_instance,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(LLMNode, "_node_data_type", _FactoryLLMNodeData)
|
||||
return _configure
|
||||
|
||||
def _capture_model_instance(self: DifyNodeFactory, node_data: object) -> ModelInstance:
|
||||
captured["node_data"] = node_data
|
||||
return object() # type: ignore[return-value]
|
||||
def test_requires_llm_mode(self, factory):
|
||||
node_data = SimpleNamespace(
|
||||
model=SimpleNamespace(
|
||||
provider="provider",
|
||||
name="model",
|
||||
mode="",
|
||||
completion_params={},
|
||||
)
|
||||
)
|
||||
|
||||
def _capture_memory(
|
||||
self: DifyNodeFactory,
|
||||
*,
|
||||
node_data: object,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
captured["memory_node_data"] = node_data
|
||||
with pytest.raises(node_factory.LLMModeRequiredError, match="LLM mode is required"):
|
||||
factory._build_model_instance_for_llm_node(node_data)
|
||||
|
||||
monkeypatch.setattr(DifyNodeFactory, "_build_model_instance_for_llm_node", _capture_model_instance)
|
||||
monkeypatch.setattr(DifyNodeFactory, "_build_memory_for_llm_node", _capture_memory)
|
||||
def test_raises_when_provider_model_is_missing(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(has_provider_model=False)
|
||||
|
||||
node = factory.create_node(llm_node_config)
|
||||
with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"):
|
||||
factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
assert isinstance(captured["node_data"], _FactoryLLMNodeData)
|
||||
assert isinstance(captured["memory_node_data"], _FactoryLLMNodeData)
|
||||
assert isinstance(node.node_data, _FactoryLLMNodeData)
|
||||
def test_raises_when_model_schema_is_missing(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(model_schema=None)
|
||||
|
||||
with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"):
|
||||
factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
setup.provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
def test_builds_model_instance_and_normalizes_stop_tokens(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(
|
||||
completion_params={"temperature": 0.3, "stop": "not-a-list"},
|
||||
model_schema={"schema": "value"},
|
||||
)
|
||||
|
||||
result = factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
assert result is setup.model_instance
|
||||
assert result.provider == "provider"
|
||||
assert result.model_name == "model"
|
||||
assert result.credentials == setup.credentials
|
||||
assert result.parameters == {"temperature": 0.3}
|
||||
assert result.stop == ()
|
||||
assert result.model_type_instance is setup.model_type_instance
|
||||
setup.provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
class TestDifyNodeFactoryMemory:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory._dify_context = SimpleNamespace(app_id="app-id")
|
||||
factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock())
|
||||
return factory
|
||||
|
||||
def test_returns_none_when_memory_is_not_configured(self, factory):
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=None),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
factory.graph_runtime_state.variable_pool.get.assert_not_called()
|
||||
|
||||
def test_uses_string_segment_conversation_id(self, monkeypatch, factory):
|
||||
memory_config = sentinel.memory_config
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="conversation-id")
|
||||
fetch_memory = MagicMock(return_value=sentinel.memory)
|
||||
monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory)
|
||||
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=memory_config),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is sentinel.memory
|
||||
factory.graph_runtime_state.variable_pool.get.assert_called_once_with(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
fetch_memory.assert_called_once_with(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
def test_ignores_non_string_segment_conversation_ids(self, monkeypatch, factory):
|
||||
memory_config = sentinel.memory_config
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = sentinel.segment
|
||||
fetch_memory = MagicMock(return_value=sentinel.memory)
|
||||
monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory)
|
||||
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=memory_config),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is sentinel.memory
|
||||
fetch_memory.assert_called_once_with(
|
||||
conversation_id=None,
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
@ -0,0 +1,656 @@
|
||||
from collections import UserString
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow import workflow_entry
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph_events import GraphRunFailedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.runtime import ChildGraphNotFoundError
|
||||
|
||||
|
||||
def _build_typed_node_config(node_type: NodeType):
|
||||
return NodeConfigDictAdapter.validate_python({"id": "node-id", "data": {"type": node_type}})
|
||||
|
||||
|
||||
class TestWorkflowChildEngineBuilder:
|
||||
@pytest.mark.parametrize(
|
||||
("graph_config", "node_id", "expected"),
|
||||
[
|
||||
({"nodes": [{"id": "root"}]}, "root", True),
|
||||
({"nodes": [{"id": "root"}]}, "other", False),
|
||||
({"nodes": "invalid"}, "root", None),
|
||||
({"nodes": ["invalid"]}, "root", None),
|
||||
],
|
||||
)
|
||||
def test_has_node_id(self, graph_config, node_id, expected):
|
||||
result = workflow_entry._WorkflowChildEngineBuilder._has_node_id(graph_config, node_id)
|
||||
|
||||
assert result is expected
|
||||
|
||||
def test_build_child_engine_raises_when_root_node_is_missing(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
|
||||
with patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory):
|
||||
with pytest.raises(ChildGraphNotFoundError, match="child graph root node 'missing' not found"):
|
||||
builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_config={"nodes": []},
|
||||
root_node_id="missing",
|
||||
)
|
||||
|
||||
def test_build_child_engine_constructs_graph_engine_and_layers(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
child_graph = sentinel.child_graph
|
||||
child_engine = MagicMock()
|
||||
quota_layer = sentinel.quota_layer
|
||||
additional_layers = [sentinel.layer_one, sentinel.layer_two]
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory) as dify_node_factory,
|
||||
patch.object(workflow_entry.Graph, "init", return_value=child_graph) as graph_init,
|
||||
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
|
||||
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
|
||||
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=quota_layer),
|
||||
):
|
||||
result = builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_config={"nodes": [{"id": "root"}]},
|
||||
root_node_id="root",
|
||||
layers=additional_layers,
|
||||
)
|
||||
|
||||
assert result is child_engine
|
||||
dify_node_factory.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
graph_init.assert_called_once_with(
|
||||
graph_config={"nodes": [{"id": "root"}]},
|
||||
node_factory=sentinel.factory,
|
||||
root_node_id="root",
|
||||
)
|
||||
graph_engine_cls.assert_called_once_with(
|
||||
workflow_id="workflow-id",
|
||||
graph=child_graph,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=sentinel.command_channel,
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=builder,
|
||||
)
|
||||
assert child_engine.layer.call_args_list == [
|
||||
((quota_layer,), {}),
|
||||
((sentinel.layer_one,), {}),
|
||||
((sentinel.layer_two,), {}),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowEntryInit:
|
||||
def test_rejects_call_depth_above_limit(self):
|
||||
call_depth = workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH + 1
|
||||
|
||||
with pytest.raises(ValueError, match="Max workflow call depth"):
|
||||
workflow_entry.WorkflowEntry(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
graph=sentinel.graph,
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=call_depth,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
|
||||
def test_applies_debug_and_observability_layers(self):
|
||||
graph_engine = MagicMock()
|
||||
debug_layer = sentinel.debug_layer
|
||||
execution_limits_layer = sentinel.execution_limits_layer
|
||||
llm_quota_layer = sentinel.llm_quota_layer
|
||||
observability_layer = sentinel.observability_layer
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.dify_config, "DEBUG", True),
|
||||
patch.object(workflow_entry.dify_config, "ENABLE_OTEL", False),
|
||||
patch.object(workflow_entry, "is_instrument_flag_enabled", return_value=True),
|
||||
patch.object(workflow_entry, "GraphEngine", return_value=graph_engine) as graph_engine_cls,
|
||||
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
|
||||
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
|
||||
patch.object(workflow_entry, "DebugLoggingLayer", return_value=debug_layer) as debug_logging_layer,
|
||||
patch.object(
|
||||
workflow_entry,
|
||||
"ExecutionLimitsLayer",
|
||||
return_value=execution_limits_layer,
|
||||
) as execution_limits_layer_cls,
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
|
||||
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
|
||||
):
|
||||
entry = workflow_entry.WorkflowEntry(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id-123456",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
graph=sentinel.graph,
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=None,
|
||||
)
|
||||
|
||||
assert entry.command_channel is sentinel.command_channel
|
||||
graph_engine_cls.assert_called_once_with(
|
||||
workflow_id="workflow-id-123456",
|
||||
graph=sentinel.graph,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=sentinel.command_channel,
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=entry._child_engine_builder,
|
||||
)
|
||||
debug_logging_layer.assert_called_once_with(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
include_outputs=True,
|
||||
include_process_data=False,
|
||||
logger_name="GraphEngine.Debug.workflow",
|
||||
)
|
||||
execution_limits_layer_cls.assert_called_once_with(
|
||||
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
assert graph_engine.layer.call_args_list == [
|
||||
((debug_layer,), {}),
|
||||
((execution_limits_layer,), {}),
|
||||
((llm_quota_layer,), {}),
|
||||
((observability_layer,), {}),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowEntryRun:
|
||||
def test_run_swallows_generate_task_stopped_errors(self):
|
||||
entry = object.__new__(workflow_entry.WorkflowEntry)
|
||||
entry.graph_engine = MagicMock()
|
||||
entry.graph_engine.run.side_effect = GenerateTaskStoppedError()
|
||||
|
||||
assert list(entry.run()) == []
|
||||
|
||||
def test_run_emits_failed_event_for_unexpected_errors(self):
|
||||
entry = object.__new__(workflow_entry.WorkflowEntry)
|
||||
entry.graph_engine = MagicMock()
|
||||
entry.graph_engine.run.side_effect = RuntimeError("boom")
|
||||
|
||||
events = list(entry.run())
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], GraphRunFailedEvent)
|
||||
assert events[0].error == "boom"
|
||||
|
||||
|
||||
class TestWorkflowEntrySingleStepRun:
|
||||
def test_uses_empty_mapping_when_selector_extraction_is_not_implemented(self):
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "fake"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START),
|
||||
)
|
||||
|
||||
node, generator = workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
load_into_variable_pool.assert_called_once_with(
|
||||
variable_loader=workflow_entry.DUMMY_VARIABLE_LOADER,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
mapping_user_inputs_to_variable_pool.assert_called_once_with(
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_skips_user_input_mapping_for_datasource_nodes(self):
|
||||
class FakeDatasourceNode:
|
||||
id = "node-id"
|
||||
node_type = "datasource"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {"question": ["node", "question"]}
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeDatasourceNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.DATASOURCE),
|
||||
)
|
||||
|
||||
node, generator = workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
load_into_variable_pool.assert_called_once()
|
||||
mapping_user_inputs_to_variable_pool.assert_not_called()
|
||||
|
||||
def test_wraps_traced_node_run_failures(self):
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "fake"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool"),
|
||||
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START),
|
||||
)
|
||||
|
||||
with pytest.raises(WorkflowNodeRunFailedError):
|
||||
workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowEntryHelpers:
|
||||
def test_create_single_node_graph_builds_start_edge(self):
|
||||
graph = workflow_entry.WorkflowEntry._create_single_node_graph(
|
||||
node_id="target-node",
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR},
|
||||
node_width=320,
|
||||
node_height=180,
|
||||
)
|
||||
|
||||
assert graph["nodes"][0]["id"] == "start"
|
||||
assert graph["nodes"][1]["id"] == "target-node"
|
||||
assert graph["nodes"][1]["width"] == 320
|
||||
assert graph["nodes"][1]["height"] == 180
|
||||
assert graph["edges"] == [
|
||||
{
|
||||
"source": "start",
|
||||
"target": "target-node",
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
}
|
||||
]
|
||||
|
||||
def test_run_free_node_rejects_unsupported_types(self):
|
||||
with pytest.raises(ValueError, match="Node type start not supported"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.START.value},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
def test_run_free_node_rejects_missing_node_class(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": None}},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
def test_run_free_node_uses_empty_mapping_when_selector_extraction_is_not_implemented(self, monkeypatch):
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "parameter-extractor"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
dify_node_factory = MagicMock()
|
||||
dify_node_factory.create_node.return_value = FakeNode()
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables),
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
|
||||
patch.object(
|
||||
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params
|
||||
) as graph_init_params,
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(
|
||||
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
|
||||
) as build_dify_run_context,
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
node, generator = workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
variable_pool_cls.assert_called_once_with(
|
||||
system_variables=sentinel.system_variables,
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
build_dify_run_context.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
app_id="",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_params.assert_called_once_with(
|
||||
workflow_id="",
|
||||
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
|
||||
"node-id", {"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}
|
||||
),
|
||||
run_context={"_dify": "context"},
|
||||
call_depth=0,
|
||||
)
|
||||
dify_node_factory_cls.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
mapping_user_inputs_to_variable_pool.assert_called_once_with(
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_run_free_node_wraps_execution_failures(self, monkeypatch):
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "parameter-extractor"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
dify_node_factory = MagicMock()
|
||||
dify_node_factory.create_node.return_value = FakeNode()
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables),
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory),
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(WorkflowNodeRunFailedError, match="Node Title run failed: boom"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
|
||||
def test_handle_special_values_serializes_nested_files(self):
|
||||
file = File(
|
||||
tenant_id="tenant-id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.png",
|
||||
filename="image.png",
|
||||
extension=".png",
|
||||
)
|
||||
|
||||
result = workflow_entry.WorkflowEntry.handle_special_values({"file": file, "nested": {"files": [file]}})
|
||||
|
||||
assert result == {
|
||||
"file": file.to_dict(),
|
||||
"nested": {"files": [file.to_dict()]},
|
||||
}
|
||||
|
||||
def test_handle_special_values_returns_none_for_none(self):
|
||||
assert workflow_entry.WorkflowEntry._handle_special_values(None) is None
|
||||
|
||||
def test_handle_special_values_returns_scalar_as_is(self):
|
||||
assert workflow_entry.WorkflowEntry._handle_special_values("plain-text") == "plain-text"
|
||||
|
||||
|
||||
class TestMappingUserInputsBranches:
|
||||
def test_rejects_invalid_node_variable_key(self):
|
||||
class EmptySplitKey(UserString):
|
||||
def split(self, _sep=None):
|
||||
return []
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid node variable broken"):
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={EmptySplitKey("broken"): ["node", "input"]},
|
||||
user_inputs={},
|
||||
variable_pool=MagicMock(),
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_skips_none_user_input_when_variable_already_exists(self):
|
||||
variable_pool = MagicMock()
|
||||
variable_pool.get.return_value = None
|
||||
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={"node.input": ["target", "input"]},
|
||||
user_inputs={"node.input": None},
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
variable_pool.add.assert_not_called()
|
||||
|
||||
def test_merges_structured_output_values(self):
|
||||
variable_pool = MagicMock()
|
||||
variable_pool.get.side_effect = [
|
||||
None,
|
||||
SimpleNamespace(value={"existing": "value"}),
|
||||
]
|
||||
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={"node.answer": ["target", "structured_output", "answer"]},
|
||||
user_inputs={"node.answer": "new-value"},
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
variable_pool.add.assert_called_once_with(
|
||||
["target", "structured_output"],
|
||||
{"existing": "value", "answer": "new-value"},
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowEntryTracing:
|
||||
def test_traced_node_run_reports_success(self):
|
||||
layer = MagicMock()
|
||||
|
||||
class FakeNode:
|
||||
def ensure_execution_id(self):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
yield "event"
|
||||
|
||||
with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer):
|
||||
events = list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode()))
|
||||
|
||||
assert events == ["event"]
|
||||
layer.on_graph_start.assert_called_once_with()
|
||||
layer.on_node_run_start.assert_called_once()
|
||||
layer.on_node_run_end.assert_called_once_with(
|
||||
layer.on_node_run_start.call_args.args[0],
|
||||
None,
|
||||
)
|
||||
|
||||
def test_traced_node_run_reports_errors(self):
|
||||
layer = MagicMock()
|
||||
|
||||
class FakeNode:
|
||||
def ensure_execution_id(self):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
raise RuntimeError("boom")
|
||||
yield
|
||||
|
||||
with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode()))
|
||||
|
||||
assert isinstance(layer.on_node_run_end.call_args.args[1], RuntimeError)
|
||||
@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def jina_module() -> ModuleType:
|
||||
"""
|
||||
Load `api/services/auth/jina.py` as a standalone module.
|
||||
|
||||
This repo contains both `services/auth/jina.py` and a package at
|
||||
`services/auth/jina/`, so importing `services.auth.jina` can be ambiguous.
|
||||
"""
|
||||
|
||||
module_path = Path(__file__).resolve().parents[4] / "services" / "auth" / "jina.py"
|
||||
# Use a stable module name so pytest-cov can target it with `--cov=services.auth.jina_file`.
|
||||
spec = importlib.util.spec_from_file_location("services.auth.jina_file", module_path)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _credentials(api_key: str | None = "test_api_key_123", auth_type: str = "bearer") -> dict:
|
||||
config: dict = {} if api_key is None else {"api_key": api_key}
|
||||
return {"auth_type": auth_type, "config": config}
|
||||
|
||||
|
||||
def test_init_valid_bearer_credentials(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials())
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.credentials["auth_type"] == "bearer"
|
||||
|
||||
|
||||
def test_init_rejects_invalid_auth_type(jina_module: ModuleType) -> None:
|
||||
with pytest.raises(ValueError, match="Invalid auth type.*Bearer"):
|
||||
jina_module.JinaAuth(_credentials(auth_type="basic"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("credentials", [{"auth_type": "bearer", "config": {}}, {"auth_type": "bearer"}])
|
||||
def test_init_requires_api_key(jina_module: ModuleType, credentials: dict) -> None:
|
||||
with pytest.raises(ValueError, match="No API key provided"):
|
||||
jina_module.JinaAuth(credentials)
|
||||
|
||||
|
||||
def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
assert auth._prepare_headers() == {"Content-Type": "application/json", "Authorization": "Bearer k"}
|
||||
|
||||
|
||||
def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
post_mock = MagicMock(name="httpx.post")
|
||||
monkeypatch.setattr(jina_module.httpx, "post", post_mock)
|
||||
|
||||
auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"})
|
||||
post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"})
|
||||
|
||||
|
||||
def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
post_mock = MagicMock(return_value=response)
|
||||
monkeypatch.setattr(jina_module.httpx, "post", post_mock)
|
||||
|
||||
assert auth.validate_credentials() is True
|
||||
post_mock.assert_called_once_with(
|
||||
"https://r.jina.ai",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer k"},
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_non_200_raises_via_handle_error(
|
||||
jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
|
||||
response = MagicMock()
|
||||
response.status_code = 402
|
||||
response.json.return_value = {"error": "Payment required"}
|
||||
monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response))
|
||||
|
||||
with pytest.raises(Exception, match="Status code: 402.*Payment required"):
|
||||
auth.validate_credentials()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [402, 409, 500])
|
||||
def test_handle_error_statuses_use_response_json(jina_module: ModuleType, status_code: int) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
response.json.return_value = {"error": "boom"}
|
||||
|
||||
with pytest.raises(Exception, match=f"Status code: {status_code}.*boom"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_handle_error_statuses_default_unknown_error(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = 402
|
||||
response.json.return_value = {}
|
||||
|
||||
with pytest.raises(Exception, match="Unknown error occurred"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_handle_error_with_text_json_body(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = 403
|
||||
response.text = '{"error": "Forbidden"}'
|
||||
|
||||
with pytest.raises(Exception, match="Status code: 403.*Forbidden"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_handle_error_with_text_json_body_missing_error(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = 403
|
||||
response.text = "{}"
|
||||
|
||||
with pytest.raises(Exception, match="Unknown error occurred"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_handle_error_without_text_raises_unexpected(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = 404
|
||||
response.text = ""
|
||||
|
||||
with pytest.raises(Exception, match="Unexpected error occurred.*404"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_validate_credentials_propagates_network_errors(
|
||||
jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom")))
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="boom"):
|
||||
auth.validate_credentials()
|
||||
381
api/tests/unit_tests/services/test_ops_service.py
Normal file
381
api/tests/unit_tests/services/test_ops_service.py
Normal file
@ -0,0 +1,381 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import TracingProviderEnum
|
||||
from models.model import App, TraceAppConfig
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
||||
class TestOpsService:
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_db.session.query.assert_called_with(TraceAppConfig)
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None]
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
assert mock_db.session.query.call_count == 2
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = None
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
||||
OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "default_url"),
|
||||
[
|
||||
("arize", "https://app.arize.com/"),
|
||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||
("langsmith", "https://smith.langchain.com/"),
|
||||
("opik", "https://www.comet.com/opik/"),
|
||||
("weave", "https://wandb.ai/"),
|
||||
("aliyun", "https://arms.console.aliyun.com/"),
|
||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||
("mlflow", "http://localhost:5000/"),
|
||||
("databricks", "https://www.databricks.com/"),
|
||||
],
|
||||
)
|
||||
def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == default_url
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
|
||||
)
|
||||
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url"
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "success_url"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {})
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "Invalid tracing provider: invalid_provider"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.LANGFUSE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"})
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "Invalid Credentials"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "config"),
|
||||
[
|
||||
(TracingProviderEnum.ARIZE, {}),
|
||||
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
|
||||
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
|
||||
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
|
||||
],
|
||||
)
|
||||
def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config):
|
||||
# Arrange
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, config)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.LANGFUSE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config(
|
||||
"app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None]
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
# 'project' is in other_keys for Arize
|
||||
# provide an empty string for the project in the tracing_config
|
||||
# create_tracing_app_config will replace it with the default from the model
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""})
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_db.session.add.assert_called()
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
|
||||
OpsService.update_tracing_app_config("app_id", "invalid_provider", {})
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None]
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app]
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid Credentials"):
|
||||
OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
current_config.to_dict.return_value = {"some": "data"}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app]
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result == {"some": "data"}
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_no_config(self, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_success(self, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = trace_config
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_db.session.delete.assert_called_with(trace_config)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
1329
api/tests/unit_tests/services/test_summary_index_service.py
Normal file
1329
api/tests/unit_tests/services/test_summary_index_service.py
Normal file
File diff suppressed because it is too large
Load Diff
704
api/tests/unit_tests/services/test_vector_service.py
Normal file
704
api/tests/unit_tests/services/test_vector_service.py
Normal file
@ -0,0 +1,704 @@
|
||||
"""Unit tests for `api/services/vector_service.py`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import services.vector_service as vector_service_module
|
||||
from services.vector_service import VectorService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _UploadFileStub:
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ChildDocStub:
|
||||
page_content: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ParentDocStub:
|
||||
children: list[_ChildDocStub]
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
indexing_technique: str = "high_quality",
|
||||
doc_form: str = "text_model",
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "dataset-1",
|
||||
is_multimodal: bool = False,
|
||||
embedding_model_provider: str | None = "openai",
|
||||
embedding_model: str = "text-embedding",
|
||||
) -> MagicMock:
|
||||
dataset = MagicMock(name="dataset")
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.doc_form = doc_form
|
||||
dataset.indexing_technique = indexing_technique
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.embedding_model_provider = embedding_model_provider
|
||||
dataset.embedding_model = embedding_model
|
||||
return dataset
|
||||
|
||||
|
||||
def _make_segment(
|
||||
*,
|
||||
segment_id: str = "seg-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "dataset-1",
|
||||
document_id: str = "doc-1",
|
||||
content: str = "hello",
|
||||
index_node_id: str = "node-1",
|
||||
index_node_hash: str = "hash-1",
|
||||
attachments: list[dict[str, str]] | None = None,
|
||||
) -> MagicMock:
|
||||
segment = MagicMock(name="segment")
|
||||
segment.id = segment_id
|
||||
segment.tenant_id = tenant_id
|
||||
segment.dataset_id = dataset_id
|
||||
segment.document_id = document_id
|
||||
segment.content = content
|
||||
segment.index_node_id = index_node_id
|
||||
segment.index_node_hash = index_node_hash
|
||||
segment.attachments = attachments or []
|
||||
return segment
|
||||
|
||||
|
||||
def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock:
|
||||
session = MagicMock(name="session")
|
||||
|
||||
binding_query = MagicMock(name="binding_query")
|
||||
binding_query.where.return_value = binding_query
|
||||
binding_query.delete.return_value = 1
|
||||
|
||||
upload_query = MagicMock(name="upload_query")
|
||||
upload_query.where.return_value = upload_query
|
||||
upload_query.all.return_value = upload_files or []
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is vector_service_module.SegmentAttachmentBinding:
|
||||
return binding_query
|
||||
if model is vector_service_module.UploadFile:
|
||||
return upload_query
|
||||
return MagicMock(name=f"query({model})")
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
db_mock = MagicMock(name="db")
|
||||
db_mock.session = session
|
||||
return db_mock
|
||||
|
||||
|
||||
def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(is_multimodal=False)
|
||||
segment = _make_segment()
|
||||
|
||||
index_processor = MagicMock(name="index_processor")
|
||||
factory_instance = MagicMock(name="IndexProcessorFactory-instance")
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model")
|
||||
|
||||
index_processor.load.assert_called_once()
|
||||
args, kwargs = index_processor.load.call_args
|
||||
assert args[0] == dataset
|
||||
assert len(args[1]) == 1
|
||||
assert args[2] is None
|
||||
assert kwargs["with_keywords"] is True
|
||||
assert kwargs["keywords_list"] == [["k1"]]
|
||||
|
||||
|
||||
def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(is_multimodal=True)
|
||||
segment = _make_segment(
|
||||
attachments=[
|
||||
{"id": "img-1", "name": "a.png"},
|
||||
{"id": "img-2", "name": "b.png"},
|
||||
]
|
||||
)
|
||||
|
||||
index_processor = MagicMock(name="index_processor")
|
||||
factory_instance = MagicMock(name="IndexProcessorFactory-instance")
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model")
|
||||
|
||||
assert index_processor.load.call_count == 2
|
||||
first_args, first_kwargs = index_processor.load.call_args_list[0]
|
||||
assert first_args[0] == dataset
|
||||
assert len(first_args[1]) == 1
|
||||
assert first_kwargs["with_keywords"] is True
|
||||
|
||||
second_args, second_kwargs = index_processor.load.call_args_list[1]
|
||||
assert second_args[0] == dataset
|
||||
assert second_args[1] == []
|
||||
assert len(second_args[2]) == 2
|
||||
assert second_kwargs["with_keywords"] is False
|
||||
|
||||
|
||||
def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset()
|
||||
index_processor = MagicMock(name="index_processor")
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(None, [], dataset, "text_model")
|
||||
index_processor.load.assert_not_called()
|
||||
|
||||
|
||||
def _mock_parent_child_queries(
|
||||
*,
|
||||
dataset_document: object | None,
|
||||
processing_rule: object | None,
|
||||
) -> MagicMock:
|
||||
session = MagicMock(name="session")
|
||||
|
||||
doc_query = MagicMock(name="doc_query")
|
||||
doc_query.filter_by.return_value = doc_query
|
||||
doc_query.first.return_value = dataset_document
|
||||
|
||||
rule_query = MagicMock(name="rule_query")
|
||||
rule_query.where.return_value = rule_query
|
||||
rule_query.first.return_value = processing_rule
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is vector_service_module.DatasetDocument:
|
||||
return doc_query
|
||||
if model is vector_service_module.DatasetProcessRule:
|
||||
return rule_query
|
||||
return MagicMock(name=f"query({model})")
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
db_mock = MagicMock(name="db")
|
||||
db_mock.session = session
|
||||
return db_mock
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_explicit_model(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
embedding_model_provider="openai",
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
segment = _make_segment()
|
||||
|
||||
dataset_document = MagicMock(name="dataset_document")
|
||||
dataset_document.id = segment.document_id
|
||||
dataset_document.dataset_process_rule_id = "rule-1"
|
||||
dataset_document.doc_language = "en"
|
||||
dataset_document.created_by = "user-1"
|
||||
|
||||
processing_rule = MagicMock(name="processing_rule")
|
||||
processing_rule.to_dict.return_value = {"rules": {}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule),
|
||||
)
|
||||
|
||||
embedding_model_instance = MagicMock(name="embedding_model_instance")
|
||||
model_manager_instance = MagicMock(name="model_manager_instance")
|
||||
model_manager_instance.get_model_instance.return_value = embedding_model_instance
|
||||
monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance))
|
||||
|
||||
generate_child_chunks_mock = MagicMock()
|
||||
monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock)
|
||||
|
||||
index_processor = MagicMock()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
|
||||
model_manager_instance.get_model_instance.assert_called_once()
|
||||
generate_child_chunks_mock.assert_called_once_with(
|
||||
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
|
||||
)
|
||||
index_processor.load.assert_not_called()
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_uses_default_embedding_model_when_provider_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
embedding_model_provider=None,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
segment = _make_segment()
|
||||
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.dataset_process_rule_id = "rule-1"
|
||||
dataset_document.doc_language = "en"
|
||||
dataset_document.created_by = "user-1"
|
||||
|
||||
processing_rule = MagicMock()
|
||||
processing_rule.to_dict.return_value = {"rules": {}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule),
|
||||
)
|
||||
|
||||
embedding_model_instance = MagicMock()
|
||||
model_manager_instance = MagicMock()
|
||||
model_manager_instance.get_default_model_instance.return_value = embedding_model_instance
|
||||
monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance))
|
||||
|
||||
generate_child_chunks_mock = MagicMock()
|
||||
monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock)
|
||||
|
||||
index_processor = MagicMock()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
|
||||
model_manager_instance.get_default_model_instance.assert_called_once()
|
||||
generate_child_chunks_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_missing_document_logs_warning_and_continues(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX)
|
||||
segment = _make_segment()
|
||||
|
||||
processing_rule = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=None, processing_rule=processing_rule),
|
||||
)
|
||||
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
|
||||
|
||||
index_processor = MagicMock()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
logger_mock.warning.assert_called_once()
|
||||
index_processor.load.assert_not_called()
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_missing_processing_rule_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX)
|
||||
segment = _make_segment()
|
||||
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.dataset_process_rule_id = "rule-1"
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=dataset_document, processing_rule=None),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No processing rule found"):
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
indexing_technique="economy",
|
||||
)
|
||||
segment = _make_segment()
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.dataset_process_rule_id = "rule-1"
|
||||
processing_rule = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not high quality"):
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
|
||||
|
||||
def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
segment = _make_segment()
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
|
||||
VectorService.update_segment_vector(["k"], segment, dataset)
|
||||
|
||||
vector_instance.delete_by_ids.assert_called_once_with([segment.index_node_id])
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
add_args, add_kwargs = vector_instance.add_texts.call_args
|
||||
assert len(add_args[0]) == 1
|
||||
assert add_kwargs["duplicate_check"] is True
|
||||
|
||||
|
||||
def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
segment = _make_segment()
|
||||
|
||||
keyword_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance))
|
||||
|
||||
VectorService.update_segment_vector(["a", "b"], segment, dataset)
|
||||
|
||||
keyword_instance.delete_by_ids.assert_called_once_with([segment.index_node_id])
|
||||
keyword_instance.add_texts.assert_called_once()
|
||||
args, kwargs = keyword_instance.add_texts.call_args
|
||||
assert len(args[0]) == 1
|
||||
assert kwargs["keywords_list"] == [["a", "b"]]
|
||||
|
||||
|
||||
def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
segment = _make_segment()
|
||||
|
||||
keyword_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance))
|
||||
|
||||
VectorService.update_segment_vector(None, segment, dataset)
|
||||
keyword_instance.add_texts.assert_called_once()
|
||||
_, kwargs = keyword_instance.add_texts.call_args
|
||||
assert "keywords_list" not in kwargs
|
||||
|
||||
|
||||
def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1")
|
||||
segment = _make_segment(segment_id="seg-1")
|
||||
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.id = segment.document_id
|
||||
dataset_document.doc_language = "en"
|
||||
dataset_document.created_by = "user-1"
|
||||
|
||||
processing_rule = MagicMock()
|
||||
processing_rule.to_dict.return_value = {"rules": {}}
|
||||
|
||||
child1 = _ChildDocStub(page_content="c1", metadata={"doc_id": "c1-id", "doc_hash": "c1-h"})
|
||||
child2 = _ChildDocStub(page_content="c2", metadata={"doc_id": "c2-id", "doc_hash": "c2-h"})
|
||||
transformed = [_ParentDocStub(children=[child1, child2])]
|
||||
|
||||
index_processor = MagicMock()
|
||||
index_processor.transform.return_value = transformed
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
child_chunk_ctor = MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(vector_service_module, "ChildChunk", child_chunk_ctor)
|
||||
|
||||
db_mock = MagicMock()
|
||||
db_mock.session.add = MagicMock()
|
||||
db_mock.session.commit = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.generate_child_chunks(
|
||||
segment=segment,
|
||||
dataset_document=dataset_document,
|
||||
dataset=dataset,
|
||||
embedding_model_instance=MagicMock(),
|
||||
processing_rule=processing_rule,
|
||||
regenerate=True,
|
||||
)
|
||||
|
||||
index_processor.clean.assert_called_once()
|
||||
_, transform_kwargs = index_processor.transform.call_args
|
||||
assert transform_kwargs["process_rule"]["rules"]["parent_mode"] == vector_service_module.ParentMode.FULL_DOC
|
||||
index_processor.load.assert_called_once()
|
||||
assert db_mock.session.add.call_count == 2
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(doc_form="text_model")
|
||||
segment = _make_segment()
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.doc_language = "en"
|
||||
dataset_document.created_by = "user-1"
|
||||
|
||||
processing_rule = MagicMock()
|
||||
processing_rule.to_dict.return_value = {"rules": {}}
|
||||
|
||||
index_processor = MagicMock()
|
||||
index_processor.transform.return_value = [_ParentDocStub(children=[])]
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
db_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.generate_child_chunks(
|
||||
segment=segment,
|
||||
dataset_document=dataset_document,
|
||||
dataset=dataset,
|
||||
embedding_model_instance=MagicMock(),
|
||||
processing_rule=processing_rule,
|
||||
regenerate=False,
|
||||
)
|
||||
|
||||
index_processor.load.assert_not_called()
|
||||
db_mock.session.add.assert_not_called()
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
child_chunk = MagicMock()
|
||||
child_chunk.content = "child"
|
||||
child_chunk.index_node_id = "id"
|
||||
child_chunk.index_node_hash = "h"
|
||||
child_chunk.document_id = "doc-1"
|
||||
child_chunk.dataset_id = "dataset-1"
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
|
||||
VectorService.create_child_chunk_vector(child_chunk, dataset)
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
|
||||
|
||||
def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
vector_cls = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
|
||||
child_chunk = MagicMock()
|
||||
child_chunk.content = "child"
|
||||
child_chunk.index_node_id = "id"
|
||||
child_chunk.index_node_hash = "h"
|
||||
child_chunk.document_id = "doc-1"
|
||||
child_chunk.dataset_id = "dataset-1"
|
||||
|
||||
VectorService.create_child_chunk_vector(child_chunk, dataset)
|
||||
vector_cls.assert_not_called()
|
||||
|
||||
|
||||
def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
|
||||
new_chunk = MagicMock()
|
||||
new_chunk.content = "n"
|
||||
new_chunk.index_node_id = "nid"
|
||||
new_chunk.index_node_hash = "nh"
|
||||
new_chunk.document_id = "d"
|
||||
new_chunk.dataset_id = "ds"
|
||||
|
||||
upd_chunk = MagicMock()
|
||||
upd_chunk.content = "u"
|
||||
upd_chunk.index_node_id = "uid"
|
||||
upd_chunk.index_node_hash = "uh"
|
||||
upd_chunk.document_id = "d"
|
||||
upd_chunk.dataset_id = "ds"
|
||||
|
||||
del_chunk = MagicMock()
|
||||
del_chunk.index_node_id = "did"
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
|
||||
VectorService.update_child_chunk_vector([new_chunk], [upd_chunk], [del_chunk], dataset)
|
||||
|
||||
vector_instance.delete_by_ids.assert_called_once_with(["uid", "did"])
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
docs = vector_instance.add_texts.call_args.args[0]
|
||||
assert len(docs) == 2
|
||||
|
||||
|
||||
def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
vector_cls = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
VectorService.update_child_chunk_vector([], [], [], dataset)
|
||||
vector_cls.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset()
|
||||
child_chunk = MagicMock()
|
||||
child_chunk.index_node_id = "cid"
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
|
||||
VectorService.delete_child_chunk_vector(child_chunk, dataset)
|
||||
vector_instance.delete_by_ids.assert_called_once_with(["cid"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_multimodel_vector (missing coverage in previous suites)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy", is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}])
|
||||
|
||||
vector_cls = MagicMock()
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[])
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["a"], dataset=dataset)
|
||||
vector_cls.assert_not_called()
|
||||
db_mock.session.query.assert_not_called()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}])
|
||||
|
||||
vector_cls = MagicMock()
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[])
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["b", "a"], dataset=dataset)
|
||||
vector_cls.assert_not_called()
|
||||
db_mock.session.query.assert_not_called()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}])
|
||||
|
||||
vector_instance = MagicMock(name="vector_instance")
|
||||
vector_cls = MagicMock(return_value=vector_instance)
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[])
|
||||
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=[], dataset=dataset)
|
||||
|
||||
vector_cls.assert_called_once_with(dataset=dataset)
|
||||
vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"])
|
||||
db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding)
|
||||
db_mock.session.commit.assert_called_once()
|
||||
db_mock.session.add_all.assert_not_called()
|
||||
vector_instance.add_texts.assert_not_called()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[])
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["new-1"], dataset=dataset)
|
||||
|
||||
db_mock.session.commit.assert_called_once()
|
||||
db_mock.session.add_all.assert_not_called()
|
||||
vector_instance.add_texts.assert_not_called()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")])
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor)
|
||||
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1", "missing"], dataset=dataset)
|
||||
|
||||
logger_mock.warning.assert_called_once()
|
||||
db_mock.session.add_all.assert_called_once()
|
||||
bindings = db_mock.session.add_all.call_args.args[0]
|
||||
assert len(bindings) == 1
|
||||
assert bindings[0]["attachment_id"] == "file-1"
|
||||
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
documents = vector_instance.add_texts.call_args.args[0]
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "img.png"
|
||||
assert documents[0].metadata["doc_id"] == "file-1"
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")])
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
monkeypatch.setattr(
|
||||
vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset)
|
||||
|
||||
vector_instance.delete_by_ids.assert_not_called()
|
||||
vector_instance.add_texts.assert_not_called()
|
||||
db_mock.session.add_all.assert_called_once()
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")])
|
||||
db_mock.session.commit.side_effect = RuntimeError("boom")
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
monkeypatch.setattr(
|
||||
vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
)
|
||||
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset)
|
||||
|
||||
logger_mock.exception.assert_called_once()
|
||||
db_mock.session.rollback.assert_called_once()
|
||||
718
api/tests/unit_tests/services/test_website_service.py
Normal file
718
api/tests/unit_tests/services/test_website_service.py
Normal file
@ -0,0 +1,718 @@
|
||||
"""Unit tests for services.website_service.
|
||||
|
||||
Focuses on provider dispatching, argument validation, and provider-specific branches
|
||||
without making any real network/storage/redis calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import services.website_service as website_service_module
|
||||
from services.website_service import (
|
||||
CrawlOptions,
|
||||
WebsiteCrawlApiRequest,
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
WebsiteService,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _DummyHttpxResponse:
|
||||
payload: dict[str, Any]
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
return self.payload
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_current_user(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
website_service_module,
|
||||
"current_user",
|
||||
type("User", (), {"current_tenant_id": "tenant-1"})(),
|
||||
)
|
||||
|
||||
|
||||
def test_crawl_options_include_exclude_paths() -> None:
|
||||
options = CrawlOptions(includes="a,b", excludes="x,y")
|
||||
assert options.get_include_paths() == ["a", "b"]
|
||||
assert options.get_exclude_paths() == ["x", "y"]
|
||||
|
||||
empty = CrawlOptions(includes=None, excludes=None)
|
||||
assert empty.get_include_paths() == []
|
||||
assert empty.get_exclude_paths() == []
|
||||
|
||||
|
||||
def test_website_crawl_api_request_from_args_valid_and_to_crawl_request() -> None:
|
||||
args = {
|
||||
"provider": "firecrawl",
|
||||
"url": "https://example.com",
|
||||
"options": {
|
||||
"limit": 2,
|
||||
"crawl_sub_pages": True,
|
||||
"only_main_content": True,
|
||||
"includes": "a,b",
|
||||
"excludes": "x",
|
||||
"prompt": "hi",
|
||||
"max_depth": 3,
|
||||
"use_sitemap": False,
|
||||
},
|
||||
}
|
||||
|
||||
api_req = WebsiteCrawlApiRequest.from_args(args)
|
||||
crawl_req = api_req.to_crawl_request()
|
||||
|
||||
assert crawl_req.provider == "firecrawl"
|
||||
assert crawl_req.url == "https://example.com"
|
||||
assert crawl_req.options.limit == 2
|
||||
assert crawl_req.options.crawl_sub_pages is True
|
||||
assert crawl_req.options.only_main_content is True
|
||||
assert crawl_req.options.get_include_paths() == ["a", "b"]
|
||||
assert crawl_req.options.get_exclude_paths() == ["x"]
|
||||
assert crawl_req.options.prompt == "hi"
|
||||
assert crawl_req.options.max_depth == 3
|
||||
assert crawl_req.options.use_sitemap is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("args", "missing_msg"),
|
||||
[
|
||||
({}, "Provider is required"),
|
||||
({"provider": "firecrawl"}, "URL is required"),
|
||||
({"provider": "firecrawl", "url": "https://example.com"}, "Options are required"),
|
||||
],
|
||||
)
|
||||
def test_website_crawl_api_request_from_args_requires_fields(args: dict, missing_msg: str) -> None:
|
||||
with pytest.raises(ValueError, match=missing_msg):
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
|
||||
|
||||
def test_website_crawl_status_api_request_from_args_requires_fields() -> None:
|
||||
with pytest.raises(ValueError, match="Provider is required"):
|
||||
WebsiteCrawlStatusApiRequest.from_args({}, job_id="job-1")
|
||||
|
||||
with pytest.raises(ValueError, match="Job ID is required"):
|
||||
WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="")
|
||||
|
||||
req = WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="job-1")
|
||||
assert req.provider == "firecrawl"
|
||||
assert req.job_id == "job-1"
|
||||
|
||||
|
||||
def test_get_credentials_and_config_selects_plugin_id_and_key_firecrawl(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service_instance = MagicMock(name="DatasourceProviderService-instance")
|
||||
service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k", "base_url": "b"}
|
||||
monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance))
|
||||
|
||||
api_key, config = WebsiteService._get_credentials_and_config("tenant-1", "firecrawl")
|
||||
assert api_key == "k"
|
||||
assert config["base_url"] == "b"
|
||||
|
||||
service_instance.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider="firecrawl",
|
||||
plugin_id="langgenius/firecrawl_datasource",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "plugin_id"),
|
||||
[
|
||||
("watercrawl", "langgenius/watercrawl_datasource"),
|
||||
("jinareader", "langgenius/jina_datasource"),
|
||||
],
|
||||
)
|
||||
def test_get_credentials_and_config_selects_plugin_id_and_key_api_key(
|
||||
monkeypatch: pytest.MonkeyPatch, provider: str, plugin_id: str
|
||||
) -> None:
|
||||
service_instance = MagicMock(name="DatasourceProviderService-instance")
|
||||
service_instance.get_datasource_credentials.return_value = {"api_key": "enc-key", "base_url": "b"}
|
||||
monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance))
|
||||
|
||||
api_key, config = WebsiteService._get_credentials_and_config("tenant-1", provider)
|
||||
assert api_key == "enc-key"
|
||||
assert config["base_url"] == "b"
|
||||
|
||||
service_instance.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
|
||||
def test_get_credentials_and_config_rejects_invalid_provider() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService._get_credentials_and_config("tenant-1", "unknown")
|
||||
|
||||
|
||||
def test_get_credentials_and_config_hits_unreachable_guard_branch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FlakyProvider:
|
||||
def __init__(self) -> None:
|
||||
self._eq_calls = 0
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 1
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if other == "firecrawl":
|
||||
self._eq_calls += 1
|
||||
return self._eq_calls == 1
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FlakyProvider()"
|
||||
|
||||
service_instance = MagicMock(name="DatasourceProviderService-instance")
|
||||
service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k"}
|
||||
monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService._get_credentials_and_config("tenant-1", FlakyProvider()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_get_decrypted_api_key_requires_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", MagicMock())
|
||||
with pytest.raises(ValueError, match="API key not found in configuration"):
|
||||
WebsiteService._get_decrypted_api_key("tenant-1", {})
|
||||
|
||||
|
||||
def test_get_decrypted_api_key_decrypts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
decrypt_mock = MagicMock(return_value="plain")
|
||||
monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", decrypt_mock)
|
||||
|
||||
assert WebsiteService._get_decrypted_api_key("tenant-1", {"api_key": "enc"}) == "plain"
|
||||
decrypt_mock.assert_called_once_with(tenant_id="tenant-1", token="enc")
|
||||
|
||||
|
||||
def test_document_create_args_validate_wraps_error_message() -> None:
|
||||
with pytest.raises(ValueError, match=r"^Invalid arguments: Provider is required$"):
|
||||
WebsiteService.document_create_args_validate({})
|
||||
|
||||
|
||||
def test_crawl_url_dispatches_by_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api_request = WebsiteCrawlApiRequest(provider="firecrawl", url="https://example.com", options={"limit": 1})
|
||||
crawl_request = api_request.to_crawl_request()
|
||||
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
firecrawl_mock = MagicMock(return_value={"status": "active", "job_id": "j1"})
|
||||
monkeypatch.setattr(WebsiteService, "_crawl_with_firecrawl", firecrawl_mock)
|
||||
|
||||
result = WebsiteService.crawl_url(api_request)
|
||||
|
||||
assert result == {"status": "active", "job_id": "j1"}
|
||||
firecrawl_mock.assert_called_once()
|
||||
assert firecrawl_mock.call_args.kwargs["request"] == crawl_request
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "method_name"),
|
||||
[
|
||||
("watercrawl", "_crawl_with_watercrawl"),
|
||||
("jinareader", "_crawl_with_jinareader"),
|
||||
],
|
||||
)
|
||||
def test_crawl_url_dispatches_other_providers(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None:
|
||||
api_request = WebsiteCrawlApiRequest(provider=provider, url="https://example.com", options={"limit": 1})
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
|
||||
impl_mock = MagicMock(return_value={"status": "active"})
|
||||
monkeypatch.setattr(WebsiteService, method_name, impl_mock)
|
||||
|
||||
assert WebsiteService.crawl_url(api_request) == {"status": "active"}
|
||||
impl_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_crawl_url_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api_request = WebsiteCrawlApiRequest(provider="bad", url="https://example.com", options={"limit": 1})
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {})))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.crawl_url(api_request)
|
||||
|
||||
|
||||
def test_crawl_with_firecrawl_builds_params_single_page_and_sets_redis(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock(name="FirecrawlApp-instance")
|
||||
firecrawl_instance.crawl_url.return_value = "job-1"
|
||||
firecrawl_cls = MagicMock(return_value=firecrawl_instance)
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", firecrawl_cls)
|
||||
|
||||
redis_mock = MagicMock()
|
||||
monkeypatch.setattr(website_service_module, "redis_client", redis_mock)
|
||||
|
||||
fixed_now = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
with patch.object(website_service_module.datetime, "datetime") as datetime_mock:
|
||||
datetime_mock.now.return_value = fixed_now
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="firecrawl", url="https://example.com", options={"limit": 5}
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = False
|
||||
req.options.only_main_content = True
|
||||
|
||||
result = WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": "b"})
|
||||
|
||||
assert result == {"status": "active", "job_id": "job-1"}
|
||||
|
||||
firecrawl_cls.assert_called_once_with(api_key="k", base_url="b")
|
||||
firecrawl_instance.crawl_url.assert_called_once()
|
||||
_, params = firecrawl_instance.crawl_url.call_args.args
|
||||
assert params["limit"] == 1
|
||||
assert params["includePaths"] == []
|
||||
assert params["excludePaths"] == []
|
||||
assert params["scrapeOptions"] == {"onlyMainContent": True}
|
||||
|
||||
redis_mock.setex.assert_called_once()
|
||||
key, ttl, value = redis_mock.setex.call_args.args
|
||||
assert key == "website_crawl_job-1"
|
||||
assert ttl == 3600
|
||||
assert float(value) == pytest.approx(fixed_now.timestamp(), rel=0, abs=1e-6)
|
||||
|
||||
|
||||
def test_crawl_with_firecrawl_builds_params_multi_page_including_prompt(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock(name="FirecrawlApp-instance")
|
||||
firecrawl_instance.crawl_url.return_value = "job-2"
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
monkeypatch.setattr(website_service_module, "redis_client", MagicMock())
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="firecrawl",
|
||||
url="https://example.com",
|
||||
options={
|
||||
"crawl_sub_pages": True,
|
||||
"limit": 3,
|
||||
"only_main_content": False,
|
||||
"includes": "a,b",
|
||||
"excludes": "x",
|
||||
"prompt": "use this",
|
||||
},
|
||||
).to_crawl_request()
|
||||
|
||||
WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": None})
|
||||
_, params = firecrawl_instance.crawl_url.call_args.args
|
||||
assert params["includePaths"] == ["a", "b"]
|
||||
assert params["excludePaths"] == ["x"]
|
||||
assert params["limit"] == 3
|
||||
assert params["scrapeOptions"] == {"onlyMainContent": False}
|
||||
assert params["prompt"] == "use this"
|
||||
|
||||
|
||||
def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.crawl_url.return_value = {"status": "active", "job_id": "w1"}
|
||||
provider_cls = MagicMock(return_value=provider_instance)
|
||||
monkeypatch.setattr(website_service_module, "WaterCrawlProvider", provider_cls)
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="watercrawl",
|
||||
url="https://example.com",
|
||||
options={
|
||||
"limit": 2,
|
||||
"crawl_sub_pages": True,
|
||||
"only_main_content": True,
|
||||
"includes": "a",
|
||||
"excludes": None,
|
||||
"max_depth": 5,
|
||||
"use_sitemap": False,
|
||||
},
|
||||
).to_crawl_request()
|
||||
|
||||
result = WebsiteService._crawl_with_watercrawl(request=req, api_key="k", config={"base_url": "b"})
|
||||
assert result == {"status": "active", "job_id": "w1"}
|
||||
|
||||
provider_cls.assert_called_once_with(api_key="k", base_url="b")
|
||||
provider_instance.crawl_url.assert_called_once_with(
|
||||
url="https://example.com",
|
||||
options={
|
||||
"limit": 2,
|
||||
"crawl_sub_pages": True,
|
||||
"only_main_content": True,
|
||||
"includes": "a",
|
||||
"excludes": None,
|
||||
"max_depth": 5,
|
||||
"use_sitemap": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", get_mock)
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False}
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = False
|
||||
|
||||
result = WebsiteService._crawl_with_jinareader(request=req, api_key="k")
|
||||
assert result == {"status": "active", "data": {"title": "t"}}
|
||||
get_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500})))
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False}
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = False
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to crawl:"):
|
||||
WebsiteService._crawl_with_jinareader(request=req, api_key="k")
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader",
|
||||
url="https://example.com",
|
||||
options={"crawl_sub_pages": True, "limit": 5, "use_sitemap": True},
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = True
|
||||
|
||||
result = WebsiteService._crawl_with_jinareader(request=req, api_key="k")
|
||||
assert result == {"status": "active", "job_id": "t1"}
|
||||
post_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400}))
|
||||
)
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader",
|
||||
url="https://example.com",
|
||||
options={"crawl_sub_pages": True, "limit": 2, "use_sitemap": False},
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = True
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to crawl$"):
|
||||
WebsiteService._crawl_with_jinareader(request=req, api_key="k")
|
||||
|
||||
|
||||
def test_get_crawl_status_dispatches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
firecrawl_status = MagicMock(return_value={"status": "active"})
|
||||
monkeypatch.setattr(WebsiteService, "_get_firecrawl_status", firecrawl_status)
|
||||
|
||||
result = WebsiteService.get_crawl_status("job-1", "firecrawl")
|
||||
assert result == {"status": "active"}
|
||||
firecrawl_status.assert_called_once_with("job-1", "k", {"base_url": "b"})
|
||||
|
||||
watercrawl_status = MagicMock(return_value={"status": "active", "job_id": "w"})
|
||||
monkeypatch.setattr(WebsiteService, "_get_watercrawl_status", watercrawl_status)
|
||||
assert WebsiteService.get_crawl_status("job-2", "watercrawl") == {"status": "active", "job_id": "w"}
|
||||
watercrawl_status.assert_called_once_with("job-2", "k", {"base_url": "b"})
|
||||
|
||||
jinareader_status = MagicMock(return_value={"status": "active", "job_id": "j"})
|
||||
monkeypatch.setattr(WebsiteService, "_get_jinareader_status", jinareader_status)
|
||||
assert WebsiteService.get_crawl_status("job-3", "jinareader") == {"status": "active", "job_id": "j"}
|
||||
jinareader_status.assert_called_once_with("job-3", "k")
|
||||
|
||||
|
||||
def test_get_crawl_status_typed_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {})))
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.get_crawl_status_typed(WebsiteCrawlStatusApiRequest(provider="bad", job_id="j"))
|
||||
|
||||
|
||||
def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 2, "current": 2, "data": []}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
redis_mock = MagicMock()
|
||||
redis_mock.get.return_value = b"100.0"
|
||||
monkeypatch.setattr(website_service_module, "redis_client", redis_mock)
|
||||
|
||||
with patch.object(website_service_module.datetime, "datetime") as datetime_mock:
|
||||
datetime_mock.now.return_value = datetime.fromtimestamp(105.0, tz=UTC)
|
||||
result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": "b"})
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["time_consuming"] == "5.00"
|
||||
redis_mock.delete.assert_called_once_with("website_crawl_job-1")
|
||||
|
||||
|
||||
def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed"}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
redis_mock = MagicMock()
|
||||
redis_mock.get.return_value = None
|
||||
monkeypatch.setattr(website_service_module, "redis_client", redis_mock)
|
||||
|
||||
result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": None})
|
||||
assert result["status"] == "completed"
|
||||
assert "time_consuming" not in result
|
||||
redis_mock.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_get_watercrawl_status_delegates(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.get_crawl_status.return_value = {"status": "active", "job_id": "w1"}
|
||||
monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance))
|
||||
|
||||
assert WebsiteService._get_watercrawl_status("job-1", "k", {"base_url": "b"}) == {
|
||||
"status": "active",
|
||||
"job_id": "w1",
|
||||
}
|
||||
provider_instance.get_crawl_status.assert_called_once_with("job-1")
|
||||
|
||||
|
||||
def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
post_mock = MagicMock(
|
||||
return_value=_DummyHttpxResponse(
|
||||
{
|
||||
"data": {
|
||||
"status": "active",
|
||||
"urls": ["a", "b"],
|
||||
"processed": {"a": {}},
|
||||
"failed": {"b": {}},
|
||||
"duration": 3000,
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
result = WebsiteService._get_jinareader_status("job-1", "k")
|
||||
assert result["status"] == "active"
|
||||
assert result["total"] == 2
|
||||
assert result["current"] == 2
|
||||
assert result["time_consuming"] == 3.0
|
||||
assert result["data"] == []
|
||||
post_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
status_payload = {
|
||||
"data": {
|
||||
"status": "completed",
|
||||
"urls": ["u1"],
|
||||
"processed": {"u1": {}},
|
||||
"failed": {},
|
||||
"duration": 1000,
|
||||
}
|
||||
}
|
||||
processed_payload = {
|
||||
"data": {
|
||||
"processed": {
|
||||
"u1": {
|
||||
"data": {
|
||||
"title": "t",
|
||||
"url": "u1",
|
||||
"description": "d",
|
||||
"content": "md",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
result = WebsiteService._get_jinareader_status("job-1", "k")
|
||||
assert result["status"] == "completed"
|
||||
assert result["data"] == [{"title": "t", "source_url": "u1", "description": "d", "markdown": "md"}]
|
||||
assert post_mock.call_count == 2
|
||||
|
||||
|
||||
def test_get_crawl_url_data_dispatches_invalid_provider() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.get_crawl_url_data("job-1", "bad", "https://example.com", "tenant-1")
|
||||
|
||||
|
||||
def test_get_crawl_url_data_hits_invalid_provider_branch_when_credentials_stubbed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {})))
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.get_crawl_url_data("job-1", object(), "u", "tenant-1") # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "method_name"),
|
||||
[
|
||||
("firecrawl", "_get_firecrawl_url_data"),
|
||||
("watercrawl", "_get_watercrawl_url_data"),
|
||||
("jinareader", "_get_jinareader_url_data"),
|
||||
],
|
||||
)
|
||||
def test_get_crawl_url_data_dispatches(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
impl_mock = MagicMock(return_value={"ok": True})
|
||||
monkeypatch.setattr(WebsiteService, method_name, impl_mock)
|
||||
|
||||
result = WebsiteService.get_crawl_url_data("job-1", provider, "u", "tenant-1")
|
||||
assert result == {"ok": True}
|
||||
impl_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_get_firecrawl_url_data_reads_from_storage_when_present(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
stored_list = [{"source_url": "https://example.com", "title": "t"}]
|
||||
stored = json.dumps(stored_list).encode("utf-8")
|
||||
|
||||
storage_mock = MagicMock()
|
||||
storage_mock.exists.return_value = True
|
||||
storage_mock.load_once.return_value = stored
|
||||
monkeypatch.setattr(website_service_module, "storage", storage_mock)
|
||||
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock())
|
||||
|
||||
result = WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"})
|
||||
assert result == {"source_url": "https://example.com", "title": "t"}
|
||||
assert result is not stored_list[0]
|
||||
|
||||
|
||||
def test_get_firecrawl_url_data_returns_none_when_storage_empty(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
storage_mock = MagicMock()
|
||||
storage_mock.exists.return_value = True
|
||||
storage_mock.load_once.return_value = b""
|
||||
monkeypatch.setattr(website_service_module, "storage", storage_mock)
|
||||
|
||||
assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {}) is None
|
||||
|
||||
|
||||
def test_get_firecrawl_url_data_raises_when_job_not_completed(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
storage_mock = MagicMock()
|
||||
storage_mock.exists.return_value = False
|
||||
monkeypatch.setattr(website_service_module, "storage", storage_mock)
|
||||
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "active"}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
with pytest.raises(ValueError, match="Crawl job is not completed"):
|
||||
WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": None})
|
||||
|
||||
|
||||
def test_get_firecrawl_url_data_returns_none_when_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
storage_mock = MagicMock()
|
||||
storage_mock.exists.return_value = False
|
||||
monkeypatch.setattr(website_service_module, "storage", storage_mock)
|
||||
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "data": [{"source_url": "x"}]}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) is None
|
||||
|
||||
|
||||
def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.get_crawl_url_data.return_value = {"source_url": "u"}
|
||||
monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance))
|
||||
|
||||
result = WebsiteService._get_watercrawl_url_data("job-1", "u", "k", {"base_url": "b"})
|
||||
assert result == {"source_url": "u"}
|
||||
provider_instance.get_crawl_url_data.assert_called_once_with("job-1", "u")
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
website_service_module.httpx,
|
||||
"get",
|
||||
MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})),
|
||||
)
|
||||
assert WebsiteService._get_jinareader_url_data("", "u", "k") == {"url": "u"}
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500})))
|
||||
with pytest.raises(ValueError, match="Failed to crawl$"):
|
||||
WebsiteService._get_jinareader_url_data("", "u", "k")
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}}
|
||||
processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}}
|
||||
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"}
|
||||
assert post_mock.call_count == 2
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"):
|
||||
WebsiteService._get_jinareader_url_data("job-1", "u", "k")
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_none(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}}
|
||||
processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}}
|
||||
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None
|
||||
|
||||
|
||||
def test_get_scrape_url_data_dispatches_and_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
|
||||
scrape_mock = MagicMock(return_value={"data": "x"})
|
||||
monkeypatch.setattr(WebsiteService, "_scrape_with_firecrawl", scrape_mock)
|
||||
assert WebsiteService.get_scrape_url_data("firecrawl", "u", "tenant-1", True) == {"data": "x"}
|
||||
scrape_mock.assert_called_once()
|
||||
|
||||
watercrawl_mock = MagicMock(return_value={"data": "y"})
|
||||
monkeypatch.setattr(WebsiteService, "_scrape_with_watercrawl", watercrawl_mock)
|
||||
assert WebsiteService.get_scrape_url_data("watercrawl", "u", "tenant-1", False) == {"data": "y"}
|
||||
watercrawl_mock.assert_called_once()
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.get_scrape_url_data("jinareader", "u", "tenant-1", True)
|
||||
|
||||
|
||||
def test_scrape_with_firecrawl_calls_app(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.scrape_url.return_value = {"markdown": "m"}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
result = WebsiteService._scrape_with_firecrawl(
|
||||
request=website_service_module.ScrapeRequest(
|
||||
provider="firecrawl",
|
||||
url="u",
|
||||
tenant_id="tenant-1",
|
||||
only_main_content=True,
|
||||
),
|
||||
api_key="k",
|
||||
config={"base_url": "b"},
|
||||
)
|
||||
assert result == {"markdown": "m"}
|
||||
firecrawl_instance.scrape_url.assert_called_once_with(url="u", params={"onlyMainContent": True})
|
||||
|
||||
|
||||
def test_scrape_with_watercrawl_calls_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.scrape_url.return_value = {"markdown": "m"}
|
||||
monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance))
|
||||
|
||||
result = WebsiteService._scrape_with_watercrawl(
|
||||
request=website_service_module.ScrapeRequest(
|
||||
provider="watercrawl",
|
||||
url="u",
|
||||
tenant_id="tenant-1",
|
||||
only_main_content=False,
|
||||
),
|
||||
api_key="k",
|
||||
config={"base_url": "b"},
|
||||
)
|
||||
assert result == {"markdown": "m"}
|
||||
provider_instance.scrape_url.assert_called_once_with("u")
|
||||
@ -311,7 +311,9 @@ class TestWorkflowService:
|
||||
mock_workflow.conversation_variables = []
|
||||
|
||||
# Mock node config
|
||||
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
|
||||
mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
|
||||
{"id": "node-1", "data": {"type": NodeType.LLM.value}}
|
||||
)
|
||||
mock_workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
|
||||
# Mock class methods
|
||||
@ -376,7 +378,9 @@ class TestWorkflowService:
|
||||
mock_workflow.tenant_id = "tenant-1"
|
||||
mock_workflow.environment_variables = []
|
||||
mock_workflow.conversation_variables = []
|
||||
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
|
||||
mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
|
||||
{"id": "node-1", "data": {"type": NodeType.LLM.value}}
|
||||
)
|
||||
mock_workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
|
||||
|
||||
@ -8,6 +8,8 @@
|
||||
import { cleanup, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
let mockTheme = 'light'
|
||||
|
||||
vi.mock('#i18n', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
@ -19,16 +21,16 @@ vi.mock('@/context/i18n', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => ({ theme: 'light' }),
|
||||
default: () => ({ theme: mockTheme }),
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n-config', () => ({
|
||||
renderI18nObject: (obj: Record<string, string>, locale: string) => obj[locale] || obj.en_US || '',
|
||||
}))
|
||||
|
||||
vi.mock('@/types/app', () => ({
|
||||
Theme: { dark: 'dark', light: 'light' },
|
||||
}))
|
||||
vi.mock('@/types/app', async () => {
|
||||
return vi.importActual<typeof import('@/types/app')>('@/types/app')
|
||||
})
|
||||
|
||||
vi.mock('@/utils/classnames', () => ({
|
||||
cn: (...args: unknown[]) => args.filter(a => typeof a === 'string' && a).join(' '),
|
||||
@ -100,6 +102,7 @@ type CardPayload = Parameters<typeof Card>[0]['payload']
|
||||
describe('Plugin Card Rendering Integration', () => {
|
||||
beforeEach(() => {
|
||||
cleanup()
|
||||
mockTheme = 'light'
|
||||
})
|
||||
|
||||
const makePayload = (overrides = {}) => ({
|
||||
@ -194,9 +197,7 @@ describe('Plugin Card Rendering Integration', () => {
|
||||
})
|
||||
|
||||
it('uses dark icon when theme is dark and icon_dark is provided', () => {
|
||||
vi.doMock('@/hooks/use-theme', () => ({
|
||||
default: () => ({ theme: 'dark' }),
|
||||
}))
|
||||
mockTheme = 'dark'
|
||||
|
||||
const payload = makePayload({
|
||||
icon: 'https://example.com/icon-light.png',
|
||||
@ -204,7 +205,7 @@ describe('Plugin Card Rendering Integration', () => {
|
||||
})
|
||||
|
||||
render(<Card payload={payload} />)
|
||||
expect(screen.getByTestId('card-icon')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('card-icon')).toHaveTextContent('https://example.com/icon-dark.png')
|
||||
})
|
||||
|
||||
it('shows loading placeholder when isLoading is true', () => {
|
||||
|
||||
@ -2,6 +2,7 @@ import type { ComponentProps } from 'react'
|
||||
import type { IChatItem } from '@/app/components/base/chat/chat/type'
|
||||
import type { AgentLogDetailResponse } from '@/models/log'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { fetchAgentLogDetail } from '@/service/log'
|
||||
import AgentLogDetail from '../detail'
|
||||
@ -104,7 +105,7 @@ describe('AgentLogDetail', () => {
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should show loading indicator while fetching data', async () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => { }))
|
||||
|
||||
renderComponent()
|
||||
|
||||
@ -193,6 +194,18 @@ describe('AgentLogDetail', () => {
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should not fetch data when app detail is unavailable', async () => {
|
||||
vi.mocked(useAppStore).mockImplementationOnce(selector => selector({ appDetail: undefined } as never))
|
||||
vi.mocked(fetchAgentLogDetail).mockResolvedValue(createMockResponse())
|
||||
|
||||
renderComponent()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchAgentLogDetail).not.toHaveBeenCalled()
|
||||
})
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should notify on API error', async () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockRejectedValue(new Error('API Error'))
|
||||
|
||||
|
||||
@ -139,4 +139,23 @@ describe('AgentLogModal', () => {
|
||||
|
||||
expect(mockProps.onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should ignore click-away before mounted state is set', () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
let invoked = false
|
||||
vi.mocked(useClickAway).mockImplementation((callback) => {
|
||||
if (!invoked) {
|
||||
invoked = true
|
||||
callback(new Event('click'))
|
||||
}
|
||||
})
|
||||
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: vi.fn(), close: vi.fn() } as React.ComponentProps<typeof ToastContext.Provider>['value']}>
|
||||
<AgentLogModal {...mockProps} />
|
||||
</ToastContext.Provider>,
|
||||
)
|
||||
|
||||
expect(mockProps.onCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -82,4 +82,9 @@ describe('ResultPanel', () => {
|
||||
render(<ResultPanel {...mockProps} agentMode="react" />)
|
||||
expect(screen.getByText('appDebug.agent.agentModeType.ReACT')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fallback to zero tokens when total_tokens is undefined', () => {
|
||||
render(<ResultPanel {...mockProps} total_tokens={undefined} />)
|
||||
expect(screen.getByText('0 Tokens')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -2,6 +2,7 @@ import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import ToolCallItem from '../tool-call'
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
|
||||
@ -17,6 +18,10 @@ vi.mock('@/app/components/workflow/block-icon', () => ({
|
||||
default: ({ type }: { type: BlockEnum }) => <div data-testid="block-icon" data-type={type} />,
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useLocale: vi.fn(() => 'en'),
|
||||
}))
|
||||
|
||||
const mockToolCall = {
|
||||
status: 'success',
|
||||
error: null,
|
||||
@ -41,6 +46,17 @@ describe('ToolCallItem', () => {
|
||||
expect(screen.getByTestId('block-icon')).toHaveAttribute('data-type', BlockEnum.Tool)
|
||||
})
|
||||
|
||||
it('should fallback to locale key with underscores when hyphenated key is missing', () => {
|
||||
vi.mocked(useLocale).mockReturnValueOnce('en-US')
|
||||
const fallbackLocaleToolCall = {
|
||||
...mockToolCall,
|
||||
tool_label: { en_US: 'Fallback Label' },
|
||||
}
|
||||
|
||||
render(<ToolCallItem toolCall={fallbackLocaleToolCall} isLLM={false} />)
|
||||
expect(screen.getByText('Fallback Label')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should format time correctly', () => {
|
||||
render(<ToolCallItem toolCall={mockToolCall} isLLM={false} />)
|
||||
expect(screen.getByText('1.500 s')).toBeInTheDocument()
|
||||
@ -54,13 +70,17 @@ describe('ToolCallItem', () => {
|
||||
expect(screen.getByText('1 m 5.000 s')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should format token count correctly', () => {
|
||||
it('should format token count in K units', () => {
|
||||
render(<ToolCallItem toolCall={mockToolCall} isLLM={true} tokens={1200} />)
|
||||
expect(screen.getByText('1.2K tokens')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should format token count without unit for small values', () => {
|
||||
render(<ToolCallItem toolCall={mockToolCall} isLLM={true} tokens={800} />)
|
||||
expect(screen.getByText('800 tokens')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should format token count in M units', () => {
|
||||
render(<ToolCallItem toolCall={mockToolCall} isLLM={true} tokens={1200000} />)
|
||||
expect(screen.getByText('1.2M tokens')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@ -45,6 +45,7 @@ const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => {
|
||||
execute: async (event: amplitude.Types.Event) => {
|
||||
// Only modify page view events
|
||||
if (event.event_type === '[Amplitude] Page Viewed' && event.event_properties) {
|
||||
/* v8 ignore next @preserve */
|
||||
const pathname = typeof window !== 'undefined' ? window.location.pathname : ''
|
||||
event.event_properties['[Amplitude] Page Title'] = getEnglishPageName(pathname)
|
||||
}
|
||||
|
||||
@ -42,6 +42,7 @@ const ImageInput: FC<UploaderProps> = ({
|
||||
const [zoom, setZoom] = useState(1)
|
||||
|
||||
const onCropComplete = async (_: Area, croppedAreaPixels: Area) => {
|
||||
/* v8 ignore next -- unreachable guard when Cropper is rendered @preserve */
|
||||
if (!inputImage)
|
||||
return
|
||||
onImageInput?.(true, inputImage.url, croppedAreaPixels, inputImage.file.name)
|
||||
|
||||
@ -151,6 +151,43 @@ describe('BlockInput', () => {
|
||||
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle change when onConfirm is not provided', async () => {
|
||||
render(<BlockInput value="Hello" />)
|
||||
|
||||
const contentArea = screen.getByText('Hello')
|
||||
fireEvent.click(contentArea)
|
||||
|
||||
const textarea = await screen.findByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Hello World' } })
|
||||
|
||||
expect(textarea).toHaveValue('Hello World')
|
||||
})
|
||||
|
||||
it('should enter edit mode when clicked with empty value', async () => {
|
||||
render(<BlockInput value="" />)
|
||||
const contentArea = screen.getByTestId('block-input').firstChild as Element
|
||||
fireEvent.click(contentArea)
|
||||
|
||||
const textarea = await screen.findByRole('textbox')
|
||||
expect(textarea).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should exit edit mode on blur', async () => {
|
||||
render(<BlockInput value="Hello" />)
|
||||
|
||||
const contentArea = screen.getByText('Hello')
|
||||
fireEvent.click(contentArea)
|
||||
|
||||
const textarea = await screen.findByRole('textbox')
|
||||
expect(textarea).toBeInTheDocument()
|
||||
|
||||
fireEvent.blur(textarea)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
@ -168,8 +205,9 @@ describe('BlockInput', () => {
|
||||
})
|
||||
|
||||
it('should handle newlines in value', () => {
|
||||
render(<BlockInput value="line1\nline2" />)
|
||||
const { container } = render(<BlockInput value={`line1\nline2`} />)
|
||||
expect(screen.getByText(/line1/)).toBeInTheDocument()
|
||||
expect(container.querySelector('br')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle multiple same variables', () => {
|
||||
|
||||
@ -40,7 +40,7 @@ const createMockEmblaApi = (): MockEmblaApi => ({
|
||||
canScrollPrev: vi.fn(() => mockCanScrollPrev),
|
||||
canScrollNext: vi.fn(() => mockCanScrollNext),
|
||||
slideNodes: vi.fn(() =>
|
||||
Array.from({ length: mockSlideCount }, () => document.createElement('div')),
|
||||
Array.from({ length: mockSlideCount }).fill(document.createElement('div')),
|
||||
),
|
||||
on: vi.fn((event: EmblaEventName, callback: EmblaListener) => {
|
||||
listeners[event].push(callback)
|
||||
@ -50,12 +50,13 @@ const createMockEmblaApi = (): MockEmblaApi => ({
|
||||
}),
|
||||
})
|
||||
|
||||
const emitEmblaEvent = (event: EmblaEventName, api: MockEmblaApi | undefined = mockApi) => {
|
||||
function emitEmblaEvent(event: EmblaEventName, api?: MockEmblaApi) {
|
||||
const resolvedApi = arguments.length === 1 ? mockApi : api
|
||||
|
||||
listeners[event].forEach((callback) => {
|
||||
callback(api)
|
||||
callback(resolvedApi)
|
||||
})
|
||||
}
|
||||
|
||||
const renderCarouselWithControls = (orientation: 'horizontal' | 'vertical' = 'horizontal') => {
|
||||
return render(
|
||||
<Carousel orientation={orientation}>
|
||||
@ -133,6 +134,24 @@ describe('Carousel', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// Ref API exposes embla and controls.
|
||||
describe('Ref API', () => {
|
||||
it('should expose carousel API and controls via ref', () => {
|
||||
type CarouselRef = { api: unknown, selectedIndex: number }
|
||||
const ref = { current: null as CarouselRef | null }
|
||||
|
||||
render(
|
||||
<Carousel ref={(r) => { ref.current = r as unknown as CarouselRef }}>
|
||||
<Carousel.Content />
|
||||
</Carousel>,
|
||||
)
|
||||
|
||||
expect(ref.current).toBeDefined()
|
||||
expect(ref.current?.api).toBe(mockApi)
|
||||
expect(ref.current?.selectedIndex).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
// Users can move slides through previous and next controls.
|
||||
describe('User interactions', () => {
|
||||
it('should call scroll handlers when previous and next buttons are clicked', () => {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { ChatWithHistoryContextValue } from '../../context'
|
||||
import type { AppData, ConversationItem } from '@/models/share'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { act, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useChatWithHistoryContext } from '../../context'
|
||||
@ -237,7 +237,9 @@ describe('Header Component', () => {
|
||||
expect(handleRenameConversation).toHaveBeenCalledWith('conv-1', 'New Name', expect.any(Object))
|
||||
|
||||
const successCallback = handleRenameConversation.mock.calls[0][2].onSuccess
|
||||
successCallback()
|
||||
await act(async () => {
|
||||
successCallback()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('common.chat.renameConversation')).not.toBeInTheDocument()
|
||||
@ -268,7 +270,9 @@ describe('Header Component', () => {
|
||||
expect(handleDeleteConversation).toHaveBeenCalledWith('conv-1', expect.any(Object))
|
||||
|
||||
const successCallback = handleDeleteConversation.mock.calls[0][1].onSuccess
|
||||
successCallback()
|
||||
await act(async () => {
|
||||
successCallback()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('share.chat.deleteConversation.title')).not.toBeInTheDocument()
|
||||
@ -295,6 +299,20 @@ describe('Header Component', () => {
|
||||
expect(screen.queryByText('share.chat.deleteConversation.title')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty translated delete content via fallback', async () => {
|
||||
const mockConv = { id: 'conv-1', name: 'My Chat' } as ConversationItem
|
||||
setup({
|
||||
currentConversationId: 'conv-1',
|
||||
currentConversationItem: mockConv,
|
||||
sidebarCollapseState: true,
|
||||
})
|
||||
|
||||
await userEvent.click(screen.getByText('My Chat'))
|
||||
await userEvent.click(await screen.findByText('explore.sidebar.action.delete'))
|
||||
|
||||
expect(await screen.findByText('share.chat.deleteConversation.title')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
@ -317,6 +335,64 @@ describe('Header Component', () => {
|
||||
expect(titleEl).toHaveClass('system-md-semibold')
|
||||
})
|
||||
|
||||
it('should render app icon from URL when icon_url is provided', () => {
|
||||
setup({
|
||||
appData: {
|
||||
...mockAppData,
|
||||
site: {
|
||||
...mockAppData.site,
|
||||
icon_type: 'image',
|
||||
icon_url: 'https://example.com/icon.png',
|
||||
},
|
||||
},
|
||||
})
|
||||
const img = screen.getByAltText('app icon')
|
||||
expect(img).toHaveAttribute('src', 'https://example.com/icon.png')
|
||||
})
|
||||
|
||||
it('should handle undefined appData gracefully (optional chaining)', () => {
|
||||
setup({ appData: null as unknown as AppData })
|
||||
// Just verify it doesn't crash and renders the basic structure
|
||||
expect(screen.getAllByRole('button').length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should handle missing name in conversation item', () => {
|
||||
const mockConv = { id: 'conv-1', name: '' } as ConversationItem
|
||||
setup({
|
||||
currentConversationId: 'conv-1',
|
||||
currentConversationItem: mockConv,
|
||||
sidebarCollapseState: true,
|
||||
})
|
||||
// The separator is just a div with text content '/'
|
||||
expect(screen.getByText('/')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle New Chat button state when currentConversationId is present but isResponding is true', () => {
|
||||
setup({
|
||||
isResponding: true,
|
||||
sidebarCollapseState: true,
|
||||
currentConversationId: 'conv-1',
|
||||
})
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
// Sidebar, NewChat, ResetChat (3)
|
||||
const newChatBtn = buttons[1]
|
||||
expect(newChatBtn).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should handle New Chat button state when currentConversationId is missing and isResponding is false', () => {
|
||||
setup({
|
||||
isResponding: false,
|
||||
sidebarCollapseState: true,
|
||||
currentConversationId: '',
|
||||
})
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
// Sidebar, NewChat (2)
|
||||
const newChatBtn = buttons[1]
|
||||
expect(newChatBtn).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should not render operation menu if conversation id is missing', () => {
|
||||
setup({ currentConversationId: '', sidebarCollapseState: true })
|
||||
expect(screen.queryByText('My Chat')).not.toBeInTheDocument()
|
||||
@ -332,17 +408,20 @@ describe('Header Component', () => {
|
||||
expect(screen.queryByText('My Chat')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle New Chat button disabled state when responding', () => {
|
||||
setup({
|
||||
isResponding: true,
|
||||
it('should pass empty rename value when conversation name is undefined', async () => {
|
||||
const mockConv = { id: 'conv-1' } as ConversationItem
|
||||
const { container } = setup({
|
||||
currentConversationId: 'conv-1',
|
||||
currentConversationItem: mockConv,
|
||||
sidebarCollapseState: true,
|
||||
currentConversationId: undefined,
|
||||
})
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
// Sidebar(1) + NewChat(1) = 2
|
||||
const newChatBtn = buttons[1]
|
||||
expect(newChatBtn).toBeDisabled()
|
||||
const operationTrigger = container.querySelector('.flex.cursor-pointer.items-center.rounded-lg.p-1\\.5.pl-2.text-text-secondary.hover\\:bg-state-base-hover') as HTMLElement
|
||||
await userEvent.click(operationTrigger)
|
||||
await userEvent.click(await screen.findByText('explore.sidebar.action.rename'))
|
||||
|
||||
const input = screen.getByRole('textbox') as HTMLInputElement
|
||||
expect(input.value).toBe('')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -59,6 +59,7 @@ const Header = () => {
|
||||
setShowConfirm(null)
|
||||
}, [])
|
||||
const handleDelete = useCallback(() => {
|
||||
/* v8 ignore next -- defensive guard; onConfirm is only reachable when showConfirm is truthy. @preserve */
|
||||
if (showConfirm)
|
||||
handleDeleteConversation(showConfirm.id, { onSuccess: handleCancelConfirm })
|
||||
}, [showConfirm, handleDeleteConversation, handleCancelConfirm])
|
||||
@ -66,6 +67,7 @@ const Header = () => {
|
||||
setShowRename(null)
|
||||
}, [])
|
||||
const handleRename = useCallback((newName: string) => {
|
||||
/* v8 ignore next -- defensive guard; onSave is only reachable when showRename is truthy. @preserve */
|
||||
if (showRename)
|
||||
handleRenameConversation(showRename.id, newName, { onSuccess: handleCancelRename })
|
||||
}, [showRename, handleRenameConversation, handleCancelRename])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,18 +1,18 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import type { ConversationItem } from '@/models/share'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import Item from '../item'
|
||||
|
||||
// Mock Operation to verify its usage
|
||||
vi.mock('@/app/components/base/chat/chat-with-history/sidebar/operation', () => ({
|
||||
default: ({ togglePin, onRenameConversation, onDelete, isItemHovering, isActive }: { togglePin: () => void, onRenameConversation: () => void, onDelete: () => void, isItemHovering: boolean, isActive: boolean }) => (
|
||||
default: ({ togglePin, onRenameConversation, onDelete, isItemHovering, isActive, isPinned }: { togglePin: () => void, onRenameConversation: () => void, onDelete: () => void, isItemHovering: boolean, isActive: boolean, isPinned: boolean }) => (
|
||||
<div data-testid="mock-operation">
|
||||
<button onClick={togglePin}>Pin</button>
|
||||
<button onClick={onRenameConversation}>Rename</button>
|
||||
<button onClick={onDelete}>Delete</button>
|
||||
<span data-hovering={isItemHovering}>Hovering</span>
|
||||
<span data-active={isActive}>Active</span>
|
||||
<button onClick={togglePin} data-testid="pin-button">Pin</button>
|
||||
<button onClick={onRenameConversation} data-testid="rename-button">Rename</button>
|
||||
<button onClick={onDelete} data-testid="delete-button">Delete</button>
|
||||
<span data-hovering={isItemHovering} data-testid="hover-indicator">Hovering</span>
|
||||
<span data-active={isActive} data-testid="active-indicator">Active</span>
|
||||
<span data-pinned={isPinned} data-testid="pinned-indicator">Pinned</span>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
@ -36,47 +36,525 @@ describe('Item', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render conversation name', () => {
|
||||
render(<Item {...defaultProps} />)
|
||||
expect(screen.getByText('Test Conversation')).toBeInTheDocument()
|
||||
describe('Rendering', () => {
|
||||
it('should render conversation name', () => {
|
||||
render(<Item {...defaultProps} />)
|
||||
expect(screen.getByText('Test Conversation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with title attribute for truncated text', () => {
|
||||
render(<Item {...defaultProps} />)
|
||||
const nameDiv = screen.getByText('Test Conversation')
|
||||
expect(nameDiv).toHaveAttribute('title', 'Test Conversation')
|
||||
})
|
||||
|
||||
it('should render with different names', () => {
|
||||
const item = { ...mockItem, name: 'Different Conversation' }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
expect(screen.getByText('Different Conversation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with very long name', () => {
|
||||
const longName = 'A'.repeat(500)
|
||||
const item = { ...mockItem, name: longName }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
expect(screen.getByText(longName)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with special characters in name', () => {
|
||||
const item = { ...mockItem, name: 'Chat @#$% 中文' }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
expect(screen.getByText('Chat @#$% 中文')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with empty name', () => {
|
||||
const item = { ...mockItem, name: '' }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
expect(screen.getByTestId('mock-operation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with whitespace-only name', () => {
|
||||
const item = { ...mockItem, name: ' ' }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
const nameElement = screen.getByText((_, element) => element?.getAttribute('title') === ' ')
|
||||
expect(nameElement).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onChangeConversation when clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<Item {...defaultProps} />)
|
||||
describe('Active State', () => {
|
||||
it('should show active state when selected', () => {
|
||||
const { container } = render(<Item {...defaultProps} currentConversationId="1" />)
|
||||
const itemDiv = container.firstChild as HTMLElement
|
||||
expect(itemDiv).toHaveClass('bg-state-accent-active')
|
||||
expect(itemDiv).toHaveClass('text-text-accent')
|
||||
|
||||
await user.click(screen.getByText('Test Conversation'))
|
||||
expect(defaultProps.onChangeConversation).toHaveBeenCalledWith('1')
|
||||
const activeIndicator = screen.getByTestId('active-indicator')
|
||||
expect(activeIndicator).toHaveAttribute('data-active', 'true')
|
||||
})
|
||||
|
||||
it('should not show active state when not selected', () => {
|
||||
const { container } = render(<Item {...defaultProps} currentConversationId="0" />)
|
||||
const itemDiv = container.firstChild as HTMLElement
|
||||
expect(itemDiv).not.toHaveClass('bg-state-accent-active')
|
||||
|
||||
const activeIndicator = screen.getByTestId('active-indicator')
|
||||
expect(activeIndicator).toHaveAttribute('data-active', 'false')
|
||||
})
|
||||
|
||||
it('should toggle active state when currentConversationId changes', () => {
|
||||
const { rerender, container } = render(<Item {...defaultProps} currentConversationId="0" />)
|
||||
expect(container.firstChild).not.toHaveClass('bg-state-accent-active')
|
||||
|
||||
rerender(<Item {...defaultProps} currentConversationId="1" />)
|
||||
expect(container.firstChild).toHaveClass('bg-state-accent-active')
|
||||
|
||||
rerender(<Item {...defaultProps} currentConversationId="0" />)
|
||||
expect(container.firstChild).not.toHaveClass('bg-state-accent-active')
|
||||
})
|
||||
})
|
||||
|
||||
it('should show active state when selected', () => {
|
||||
const { container } = render(<Item {...defaultProps} currentConversationId="1" />)
|
||||
const itemDiv = container.firstChild as HTMLElement
|
||||
expect(itemDiv).toHaveClass('bg-state-accent-active')
|
||||
describe('Pin State', () => {
|
||||
it('should render with isPin true', () => {
|
||||
render(<Item {...defaultProps} isPin={true} />)
|
||||
const pinnedIndicator = screen.getByTestId('pinned-indicator')
|
||||
expect(pinnedIndicator).toHaveAttribute('data-pinned', 'true')
|
||||
})
|
||||
|
||||
const activeIndicator = screen.getByText('Active')
|
||||
expect(activeIndicator).toHaveAttribute('data-active', 'true')
|
||||
it('should render with isPin false', () => {
|
||||
render(<Item {...defaultProps} isPin={false} />)
|
||||
const pinnedIndicator = screen.getByTestId('pinned-indicator')
|
||||
expect(pinnedIndicator).toHaveAttribute('data-pinned', 'false')
|
||||
})
|
||||
|
||||
it('should render with isPin undefined', () => {
|
||||
render(<Item {...defaultProps} />)
|
||||
const pinnedIndicator = screen.getByTestId('pinned-indicator')
|
||||
expect(pinnedIndicator).toHaveAttribute('data-pinned', 'false')
|
||||
})
|
||||
|
||||
it('should call onOperate with unpin when isPinned is true', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
render(<Item {...defaultProps} onOperate={onOperate} isPin={true} />)
|
||||
|
||||
await user.click(screen.getByTestId('pin-button'))
|
||||
expect(onOperate).toHaveBeenCalledWith('unpin', mockItem)
|
||||
})
|
||||
|
||||
it('should call onOperate with pin when isPinned is false', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
render(<Item {...defaultProps} onOperate={onOperate} isPin={false} />)
|
||||
|
||||
await user.click(screen.getByTestId('pin-button'))
|
||||
expect(onOperate).toHaveBeenCalledWith('pin', mockItem)
|
||||
})
|
||||
|
||||
it('should call onOperate with pin when isPin is undefined', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
render(<Item {...defaultProps} onOperate={onOperate} />)
|
||||
|
||||
await user.click(screen.getByTestId('pin-button'))
|
||||
expect(onOperate).toHaveBeenCalledWith('pin', mockItem)
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass correct props to Operation', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<Item {...defaultProps} isPin={true} />)
|
||||
describe('Item ID Handling', () => {
|
||||
it('should show Operation for non-empty id', () => {
|
||||
render(<Item {...defaultProps} item={{ ...mockItem, id: '123' }} />)
|
||||
expect(screen.getByTestId('mock-operation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const operation = screen.getByTestId('mock-operation')
|
||||
expect(operation).toBeInTheDocument()
|
||||
it('should not show Operation for empty id', () => {
|
||||
render(<Item {...defaultProps} item={{ ...mockItem, id: '' }} />)
|
||||
expect(screen.queryByTestId('mock-operation')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText('Pin'))
|
||||
expect(defaultProps.onOperate).toHaveBeenCalledWith('unpin', mockItem)
|
||||
it('should show Operation for id with special characters', () => {
|
||||
render(<Item {...defaultProps} item={{ ...mockItem, id: 'abc-123_xyz' }} />)
|
||||
expect(screen.getByTestId('mock-operation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText('Rename'))
|
||||
expect(defaultProps.onOperate).toHaveBeenCalledWith('rename', mockItem)
|
||||
it('should show Operation for numeric id', () => {
|
||||
render(<Item {...defaultProps} item={{ ...mockItem, id: '999' }} />)
|
||||
expect(screen.getByTestId('mock-operation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText('Delete'))
|
||||
expect(defaultProps.onOperate).toHaveBeenCalledWith('delete', mockItem)
|
||||
it('should show Operation for uuid-like id', () => {
|
||||
const uuid = '123e4567-e89b-12d3-a456-426614174000'
|
||||
render(<Item {...defaultProps} item={{ ...mockItem, id: uuid }} />)
|
||||
expect(screen.getByTestId('mock-operation')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not show Operation for empty id items', () => {
|
||||
render(<Item {...defaultProps} item={{ ...mockItem, id: '' }} />)
|
||||
expect(screen.queryByTestId('mock-operation')).not.toBeInTheDocument()
|
||||
describe('Click Interactions', () => {
|
||||
it('should call onChangeConversation when clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onChangeConversation = vi.fn()
|
||||
render(<Item {...defaultProps} onChangeConversation={onChangeConversation} />)
|
||||
|
||||
await user.click(screen.getByText('Test Conversation'))
|
||||
expect(onChangeConversation).toHaveBeenCalledWith('1')
|
||||
})
|
||||
|
||||
it('should call onChangeConversation with correct id', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onChangeConversation = vi.fn()
|
||||
const item = { ...mockItem, id: 'custom-id' }
|
||||
render(<Item {...defaultProps} item={item} onChangeConversation={onChangeConversation} />)
|
||||
|
||||
await user.click(screen.getByText('Test Conversation'))
|
||||
expect(onChangeConversation).toHaveBeenCalledWith('custom-id')
|
||||
})
|
||||
|
||||
it('should not propagate click to parent when Operation button is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onChangeConversation = vi.fn()
|
||||
render(<Item {...defaultProps} onChangeConversation={onChangeConversation} />)
|
||||
|
||||
const deleteButton = screen.getByTestId('delete-button')
|
||||
await user.click(deleteButton)
|
||||
|
||||
// onChangeConversation should not be called when Operation button is clicked
|
||||
expect(onChangeConversation).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onOperate with delete when delete button clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
render(<Item {...defaultProps} onOperate={onOperate} />)
|
||||
|
||||
await user.click(screen.getByTestId('delete-button'))
|
||||
expect(onOperate).toHaveBeenCalledWith('delete', mockItem)
|
||||
})
|
||||
|
||||
it('should call onOperate with rename when rename button clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
render(<Item {...defaultProps} onOperate={onOperate} />)
|
||||
|
||||
await user.click(screen.getByTestId('rename-button'))
|
||||
expect(onOperate).toHaveBeenCalledWith('rename', mockItem)
|
||||
})
|
||||
|
||||
it('should handle multiple rapid clicks on different operations', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
render(<Item {...defaultProps} onOperate={onOperate} />)
|
||||
|
||||
await user.click(screen.getByTestId('rename-button'))
|
||||
await user.click(screen.getByTestId('pin-button'))
|
||||
await user.click(screen.getByTestId('delete-button'))
|
||||
|
||||
expect(onOperate).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('should call onChangeConversation only once on single click', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onChangeConversation = vi.fn()
|
||||
render(<Item {...defaultProps} onChangeConversation={onChangeConversation} />)
|
||||
|
||||
await user.click(screen.getByText('Test Conversation'))
|
||||
expect(onChangeConversation).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onChangeConversation multiple times on multiple clicks', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onChangeConversation = vi.fn()
|
||||
render(<Item {...defaultProps} onChangeConversation={onChangeConversation} />)
|
||||
|
||||
await user.click(screen.getByText('Test Conversation'))
|
||||
await user.click(screen.getByText('Test Conversation'))
|
||||
await user.click(screen.getByText('Test Conversation'))
|
||||
|
||||
expect(onChangeConversation).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Operation Buttons', () => {
|
||||
it('should show Operation when item.id is not empty', () => {
|
||||
render(<Item {...defaultProps} />)
|
||||
expect(screen.getByTestId('mock-operation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass correct props to Operation', async () => {
|
||||
render(<Item {...defaultProps} isPin={true} currentConversationId="1" />)
|
||||
|
||||
const operation = screen.getByTestId('mock-operation')
|
||||
expect(operation).toBeInTheDocument()
|
||||
|
||||
const activeIndicator = screen.getByTestId('active-indicator')
|
||||
expect(activeIndicator).toHaveAttribute('data-active', 'true')
|
||||
|
||||
const pinnedIndicator = screen.getByTestId('pinned-indicator')
|
||||
expect(pinnedIndicator).toHaveAttribute('data-pinned', 'true')
|
||||
})
|
||||
|
||||
it('should handle all three operation types sequentially', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
render(<Item {...defaultProps} onOperate={onOperate} />)
|
||||
|
||||
await user.click(screen.getByTestId('rename-button'))
|
||||
expect(onOperate).toHaveBeenNthCalledWith(1, 'rename', mockItem)
|
||||
|
||||
await user.click(screen.getByTestId('pin-button'))
|
||||
expect(onOperate).toHaveBeenNthCalledWith(2, 'pin', mockItem)
|
||||
|
||||
await user.click(screen.getByTestId('delete-button'))
|
||||
expect(onOperate).toHaveBeenNthCalledWith(3, 'delete', mockItem)
|
||||
})
|
||||
|
||||
it('should handle pin toggle between pin and unpin', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
|
||||
const { rerender } = render(
|
||||
<Item {...defaultProps} onOperate={onOperate} isPin={false} />,
|
||||
)
|
||||
|
||||
await user.click(screen.getByTestId('pin-button'))
|
||||
expect(onOperate).toHaveBeenCalledWith('pin', mockItem)
|
||||
|
||||
rerender(<Item {...defaultProps} onOperate={onOperate} isPin={true} />)
|
||||
|
||||
await user.click(screen.getByTestId('pin-button'))
|
||||
expect(onOperate).toHaveBeenCalledWith('unpin', mockItem)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Styling', () => {
|
||||
it('should have base classes on container', () => {
|
||||
const { container } = render(<Item {...defaultProps} />)
|
||||
const itemDiv = container.firstChild as HTMLElement
|
||||
|
||||
expect(itemDiv).toHaveClass('group')
|
||||
expect(itemDiv).toHaveClass('flex')
|
||||
expect(itemDiv).toHaveClass('cursor-pointer')
|
||||
expect(itemDiv).toHaveClass('rounded-lg')
|
||||
})
|
||||
|
||||
it('should apply active state classes when selected', () => {
|
||||
const { container } = render(<Item {...defaultProps} currentConversationId="1" />)
|
||||
const itemDiv = container.firstChild as HTMLElement
|
||||
|
||||
expect(itemDiv).toHaveClass('bg-state-accent-active')
|
||||
expect(itemDiv).toHaveClass('text-text-accent')
|
||||
})
|
||||
|
||||
it('should apply hover classes', () => {
|
||||
const { container } = render(<Item {...defaultProps} />)
|
||||
const itemDiv = container.firstChild as HTMLElement
|
||||
|
||||
expect(itemDiv).toHaveClass('hover:bg-state-base-hover')
|
||||
})
|
||||
|
||||
it('should maintain hover classes when active', () => {
|
||||
const { container } = render(<Item {...defaultProps} currentConversationId="1" />)
|
||||
const itemDiv = container.firstChild as HTMLElement
|
||||
|
||||
expect(itemDiv).toHaveClass('hover:bg-state-accent-active')
|
||||
})
|
||||
|
||||
it('should apply truncate class to text container', () => {
|
||||
const { container } = render(<Item {...defaultProps} />)
|
||||
const textDiv = container.querySelector('.grow.truncate')
|
||||
|
||||
expect(textDiv).toHaveClass('truncate')
|
||||
expect(textDiv).toHaveClass('grow')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Props Updates', () => {
|
||||
it('should update when item prop changes', () => {
|
||||
const { rerender } = render(<Item {...defaultProps} item={mockItem} />)
|
||||
|
||||
expect(screen.getByText('Test Conversation')).toBeInTheDocument()
|
||||
|
||||
const newItem = { ...mockItem, name: 'Updated Conversation' }
|
||||
rerender(<Item {...defaultProps} item={newItem} />)
|
||||
|
||||
expect(screen.getByText('Updated Conversation')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Test Conversation')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update when currentConversationId changes', () => {
|
||||
const { container, rerender } = render(
|
||||
<Item {...defaultProps} currentConversationId="0" />,
|
||||
)
|
||||
|
||||
expect(container.firstChild).not.toHaveClass('bg-state-accent-active')
|
||||
|
||||
rerender(<Item {...defaultProps} currentConversationId="1" />)
|
||||
|
||||
expect(container.firstChild).toHaveClass('bg-state-accent-active')
|
||||
})
|
||||
|
||||
it('should update when isPin changes', () => {
|
||||
const { rerender } = render(<Item {...defaultProps} isPin={false} />)
|
||||
|
||||
let pinnedIndicator = screen.getByTestId('pinned-indicator')
|
||||
expect(pinnedIndicator).toHaveAttribute('data-pinned', 'false')
|
||||
|
||||
rerender(<Item {...defaultProps} isPin={true} />)
|
||||
|
||||
pinnedIndicator = screen.getByTestId('pinned-indicator')
|
||||
expect(pinnedIndicator).toHaveAttribute('data-pinned', 'true')
|
||||
})
|
||||
|
||||
it('should update when callbacks change', async () => {
|
||||
const user = userEvent.setup()
|
||||
const oldOnOperate = vi.fn()
|
||||
const newOnOperate = vi.fn()
|
||||
|
||||
const { rerender } = render(<Item {...defaultProps} onOperate={oldOnOperate} />)
|
||||
|
||||
rerender(<Item {...defaultProps} onOperate={newOnOperate} />)
|
||||
|
||||
await user.click(screen.getByTestId('delete-button'))
|
||||
|
||||
expect(newOnOperate).toHaveBeenCalledWith('delete', mockItem)
|
||||
expect(oldOnOperate).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should update when multiple props change together', () => {
|
||||
const { rerender } = render(
|
||||
<Item
|
||||
{...defaultProps}
|
||||
item={mockItem}
|
||||
currentConversationId="0"
|
||||
isPin={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
const newItem = { ...mockItem, name: 'New Name', id: '2' }
|
||||
rerender(
|
||||
<Item
|
||||
{...defaultProps}
|
||||
item={newItem}
|
||||
currentConversationId="2"
|
||||
isPin={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('New Name')).toBeInTheDocument()
|
||||
|
||||
const activeIndicator = screen.getByTestId('active-indicator')
|
||||
expect(activeIndicator).toHaveAttribute('data-active', 'true')
|
||||
|
||||
const pinnedIndicator = screen.getByTestId('pinned-indicator')
|
||||
expect(pinnedIndicator).toHaveAttribute('data-pinned', 'true')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Item with Different Data', () => {
|
||||
it('should handle item with all properties', () => {
|
||||
const item = {
|
||||
id: 'full-item',
|
||||
name: 'Full Item Name',
|
||||
inputs: { key: 'value' },
|
||||
introduction: 'Some introduction',
|
||||
}
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
|
||||
expect(screen.getByText('Full Item Name')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle item with minimal properties', () => {
|
||||
const item = {
|
||||
id: '1',
|
||||
name: 'Minimal',
|
||||
} as unknown as ConversationItem
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
|
||||
expect(screen.getByText('Minimal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle multiple items rendered separately', () => {
|
||||
const item1 = { ...mockItem, id: '1', name: 'First' }
|
||||
const item2 = { ...mockItem, id: '2', name: 'Second' }
|
||||
|
||||
const { rerender } = render(<Item {...defaultProps} item={item1} />)
|
||||
expect(screen.getByText('First')).toBeInTheDocument()
|
||||
|
||||
rerender(<Item {...defaultProps} item={item2} />)
|
||||
expect(screen.getByText('Second')).toBeInTheDocument()
|
||||
expect(screen.queryByText('First')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Hover State', () => {
|
||||
it('should pass hover state to Operation when hovering', async () => {
|
||||
const { container } = render(<Item {...defaultProps} />)
|
||||
const row = container.firstChild as HTMLElement
|
||||
const hoverIndicator = screen.getByTestId('hover-indicator')
|
||||
|
||||
expect(hoverIndicator.getAttribute('data-hovering')).toBe('false')
|
||||
|
||||
fireEvent.mouseEnter(row)
|
||||
expect(hoverIndicator.getAttribute('data-hovering')).toBe('true')
|
||||
|
||||
fireEvent.mouseLeave(row)
|
||||
expect(hoverIndicator.getAttribute('data-hovering')).toBe('false')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle item with unicode name', () => {
|
||||
const item = { ...mockItem, name: '🎉 Celebration Chat 中文版' }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
expect(screen.getByText('🎉 Celebration Chat 中文版')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle item with numeric id as string', () => {
|
||||
const item = { ...mockItem, id: '12345' }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
expect(screen.getByTestId('mock-operation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle rapid isPin prop changes', () => {
|
||||
const { rerender } = render(<Item {...defaultProps} isPin={true} />)
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
rerender(<Item {...defaultProps} isPin={i % 2 === 0} />)
|
||||
}
|
||||
|
||||
const pinnedIndicator = screen.getByTestId('pinned-indicator')
|
||||
expect(pinnedIndicator).toHaveAttribute('data-pinned', 'true')
|
||||
})
|
||||
|
||||
it('should handle item name with HTML-like content', () => {
|
||||
const item = { ...mockItem, name: '<script>alert("xss")</script>' }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
// Should render as text, not execute
|
||||
expect(screen.getByText('<script>alert("xss")</script>')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle very long item id', () => {
|
||||
const longId = 'a'.repeat(1000)
|
||||
const item = { ...mockItem, id: longId }
|
||||
render(<Item {...defaultProps} item={item} />)
|
||||
expect(screen.getByTestId('mock-operation')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Memoization', () => {
|
||||
it('should not re-render when same props are passed', () => {
|
||||
const { rerender } = render(<Item {...defaultProps} />)
|
||||
const element = screen.getByText('Test Conversation')
|
||||
|
||||
rerender(<Item {...defaultProps} />)
|
||||
expect(screen.getByText('Test Conversation')).toBe(element)
|
||||
})
|
||||
|
||||
it('should re-render when item changes', () => {
|
||||
const { rerender } = render(<Item {...defaultProps} item={mockItem} />)
|
||||
|
||||
const newItem = { ...mockItem, name: 'Changed' }
|
||||
rerender(<Item {...defaultProps} item={newItem} />)
|
||||
|
||||
expect(screen.getByText('Changed')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,9 +1,30 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import * as ReactI18next from 'react-i18next'
|
||||
import RenameModal from '../rename-modal'
|
||||
|
||||
vi.mock('@/app/components/base/modal', () => ({
|
||||
default: ({
|
||||
title,
|
||||
isShow,
|
||||
children,
|
||||
}: {
|
||||
title: ReactNode
|
||||
isShow: boolean
|
||||
children: ReactNode
|
||||
}) => {
|
||||
if (!isShow)
|
||||
return null
|
||||
return (
|
||||
<div role="dialog">
|
||||
<h2>{title}</h2>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
describe('RenameModal', () => {
|
||||
const defaultProps = {
|
||||
isShow: true,
|
||||
@ -17,58 +38,106 @@ describe('RenameModal', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render with initial name', () => {
|
||||
it('renders title, label, input and action buttons', () => {
|
||||
render(<RenameModal {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('common.chat.renameConversation')).toBeInTheDocument()
|
||||
expect(screen.getByDisplayValue('Original Name')).toBeInTheDocument()
|
||||
expect(screen.getByPlaceholderText('common.chat.conversationNamePlaceholder')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.chat.conversationName')).toBeInTheDocument()
|
||||
expect(screen.getByPlaceholderText('common.chat.conversationNamePlaceholder')).toHaveValue('Original Name')
|
||||
expect(screen.getByText('common.operation.cancel')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.operation.save')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update text when typing', async () => {
|
||||
it('does not render when isShow is false', () => {
|
||||
render(<RenameModal {...defaultProps} isShow={false} />)
|
||||
expect(screen.queryByText('common.chat.renameConversation')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('calls onClose when cancel is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<RenameModal {...defaultProps} />)
|
||||
|
||||
const input = screen.getByDisplayValue('Original Name')
|
||||
await user.clear(input)
|
||||
await user.type(input, 'New Name')
|
||||
|
||||
expect(input).toHaveValue('New Name')
|
||||
await user.click(screen.getByText('common.operation.cancel'))
|
||||
expect(defaultProps.onClose).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onSave with new name when save button is clicked', async () => {
|
||||
it('calls onSave with updated name', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<RenameModal {...defaultProps} />)
|
||||
|
||||
const input = screen.getByDisplayValue('Original Name')
|
||||
const input = screen.getByRole('textbox')
|
||||
await user.clear(input)
|
||||
await user.type(input, 'Updated Name')
|
||||
|
||||
const saveButton = screen.getByText('common.operation.save')
|
||||
await user.click(saveButton)
|
||||
await user.click(screen.getByText('common.operation.save'))
|
||||
|
||||
expect(defaultProps.onSave).toHaveBeenCalledWith('Updated Name')
|
||||
})
|
||||
|
||||
it('should call onClose when cancel button is clicked', async () => {
|
||||
it('calls onSave with initial name when unchanged', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<RenameModal {...defaultProps} />)
|
||||
|
||||
const cancelButton = screen.getByText('common.operation.cancel')
|
||||
await user.click(cancelButton)
|
||||
|
||||
expect(defaultProps.onClose).toHaveBeenCalled()
|
||||
await user.click(screen.getByText('common.operation.save'))
|
||||
expect(defaultProps.onSave).toHaveBeenCalledWith('Original Name')
|
||||
})
|
||||
|
||||
it('should show loading state on save button', () => {
|
||||
render(<RenameModal {...defaultProps} saveLoading={true} />)
|
||||
|
||||
// The Button component with loading=true renders a status role (spinner)
|
||||
it('shows loading state when saveLoading is true', () => {
|
||||
render(<RenameModal {...defaultProps} saveLoading />)
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render when isShow is false', () => {
|
||||
const { queryByText } = render(<RenameModal {...defaultProps} isShow={false} />)
|
||||
expect(queryByText('common.chat.renameConversation')).not.toBeInTheDocument()
|
||||
it('hides loading state when saveLoading is false', () => {
|
||||
render(<RenameModal {...defaultProps} saveLoading={false} />)
|
||||
expect(screen.queryByRole('status')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('keeps edited name when parent rerenders with different name prop', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { rerender } = render(<RenameModal {...defaultProps} name="First" />)
|
||||
|
||||
const input = screen.getByRole('textbox')
|
||||
await user.clear(input)
|
||||
await user.type(input, 'Edited')
|
||||
|
||||
rerender(<RenameModal {...defaultProps} name="Second" />)
|
||||
expect(screen.getByRole('textbox')).toHaveValue('Edited')
|
||||
})
|
||||
|
||||
it('retains typed state after isShow false then true on same component instance', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { rerender } = render(<RenameModal {...defaultProps} isShow />)
|
||||
|
||||
const input = screen.getByRole('textbox')
|
||||
await user.clear(input)
|
||||
await user.type(input, 'Changed')
|
||||
|
||||
rerender(<RenameModal {...defaultProps} isShow={false} />)
|
||||
rerender(<RenameModal {...defaultProps} isShow />)
|
||||
|
||||
expect(screen.getByRole('textbox')).toHaveValue('Changed')
|
||||
})
|
||||
|
||||
it('uses empty placeholder fallback when translation returns empty string', () => {
|
||||
const originalUseTranslation = ReactI18next.useTranslation
|
||||
const useTranslationSpy = vi.spyOn(ReactI18next, 'useTranslation').mockImplementation((...args) => {
|
||||
const translation = originalUseTranslation(...args)
|
||||
return {
|
||||
...translation,
|
||||
t: ((key: string, options?: Record<string, unknown>) => {
|
||||
if (key === 'chat.conversationNamePlaceholder')
|
||||
return ''
|
||||
const ns = options?.ns as string | undefined
|
||||
return ns ? `${ns}.${key}` : key
|
||||
}) as typeof translation.t,
|
||||
}
|
||||
})
|
||||
|
||||
try {
|
||||
render(<RenameModal {...defaultProps} />)
|
||||
expect(screen.getByPlaceholderText('')).toBeInTheDocument()
|
||||
}
|
||||
finally {
|
||||
useTranslationSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@ -78,6 +78,8 @@ const Sidebar = ({ isPanel, panelVisible }: Props) => {
|
||||
if (showRename)
|
||||
handleRenameConversation(showRename.id, newName, { onSuccess: handleCancelRename })
|
||||
}, [showRename, handleRenameConversation, handleCancelRename])
|
||||
const pinnedTitle = t('chat.pinnedTitle', { ns: 'share' }) || ''
|
||||
const deleteConversationContent = t('chat.deleteConversation.content', { ns: 'share' }) || ''
|
||||
|
||||
return (
|
||||
<div className={cn(
|
||||
@ -122,7 +124,7 @@ const Sidebar = ({ isPanel, panelVisible }: Props) => {
|
||||
<div className="mb-4">
|
||||
<List
|
||||
isPin
|
||||
title={t('chat.pinnedTitle', { ns: 'share' }) || ''}
|
||||
title={pinnedTitle}
|
||||
list={pinnedConversationList}
|
||||
onChangeConversation={handleChangeConversation}
|
||||
onOperate={handleOperate}
|
||||
@ -168,7 +170,7 @@ const Sidebar = ({ isPanel, panelVisible }: Props) => {
|
||||
{!!showConfirm && (
|
||||
<Confirm
|
||||
title={t('chat.deleteConversation.title', { ns: 'share' })}
|
||||
content={t('chat.deleteConversation.content', { ns: 'share' }) || ''}
|
||||
content={deleteConversationContent}
|
||||
isShow
|
||||
onCancel={handleCancelConfirm}
|
||||
onConfirm={handleDelete}
|
||||
|
||||
@ -24,6 +24,7 @@ const RenameModal: FC<IRenameModalProps> = ({
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [tempName, setTempName] = useState(name)
|
||||
const conversationNamePlaceholder = t('chat.conversationNamePlaceholder', { ns: 'common' }) || ''
|
||||
|
||||
return (
|
||||
<Modal
|
||||
@ -36,7 +37,7 @@ const RenameModal: FC<IRenameModalProps> = ({
|
||||
className="mt-2 h-10 w-full"
|
||||
value={tempName}
|
||||
onChange={e => setTempName(e.target.value)}
|
||||
placeholder={t('chat.conversationNamePlaceholder', { ns: 'common' }) || ''}
|
||||
placeholder={conversationNamePlaceholder}
|
||||
/>
|
||||
|
||||
<div className="mt-10 flex justify-end">
|
||||
|
||||
@ -2,6 +2,7 @@ import type { ChatConfig, ChatItemInTree } from '../../types'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useParams, usePathname } from 'next/navigation'
|
||||
import { WorkflowRunningStatus } from '@/app/components/workflow/types'
|
||||
import { sseGet, ssePost } from '@/service/base'
|
||||
import { useChat } from '../hooks'
|
||||
|
||||
@ -1378,22 +1379,884 @@ describe('useChat', () => {
|
||||
}]
|
||||
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, nestedTree as ChatItemInTree[]))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSwitchSibling('a-deep', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
expect(sseGet).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should do nothing when switching to a sibling message that does not exist', () => {
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
|
||||
act(() => {
|
||||
result.current.handleSwitchSibling('missing-message-id', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
expect(sseGet).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Uncovered edge cases', () => {
|
||||
it('should handle onFile fallbacks for audio, video, bin types', () => {
|
||||
let callbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'file types' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
callbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1', message_id: 'm-files' })
|
||||
|
||||
// No transferMethod, type: video
|
||||
callbacks.onFile({ id: 'f-vid', type: 'video', url: 'vid.mp4' })
|
||||
// No transferMethod, type: audio
|
||||
callbacks.onFile({ id: 'f-aud', type: 'audio', url: 'aud.mp3' })
|
||||
// No transferMethod, type: bin
|
||||
callbacks.onFile({ id: 'f-bin', type: 'bin', url: 'file.bin' })
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.message_files).toHaveLength(3)
|
||||
expect(lastResponse.message_files![0].type).toBe('video/mp4')
|
||||
expect(lastResponse.message_files![0].supportFileType).toBe('video')
|
||||
expect(lastResponse.message_files![1].type).toBe('audio/mpeg')
|
||||
expect(lastResponse.message_files![1].supportFileType).toBe('audio')
|
||||
expect(lastResponse.message_files![2].type).toBe('application/octet-stream')
|
||||
expect(lastResponse.message_files![2].supportFileType).toBe('document')
|
||||
})
|
||||
|
||||
it('should handle onMessageEnd empty citation and empty processed files fallbacks', () => {
|
||||
let callbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'citations' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
callbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1', message_id: 'm-cite' })
|
||||
callbacks.onMessageEnd({ id: 'm-cite', metadata: {} }) // No retriever_resources or annotation_reply
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.citation).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle iteration and loop tracing edge cases (lazy arrays, node finish index -1)', () => {
|
||||
let callbacks: HookCallbacks
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const prevChatTree = [{
|
||||
id: 'q-trace',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-trace',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
workflowProcess: { status: WorkflowRunningStatus.Running }, // Omit tracing array to test fallback
|
||||
}],
|
||||
}]
|
||||
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result.current.handleResume('m-trace', 'wr-trace', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// onIterationStart should create the tracing array
|
||||
callbacks.onIterationStart({ data: { node_id: 'iter-1' } })
|
||||
})
|
||||
|
||||
const prevChatTree2 = [{
|
||||
id: 'q-trace2',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-trace',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
workflowProcess: { status: WorkflowRunningStatus.Running }, // Omit tracing array to test fallback
|
||||
}],
|
||||
}]
|
||||
|
||||
const { result: result2 } = renderHook(() => useChat(undefined, undefined, prevChatTree2 as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result2.current.handleResume('m-trace', 'wr-trace2', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// onNodeStarted should create the tracing array
|
||||
callbacks.onNodeStarted({ data: { node_id: 'n-1', id: 'n-1' } })
|
||||
})
|
||||
|
||||
const prevChatTree3 = [{
|
||||
id: 'q-trace3',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-trace',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
workflowProcess: { status: WorkflowRunningStatus.Running }, // Omit tracing array to test fallback
|
||||
}],
|
||||
}]
|
||||
|
||||
const { result: result3 } = renderHook(() => useChat(undefined, undefined, prevChatTree3 as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result3.current.handleResume('m-trace', 'wr-trace3', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// onLoopStart should create the tracing array
|
||||
callbacks.onLoopStart({ data: { node_id: 'loop-1' } })
|
||||
})
|
||||
|
||||
// Ensure the tracing array exists and holds the loop item
|
||||
const lastResponse = result3.current.chatList[1]
|
||||
expect(lastResponse.workflowProcess?.tracing).toBeDefined()
|
||||
expect(lastResponse.workflowProcess?.tracing).toHaveLength(1)
|
||||
expect(lastResponse.workflowProcess?.tracing![0].node_id).toBe('loop-1')
|
||||
})
|
||||
|
||||
it('should handle onCompleted fallback to answer when agent thought does not match and provider latency is 0', async () => {
|
||||
let callbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const onGetConversationMessages = vi.fn().mockResolvedValue({
|
||||
data: [{
|
||||
id: 'm-completed',
|
||||
answer: 'final answer',
|
||||
message: [{ role: 'user', text: 'hi' }],
|
||||
agent_thoughts: [{ thought: 'thinking different from answer' }],
|
||||
created_at: Date.now(),
|
||||
answer_tokens: 10,
|
||||
message_tokens: 5,
|
||||
provider_response_latency: 0,
|
||||
inputs: {},
|
||||
query: 'hi',
|
||||
}],
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('test-url', { query: 'fetch test latency zero' }, {
|
||||
onGetConversationMessages,
|
||||
})
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
callbacks.onData(' data', true, { messageId: 'm-completed', conversationId: 'c-latency' })
|
||||
await callbacks.onCompleted()
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.content).toBe('final answer')
|
||||
expect(lastResponse.more?.latency).toBe('0.00')
|
||||
expect(lastResponse.more?.tokens_per_second).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle onCompleted using agent thought when thought matches answer', async () => {
|
||||
let callbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const onGetConversationMessages = vi.fn().mockResolvedValue({
|
||||
data: [{
|
||||
id: 'm-matched',
|
||||
answer: 'matched thought',
|
||||
message: [{ role: 'user', text: 'hi' }],
|
||||
agent_thoughts: [{ thought: 'matched thought' }],
|
||||
created_at: Date.now(),
|
||||
answer_tokens: 10,
|
||||
message_tokens: 5,
|
||||
provider_response_latency: 0.5,
|
||||
inputs: {},
|
||||
query: 'hi',
|
||||
}],
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('test-url', { query: 'fetch test match thought' }, {
|
||||
onGetConversationMessages,
|
||||
})
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
callbacks.onData(' data', true, { messageId: 'm-matched', conversationId: 'c-matched' })
|
||||
await callbacks.onCompleted()
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.content).toBe('') // isUseAgentThought sets content to empty string
|
||||
})
|
||||
|
||||
it('should cover pausedStateRef reset on workflowFinished and missing tracing arrays in node finish / human input', () => {
|
||||
let callbacks: HookCallbacks
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const prevChatTree = [{
|
||||
id: 'q-pause',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-pause',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
workflowProcess: { status: WorkflowRunningStatus.Running }, // Omit tracing
|
||||
}],
|
||||
}]
|
||||
|
||||
// Setup test for workflow paused + finished
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result.current.handleResume('m-pause', 'wr-1', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// Trigger a pause to set pausedStateRef = true
|
||||
callbacks.onWorkflowPaused({ data: { workflow_run_id: 'wr-1' } })
|
||||
|
||||
// workflowFinished should reset pausedStateRef to false
|
||||
callbacks.onWorkflowFinished({ data: { status: 'succeeded' } })
|
||||
|
||||
// Missing tracing array onNodeFinished early return
|
||||
callbacks.onNodeFinished({ data: { id: 'n-none' } })
|
||||
|
||||
// Missing tracing array fallback for human input
|
||||
callbacks.onHumanInputRequired({ data: { node_id: 'h-1' } })
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.workflowProcess?.status).toBe('succeeded')
|
||||
})
|
||||
|
||||
it('should cover onThought creating tracing and appending message correctly when isAgentMode=true', () => {
|
||||
let callbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
callbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'agent onThought' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
callbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' })
|
||||
|
||||
// onThought when array is implicitly empty
|
||||
callbacks.onThought({ id: 'th-1', thought: 'initial thought' })
|
||||
|
||||
// onData which appends to last thought
|
||||
callbacks.onData(' appended', false, { messageId: 'm-thought' })
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[result.current.chatList.length - 1]
|
||||
expect(lastResponse.agent_thoughts).toHaveLength(1)
|
||||
expect(lastResponse.agent_thoughts![0].thought).toBe('initial thought appended')
|
||||
})
|
||||
})
|
||||
|
||||
it('should cover produceChatTreeNode traversing deeply nested child nodes to find the target item', () => {
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, _options) => { })
|
||||
|
||||
const nestedTree = [{
|
||||
id: 'q-root',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'a-root',
|
||||
content: 'answer root',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
children: [{
|
||||
id: 'q-deep',
|
||||
content: 'deep question',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'a-deep',
|
||||
content: 'deep answer to find',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
}],
|
||||
}],
|
||||
}],
|
||||
}]
|
||||
|
||||
// Render the chat with the nested tree
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, nestedTree as ChatItemInTree[]))
|
||||
|
||||
// Setting TargetNodeId triggers state update using produceChatTreeNode internally
|
||||
act(() => {
|
||||
// AnnotationEdited uses produceChatTreeNode to find target Question/Answer nodes
|
||||
result.current.handleAnnotationRemoved(3)
|
||||
})
|
||||
|
||||
// We just care that the tree traversal didn't crash
|
||||
expect(result.current.chatList).toHaveLength(4)
|
||||
})
|
||||
|
||||
it('should cover baseFile with transferMethod and without file type in handleResume and handleSend', () => {
|
||||
let resumeCallbacks: HookCallbacks
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, options) => {
|
||||
resumeCallbacks = options as HookCallbacks
|
||||
})
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test base file' }, {})
|
||||
})
|
||||
|
||||
const prevChatTree = [{
|
||||
id: 'q-resume',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-resume',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
}],
|
||||
}]
|
||||
const { result: resumeResult } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
resumeResult.current.handleResume('m-resume', 'wr-1', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
const fileWithMethodAndNoType = {
|
||||
id: 'f-1',
|
||||
transferMethod: 'remote_url',
|
||||
type: undefined,
|
||||
name: 'uploaded.png',
|
||||
}
|
||||
sendCallbacks.onFile(fileWithMethodAndNoType)
|
||||
resumeCallbacks.onFile(fileWithMethodAndNoType)
|
||||
|
||||
// Test the inner condition in handleSend `!isAgentMode` where we also push to current files
|
||||
sendCallbacks.onFile(fileWithMethodAndNoType)
|
||||
})
|
||||
|
||||
const lastSendResponse = result.current.chatList[1]
|
||||
expect(lastSendResponse.message_files).toHaveLength(2)
|
||||
|
||||
const lastResumeResponse = resumeResult.current.chatList[1]
|
||||
expect(lastResumeResponse.message_files).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should cover parallel_id tracing matches in iteration and loop finish', () => {
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test parallel_id' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
sendCallbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' })
|
||||
|
||||
// parallel_id in execution_metadata
|
||||
sendCallbacks.onIterationStart({ data: { node_id: 'iter-1', execution_metadata: { parallel_id: 'pid-1' } } })
|
||||
sendCallbacks.onIterationFinish({ data: { node_id: 'iter-1', execution_metadata: { parallel_id: 'pid-1' }, status: 'succeeded' } })
|
||||
|
||||
// no parallel_id
|
||||
sendCallbacks.onLoopStart({ data: { node_id: 'loop-1' } })
|
||||
sendCallbacks.onLoopFinish({ data: { node_id: 'loop-1', status: 'succeeded' } })
|
||||
|
||||
// parallel_id in root item but finish has it in execution_metadata
|
||||
sendCallbacks.onNodeStarted({ data: { node_id: 'n-1', id: 'n-1', parallel_id: 'pid-2' } })
|
||||
sendCallbacks.onNodeFinished({ data: { node_id: 'n-1', id: 'n-1', execution_metadata: { parallel_id: 'pid-2' } } })
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
const tracing = lastResponse.workflowProcess!.tracing!
|
||||
expect(tracing).toHaveLength(3)
|
||||
expect(tracing[0].status).toBe('succeeded')
|
||||
expect(tracing[1].status).toBe('succeeded')
|
||||
})
|
||||
|
||||
it('should cover baseFile with ALL fields, avoiding all fallbacks', () => {
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test exact file' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
sendCallbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' })
|
||||
sendCallbacks.onFile({
|
||||
id: 'exact-1',
|
||||
type: 'custom/mime',
|
||||
transferMethod: 'local_file',
|
||||
url: 'exact.url',
|
||||
supportFileType: 'blob',
|
||||
progress: 50,
|
||||
name: 'exact.name',
|
||||
size: 1024,
|
||||
})
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[result.current.chatList.length - 1]
|
||||
expect(lastResponse.message_files).toHaveLength(1)
|
||||
expect(lastResponse.message_files![0].type).toBe('custom/mime')
|
||||
expect(lastResponse.message_files![0].size).toBe(1024)
|
||||
})
|
||||
|
||||
it('should cover handleResume missing branches for onMessageEnd, onFile fallbacks, and workflow edges', () => {
|
||||
let resumeCallbacks: HookCallbacks
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, options) => {
|
||||
resumeCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const prevChatTree = [{
|
||||
id: 'q-data',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-data',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
}],
|
||||
}]
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result.current.handleResume('m-data', 'wr-1', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// messageId undefined
|
||||
resumeCallbacks.onData(' more data', false, { conversationId: 'c-1', taskId: 't-1' })
|
||||
|
||||
// onFile audio video bin fallbacks
|
||||
resumeCallbacks.onFile({ id: 'f-vid', type: 'video', url: 'vid.mp4' })
|
||||
resumeCallbacks.onFile({ id: 'f-aud', type: 'audio', url: 'aud.mp3' })
|
||||
resumeCallbacks.onFile({ id: 'f-bin', type: 'bin', url: 'file.bin' })
|
||||
|
||||
// onMessageEnd missing annotation and citation
|
||||
resumeCallbacks.onMessageEnd({ id: 'm-end', metadata: {} } as Record<string, unknown>)
|
||||
|
||||
// onThought fallback missing message_id
|
||||
resumeCallbacks.onThought({ thought: 'missing message id', message_files: [] } as Record<string, unknown>)
|
||||
|
||||
// onHumanInputFormTimeout missing length
|
||||
resumeCallbacks.onHumanInputFormTimeout({ data: { node_id: 'timeout-id' } })
|
||||
|
||||
// Empty file list
|
||||
result.current.chatList[1].message_files = undefined
|
||||
// Call onFile while agent_thoughts is empty/undefined to hit the `else` fallback branch
|
||||
resumeCallbacks.onFile({ id: 'f-agent', type: 'image', url: 'agent.png' })
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.message_files![0]).toBeDefined()
|
||||
})
|
||||
|
||||
it('should cover edge case where node_id is missing or index is -1 in handleResume onNodeFinished and onLoopFinish', () => {
|
||||
let resumeCallbacks: HookCallbacks
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, options) => {
|
||||
resumeCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const prevChatTree = [{
|
||||
id: 'q-index',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-index',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
workflowProcess: { status: WorkflowRunningStatus.Running, tracing: [] },
|
||||
}],
|
||||
}]
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result.current.handleResume('m-index', 'wr-1', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// ID doesn't exist in tracing
|
||||
resumeCallbacks.onNodeFinished({ data: { id: 'missing', execution_metadata: { parallel_id: 'missing-pid' } } })
|
||||
|
||||
// Node ID doesn't exist in tracing
|
||||
resumeCallbacks.onLoopFinish({ data: { node_id: 'missing-loop', status: 'succeeded' } })
|
||||
|
||||
// Parallel ID doesn't match
|
||||
resumeCallbacks.onIterationFinish({ data: { node_id: 'missing-iter', execution_metadata: { parallel_id: 'missing-pid' }, status: 'succeeded' } })
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.workflowProcess?.tracing).toHaveLength(0) // None were updated
|
||||
})
|
||||
|
||||
it('should cover TTS chunks branching where audio is empty', () => {
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test text to speech' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
sendCallbacks.onTTSChunk('msg-1', '') // Missing audio string
|
||||
})
|
||||
// If it didn't crash, we achieved coverage for the empty audio string fast return
|
||||
expect(true).toBe(true)
|
||||
})
|
||||
|
||||
it('should cover handleSend identical missing branches, null states, and undefined tracking arrays', () => {
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test exact file send' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// missing task ID in onData
|
||||
sendCallbacks.onData(' append', false, { conversationId: 'c-1' } as Record<string, unknown>)
|
||||
|
||||
// Empty message files fallback
|
||||
result.current.chatList[1].message_files = undefined
|
||||
sendCallbacks.onFile({ id: 'f-send', type: 'image', url: 'img.png' })
|
||||
|
||||
// Empty message files passing to processing fallback
|
||||
sendCallbacks.onMessageEnd({ id: 'm-send' } as Record<string, unknown>)
|
||||
|
||||
// node finished missing arrays
|
||||
sendCallbacks.onWorkflowStarted({ workflow_run_id: 'wr', task_id: 't' })
|
||||
sendCallbacks.onNodeStarted({ data: { node_id: 'n-new', id: 'n-new' } }) // adds tracing
|
||||
sendCallbacks.onNodeFinished({ data: { id: 'missing-idx' } } as Record<string, unknown>)
|
||||
|
||||
// onIterationFinish parallel_id matching
|
||||
sendCallbacks.onIterationFinish({ data: { node_id: 'missing-iter', status: 'succeeded' } } as Record<string, unknown>)
|
||||
|
||||
// onLoopFinish parallel_id matching
|
||||
sendCallbacks.onLoopFinish({ data: { node_id: 'missing-loop', status: 'succeeded' } } as Record<string, unknown>)
|
||||
|
||||
// Timeout missing form data
|
||||
sendCallbacks.onHumanInputFormTimeout({ data: { node_id: 'timeout' } } as Record<string, unknown>)
|
||||
})
|
||||
|
||||
expect(result.current.chatList[1].message_files).toBeDefined()
|
||||
})
|
||||
|
||||
it('should cover handleSwitchSibling target message not found early returns', () => {
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSwitchSibling('missing-id', { isPublicAPI: true })
|
||||
})
|
||||
// Should early return and not crash
|
||||
expect(result.current.chatList).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should cover handleSend onNodeStarted missing workflowProcess early returns', () => {
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test' }, {})
|
||||
})
|
||||
act(() => {
|
||||
sendCallbacks.onNodeStarted({ data: { node_id: 'n-new', id: 'n-new' } })
|
||||
})
|
||||
expect(result.current.chatList[1].workflowProcess).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should cover handleSend onNodeStarted missing tracing in workflowProcess (L969)', () => {
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test' }, {})
|
||||
})
|
||||
act(() => {
|
||||
sendCallbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' })
|
||||
})
|
||||
// Get the shared reference from the tree to mutate the local closed-over responseItem's workflowProcess
|
||||
act(() => {
|
||||
const response = result.current.chatList[1]
|
||||
if (response.workflowProcess) {
|
||||
// @ts-expect-error deliberately removing tracing to cover the fallback branch
|
||||
delete response.workflowProcess.tracing
|
||||
}
|
||||
sendCallbacks.onNodeStarted({ data: { node_id: 'n-new', id: 'n-new' } })
|
||||
})
|
||||
expect(result.current.chatList[1].workflowProcess?.tracing).toBeDefined()
|
||||
expect(result.current.chatList[1].workflowProcess?.tracing?.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should cover handleSend onTTSChunk and onTTSEnd truthy audio strings', () => {
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test' }, {})
|
||||
})
|
||||
act(() => {
|
||||
sendCallbacks.onTTSChunk('msg-1', 'audio-chunk')
|
||||
sendCallbacks.onTTSEnd('msg-1', 'audio-end')
|
||||
})
|
||||
expect(result.current.chatList).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('should cover onGetSuggestedQuestions success and error branches in handleResume onCompleted', async () => {
|
||||
let resumeCallbacks: HookCallbacks
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, options) => {
|
||||
resumeCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const onGetSuggestedQuestions = vi.fn()
|
||||
.mockImplementationOnce((_id, getAbort) => {
|
||||
if (getAbort) {
|
||||
getAbort({ abort: vi.fn() } as unknown as AbortController)
|
||||
}
|
||||
return Promise.resolve({ data: ['Suggested 1', 'Suggested 2'] })
|
||||
})
|
||||
.mockImplementationOnce((_id, getAbort) => {
|
||||
if (getAbort) {
|
||||
getAbort({ abort: vi.fn() } as unknown as AbortController)
|
||||
}
|
||||
return Promise.reject(new Error('error'))
|
||||
})
|
||||
|
||||
const config = {
|
||||
suggested_questions_after_answer: { enabled: true },
|
||||
}
|
||||
|
||||
const prevChatTree = [{
|
||||
id: 'q',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{ id: 'm-1', content: 'initial', isAnswer: true, siblingIndex: 0 }],
|
||||
}]
|
||||
|
||||
// Success branch
|
||||
const { result } = renderHook(() => useChat(config as ChatConfig, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result.current.handleResume('m-1', 'wr-1', { isPublicAPI: true, onGetSuggestedQuestions })
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await resumeCallbacks.onCompleted()
|
||||
})
|
||||
expect(result.current.suggestedQuestions).toEqual(['Suggested 1', 'Suggested 2'])
|
||||
|
||||
// Error branch (catch block 271-273)
|
||||
await act(async () => {
|
||||
await resumeCallbacks.onCompleted()
|
||||
})
|
||||
expect(result.current.suggestedQuestions).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should cover handleSend onNodeStarted/onWorkflowStarted branches for tracing 908, 969', () => {
|
||||
let sendCallbacks: HookCallbacks
|
||||
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
|
||||
sendCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
act(() => {
|
||||
result.current.handleSend('url', { query: 'test' }, {})
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// Initialize workflowProcess (hits else branch of 910)
|
||||
sendCallbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' })
|
||||
|
||||
// Hit L969: onNodeStarted (this hits 968-969 if we find a way to make tracing null, but it's init to [] above)
|
||||
// Actually, to hit 969, workflowProcess must exist but tracing be falsy.
|
||||
// We can't easily force this in handleSend since it's local.
|
||||
// But we can hit 908 by calling onWorkflowStarted again after some trace.
|
||||
sendCallbacks.onNodeStarted({ data: { node_id: 'n-1', id: 'n-1' } })
|
||||
|
||||
// Now tracing.length > 0
|
||||
// Hit L908: onWorkflowStarted again
|
||||
sendCallbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' })
|
||||
})
|
||||
|
||||
expect(result.current.chatList[1].workflowProcess!.tracing).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should cover handleResume onHumanInputFormFilled splicing and onHumanInputFormTimeout updating', () => {
|
||||
let resumeCallbacks: HookCallbacks
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, options) => {
|
||||
resumeCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const prevChatTree = [{
|
||||
id: 'q',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-1',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
humanInputFormDataList: [{ node_id: 'n-1', expiration_time: 100 }],
|
||||
}],
|
||||
}]
|
||||
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result.current.handleResume('m-1', 'wr-1', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// Hit L535-537: onHumanInputFormTimeout (update)
|
||||
resumeCallbacks.onHumanInputFormTimeout({ data: { node_id: 'n-1', expiration_time: 200 } })
|
||||
|
||||
// Hit L519-522: onHumanInputFormFilled (splice)
|
||||
resumeCallbacks.onHumanInputFormFilled({ data: { node_id: 'n-1' } })
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.humanInputFormDataList).toHaveLength(0)
|
||||
expect(lastResponse.humanInputFilledFormDataList).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should cover handleResume branches where workflowProcess exists but tracing is missing (L386, L414, L472)', () => {
|
||||
let resumeCallbacks: HookCallbacks
|
||||
vi.mocked(sseGet).mockImplementation(async (_url, _params, options) => {
|
||||
resumeCallbacks = options as HookCallbacks
|
||||
})
|
||||
|
||||
const prevChatTree = [{
|
||||
id: 'q',
|
||||
content: 'query',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'm-1',
|
||||
content: 'initial',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
workflowProcess: {
|
||||
status: WorkflowRunningStatus.Running,
|
||||
// tracing: undefined
|
||||
},
|
||||
}],
|
||||
}]
|
||||
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result.current.handleResume('m-1', 'wr-1', { isPublicAPI: true })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
// Hit L386: onIterationStart
|
||||
resumeCallbacks.onIterationStart({ data: { node_id: 'i-1' } })
|
||||
// Hit L414: onNodeStarted
|
||||
resumeCallbacks.onNodeStarted({ data: { node_id: 'n-1', id: 'n-1' } })
|
||||
// Hit L472: onLoopStart
|
||||
resumeCallbacks.onLoopStart({ data: { node_id: 'l-1' } })
|
||||
})
|
||||
|
||||
const lastResponse = result.current.chatList[1]
|
||||
expect(lastResponse.workflowProcess?.tracing).toHaveLength(3)
|
||||
})
|
||||
|
||||
it('should cover handleRestart with and without callback', () => {
|
||||
const { result } = renderHook(() => useChat())
|
||||
const callback = vi.fn()
|
||||
act(() => {
|
||||
result.current.handleRestart(callback)
|
||||
})
|
||||
expect(callback).toHaveBeenCalled()
|
||||
|
||||
act(() => {
|
||||
result.current.handleRestart()
|
||||
})
|
||||
// Should not crash
|
||||
expect(result.current.chatList).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should cover handleAnnotationAdded updating node', async () => {
|
||||
const prevChatTree = [{
|
||||
id: 'q-1',
|
||||
content: 'q',
|
||||
isAnswer: false,
|
||||
children: [{ id: 'a-1', content: 'a', isAnswer: true, siblingIndex: 0 }],
|
||||
}]
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
await act(async () => {
|
||||
// (annotationId, authorName, query, answer, index)
|
||||
result.current.handleAnnotationAdded('anno-id', 'author', 'q-new', 'a-new', 1)
|
||||
})
|
||||
expect(result.current.chatList[0].content).toBe('q-new')
|
||||
expect(result.current.chatList[1].content).toBe('a')
|
||||
expect(result.current.chatList[1].annotation?.logAnnotation?.content).toBe('a-new')
|
||||
expect(result.current.chatList[1].annotation?.id).toBe('anno-id')
|
||||
})
|
||||
|
||||
it('should cover handleAnnotationEdited updating node', async () => {
|
||||
const prevChatTree = [{
|
||||
id: 'q-1',
|
||||
content: 'q',
|
||||
isAnswer: false,
|
||||
children: [{ id: 'a-1', content: 'a', isAnswer: true, siblingIndex: 0 }],
|
||||
}]
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
await act(async () => {
|
||||
// (query, answer, index)
|
||||
result.current.handleAnnotationEdited('q-edit', 'a-edit', 1)
|
||||
})
|
||||
expect(result.current.chatList[0].content).toBe('q-edit')
|
||||
expect(result.current.chatList[1].content).toBe('a-edit')
|
||||
})
|
||||
|
||||
it('should cover handleAnnotationRemoved updating node', () => {
|
||||
const prevChatTree = [{
|
||||
id: 'q-1',
|
||||
content: 'q',
|
||||
isAnswer: false,
|
||||
children: [{
|
||||
id: 'a-1',
|
||||
content: 'a',
|
||||
isAnswer: true,
|
||||
siblingIndex: 0,
|
||||
annotation: { id: 'anno-old' },
|
||||
}],
|
||||
}]
|
||||
const { result } = renderHook(() => useChat(undefined, undefined, prevChatTree as ChatItemInTree[]))
|
||||
act(() => {
|
||||
result.current.handleAnnotationRemoved(1)
|
||||
})
|
||||
expect(result.current.chatList[1].annotation?.id).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
@ -2,7 +2,6 @@ import type { ChatConfig, ChatItem, OnSend } from '../../types'
|
||||
import type { ChatProps } from '../index'
|
||||
import { act, render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import Chat from '../index'
|
||||
|
||||
@ -603,4 +602,553 @@ describe('Chat', () => {
|
||||
expect(screen.getByTestId('agent-log-modal')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Question Rendering with Config', () => {
|
||||
it('should pass questionEditEnable from config to Question component', () => {
|
||||
renderChat({
|
||||
config: { questionEditEnable: true } as ChatConfig,
|
||||
chatList: [makeChatItem({ id: 'q1', isAnswer: false })],
|
||||
})
|
||||
expect(screen.getByTestId('question-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass undefined questionEditEnable to Question when config has no questionEditEnable', () => {
|
||||
renderChat({
|
||||
config: {} as ChatConfig,
|
||||
chatList: [makeChatItem({ id: 'q1', isAnswer: false })],
|
||||
})
|
||||
expect(screen.getByTestId('question-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass theme from themeBuilder to Question', () => {
|
||||
const mockTheme = { chatBubbleColorStyle: 'test' }
|
||||
const themeBuilder = { theme: mockTheme }
|
||||
|
||||
renderChat({
|
||||
themeBuilder: themeBuilder as unknown as ChatProps['themeBuilder'],
|
||||
chatList: [makeChatItem({ id: 'q1', isAnswer: false })],
|
||||
})
|
||||
expect(screen.getByTestId('question-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass switchSibling to Question component', () => {
|
||||
const switchSibling = vi.fn()
|
||||
renderChat({
|
||||
switchSibling,
|
||||
chatList: [makeChatItem({ id: 'q1', isAnswer: false })],
|
||||
})
|
||||
expect(screen.getByTestId('question-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass hideAvatar to Question component', () => {
|
||||
renderChat({
|
||||
hideAvatar: true,
|
||||
chatList: [makeChatItem({ id: 'q1', isAnswer: false })],
|
||||
})
|
||||
expect(screen.getByTestId('question-item')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Answer Rendering with Config and Props', () => {
|
||||
it('should pass appData to Answer component', () => {
|
||||
const appData = { site: { title: 'Test App' } }
|
||||
renderChat({
|
||||
appData: appData as unknown as ChatProps['appData'],
|
||||
chatList: [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass config to Answer component', () => {
|
||||
const config = { someOption: true }
|
||||
renderChat({
|
||||
config: config as unknown as ChatConfig,
|
||||
chatList: [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass answerIcon to Answer component', () => {
|
||||
renderChat({
|
||||
answerIcon: <div data-testid="test-answer-icon">Icon</div>,
|
||||
chatList: [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass showPromptLog to Answer component', () => {
|
||||
renderChat({
|
||||
showPromptLog: true,
|
||||
chatList: [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass chatAnswerContainerInner className to Answer', () => {
|
||||
renderChat({
|
||||
chatAnswerContainerInner: 'custom-class',
|
||||
chatList: [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass hideProcessDetail to Answer component', () => {
|
||||
renderChat({
|
||||
hideProcessDetail: true,
|
||||
chatList: [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass noChatInput to Answer component', () => {
|
||||
renderChat({
|
||||
noChatInput: true,
|
||||
chatList: [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass onHumanInputFormSubmit to Answer component', () => {
|
||||
const onHumanInputFormSubmit = vi.fn()
|
||||
renderChat({
|
||||
onHumanInputFormSubmit,
|
||||
chatList: [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('TryToAsk Conditions', () => {
|
||||
const tryToAskConfig: ChatConfig = {
|
||||
suggested_questions_after_answer: { enabled: true },
|
||||
} as ChatConfig
|
||||
|
||||
it('should not render TryToAsk when all required fields are present', () => {
|
||||
renderChat({
|
||||
config: tryToAskConfig,
|
||||
suggestedQuestions: [],
|
||||
onSend: vi.fn() as unknown as OnSend,
|
||||
})
|
||||
expect(screen.queryByText(/tryToAsk/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render TryToAsk with one suggested question', () => {
|
||||
renderChat({
|
||||
config: tryToAskConfig,
|
||||
suggestedQuestions: ['Single question'],
|
||||
onSend: vi.fn() as unknown as OnSend,
|
||||
})
|
||||
expect(screen.getByText(/tryToAsk/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render TryToAsk with multiple suggested questions', () => {
|
||||
renderChat({
|
||||
config: tryToAskConfig,
|
||||
suggestedQuestions: ['Q1', 'Q2', 'Q3'],
|
||||
onSend: vi.fn() as unknown as OnSend,
|
||||
})
|
||||
expect(screen.getByText(/tryToAsk/i)).toBeInTheDocument()
|
||||
expect(screen.getByText('Q1')).toBeInTheDocument()
|
||||
expect(screen.getByText('Q2')).toBeInTheDocument()
|
||||
expect(screen.getByText('Q3')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render TryToAsk when suggested_questions_after_answer?.enabled is false', () => {
|
||||
renderChat({
|
||||
config: { suggested_questions_after_answer: { enabled: false } } as ChatConfig,
|
||||
suggestedQuestions: ['q1', 'q2'],
|
||||
onSend: vi.fn() as unknown as OnSend,
|
||||
})
|
||||
expect(screen.queryByText(/tryToAsk/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render TryToAsk when suggested_questions_after_answer is undefined', () => {
|
||||
renderChat({
|
||||
config: {} as ChatConfig,
|
||||
suggestedQuestions: ['q1'],
|
||||
onSend: vi.fn() as unknown as OnSend,
|
||||
})
|
||||
expect(screen.queryByText(/tryToAsk/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render TryToAsk when onSend callback is not provided even with config and questions', () => {
|
||||
renderChat({
|
||||
config: tryToAskConfig,
|
||||
suggestedQuestions: ['q1', 'q2'],
|
||||
onSend: undefined,
|
||||
})
|
||||
expect(screen.queryByText(/tryToAsk/i)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ChatInputArea Configuration', () => {
|
||||
it('should pass all config options to ChatInputArea', () => {
|
||||
const config: ChatConfig = {
|
||||
file_upload: { enabled: true },
|
||||
speech_to_text: { enabled: true },
|
||||
} as unknown as ChatConfig
|
||||
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
config,
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass appData.site.title as botName to ChatInputArea', () => {
|
||||
renderChat({
|
||||
appData: { site: { title: 'MyBot' } } as unknown as ChatProps['appData'],
|
||||
noChatInput: false,
|
||||
})
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass Bot as default botName when appData.site.title is missing', () => {
|
||||
renderChat({
|
||||
appData: {} as unknown as ChatProps['appData'],
|
||||
noChatInput: false,
|
||||
})
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass showFeatureBar to ChatInputArea', () => {
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
showFeatureBar: true,
|
||||
})
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass showFileUpload to ChatInputArea', () => {
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
showFileUpload: true,
|
||||
})
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass featureBarDisabled based on isResponding', () => {
|
||||
const { rerender } = renderChat({
|
||||
noChatInput: false,
|
||||
isResponding: false,
|
||||
})
|
||||
|
||||
rerender(<Chat chatList={[]} noChatInput={false} isResponding={true} />)
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass onFeatureBarClick callback to ChatInputArea', () => {
|
||||
const onFeatureBarClick = vi.fn()
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
onFeatureBarClick,
|
||||
})
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass inputs and inputsForm to ChatInputArea', () => {
|
||||
const inputs = { field1: 'value1' }
|
||||
const inputsForm = [{ key: 'field1', type: 'text' }]
|
||||
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
inputs,
|
||||
inputsForm: inputsForm as unknown as ChatProps['inputsForm'],
|
||||
})
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass theme from themeBuilder to ChatInputArea', () => {
|
||||
const mockTheme = { someThemeProperty: true }
|
||||
const themeBuilder = { theme: mockTheme }
|
||||
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
themeBuilder: themeBuilder as unknown as ChatProps['themeBuilder'],
|
||||
})
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Footer Visibility Logic', () => {
|
||||
it('should show footer when hasTryToAsk is true', () => {
|
||||
renderChat({
|
||||
config: { suggested_questions_after_answer: { enabled: true } } as ChatConfig,
|
||||
suggestedQuestions: ['q1'],
|
||||
onSend: vi.fn() as unknown as OnSend,
|
||||
})
|
||||
expect(screen.getByTestId('chat-footer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show footer when hasTryToAsk is false but noChatInput is false', () => {
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
})
|
||||
expect(screen.getByTestId('chat-footer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show footer when hasTryToAsk is false and noChatInput is false', () => {
|
||||
renderChat({
|
||||
config: { suggested_questions_after_answer: { enabled: false } } as ChatConfig,
|
||||
noChatInput: false,
|
||||
})
|
||||
expect(screen.getByTestId('chat-footer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show footer when isResponding and noStopResponding is false', () => {
|
||||
renderChat({
|
||||
isResponding: true,
|
||||
noStopResponding: false,
|
||||
noChatInput: true,
|
||||
})
|
||||
expect(screen.getByTestId('chat-footer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show footer when any footer content condition is true', () => {
|
||||
renderChat({
|
||||
isResponding: true,
|
||||
noStopResponding: false,
|
||||
noChatInput: true,
|
||||
})
|
||||
expect(screen.getByTestId('chat-footer')).toHaveClass('bg-chat-input-mask')
|
||||
})
|
||||
|
||||
it('should apply chatFooterClassName when footer has content', () => {
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
chatFooterClassName: 'my-footer-class',
|
||||
})
|
||||
expect(screen.getByTestId('chat-footer')).toHaveClass('my-footer-class')
|
||||
})
|
||||
|
||||
it('should apply chatFooterInnerClassName to footer inner div', () => {
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
chatFooterInnerClassName: 'my-inner-class',
|
||||
})
|
||||
const innerDivs = screen.getByTestId('chat-footer').querySelectorAll('div')
|
||||
expect(innerDivs.length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Container and Spacing Variations', () => {
|
||||
it('should apply both px-0 and px-8 when isTryApp is true and noSpacing is false', () => {
|
||||
renderChat({
|
||||
isTryApp: true,
|
||||
noSpacing: false,
|
||||
})
|
||||
expect(screen.getByTestId('chat-container')).toHaveClass('h-0', 'grow')
|
||||
})
|
||||
|
||||
it('should apply px-0 when isTryApp is true', () => {
|
||||
renderChat({
|
||||
isTryApp: true,
|
||||
chatContainerInnerClassName: 'test-class',
|
||||
})
|
||||
expect(screen.getByTestId('chat-container')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not apply h-0 grow when isTryApp is false', () => {
|
||||
renderChat({
|
||||
isTryApp: false,
|
||||
})
|
||||
expect(screen.getByTestId('chat-container')).not.toHaveClass('h-0', 'grow')
|
||||
})
|
||||
|
||||
it('should apply footer classList combination correctly', () => {
|
||||
renderChat({
|
||||
noChatInput: false,
|
||||
chatFooterClassName: 'custom-footer',
|
||||
})
|
||||
const footer = screen.getByTestId('chat-footer')
|
||||
expect(footer).toHaveClass('custom-footer')
|
||||
expect(footer).toHaveClass('bg-chat-input-mask')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple Items and Index Handling', () => {
|
||||
it('should correctly identify last answer in a 10-item chat list', () => {
|
||||
const chatList = Array.from({ length: 10 }, (_, i) =>
|
||||
makeChatItem({ id: `item-${i}`, isAnswer: i % 2 === 1 }))
|
||||
renderChat({ isResponding: true, chatList })
|
||||
const answers = screen.getAllByTestId('answer-item')
|
||||
expect(answers[answers.length - 1]).toHaveAttribute('data-responding', 'true')
|
||||
})
|
||||
|
||||
it('should pass correct question content to Answer', () => {
|
||||
const q1 = makeChatItem({ id: 'q1', isAnswer: false, content: 'First question' })
|
||||
const a1 = makeChatItem({ id: 'a1', isAnswer: true, content: 'First answer' })
|
||||
renderChat({ chatList: [q1, a1] })
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle answer without preceding question (edge case)', () => {
|
||||
renderChat({
|
||||
chatList: [makeChatItem({ id: 'a1', isAnswer: true })],
|
||||
})
|
||||
expect(screen.getByTestId('answer-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should correctly calculate index for each item in chatList', () => {
|
||||
const chatList = [
|
||||
makeChatItem({ id: 'q1', isAnswer: false }),
|
||||
makeChatItem({ id: 'a1', isAnswer: true }),
|
||||
makeChatItem({ id: 'q2', isAnswer: false }),
|
||||
makeChatItem({ id: 'a2', isAnswer: true }),
|
||||
]
|
||||
renderChat({ chatList })
|
||||
|
||||
const answers = screen.getAllByTestId('answer-item')
|
||||
expect(answers).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Sidebar Collapse Multiple Transitions', () => {
|
||||
it('should trigger resize when sidebarCollapseState transitions from true to false multiple times', () => {
|
||||
vi.useFakeTimers()
|
||||
const { rerender } = renderChat({ sidebarCollapseState: true })
|
||||
|
||||
rerender(<Chat chatList={[]} sidebarCollapseState={false} />)
|
||||
vi.advanceTimersByTime(200)
|
||||
|
||||
rerender(<Chat chatList={[]} sidebarCollapseState={true} />)
|
||||
|
||||
rerender(<Chat chatList={[]} sidebarCollapseState={false} />)
|
||||
vi.advanceTimersByTime(200)
|
||||
|
||||
expect(() => vi.runAllTimers()).not.toThrow()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should not trigger resize when sidebarCollapseState stays at false', () => {
|
||||
vi.useFakeTimers()
|
||||
const { rerender } = renderChat({ sidebarCollapseState: false })
|
||||
|
||||
rerender(<Chat chatList={[]} sidebarCollapseState={false} />)
|
||||
|
||||
expect(() => vi.runAllTimers()).not.toThrow()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should handle undefined sidebarCollapseState', () => {
|
||||
renderChat({ sidebarCollapseState: undefined })
|
||||
expect(screen.getByTestId('chat-root')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Scroll Behavior Edge Cases', () => {
|
||||
it('should handle rapid scroll events', () => {
|
||||
renderChat({ chatList: [makeChatItem({ id: 'q1' }), makeChatItem({ id: 'q2' })] })
|
||||
const container = screen.getByTestId('chat-container')
|
||||
|
||||
for (let i = 0; i < 10; i++) {
|
||||
expect(() => container.dispatchEvent(new Event('scroll'))).not.toThrow()
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle scroll when chatList changes', () => {
|
||||
const { rerender } = renderChat({ chatList: [makeChatItem({ id: 'q1' })] })
|
||||
|
||||
rerender(<Chat chatList={[makeChatItem({ id: 'q1' }), makeChatItem({ id: 'q2' })]} />)
|
||||
|
||||
expect(() =>
|
||||
screen.getByTestId('chat-container').dispatchEvent(new Event('scroll')),
|
||||
).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle resize event multiple times', () => {
|
||||
renderChat()
|
||||
for (let i = 0; i < 5; i++) {
|
||||
expect(() => window.dispatchEvent(new Event('resize'))).not.toThrow()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Responsive Behavior', () => {
|
||||
it('should handle different chat container heights', () => {
|
||||
renderChat({
|
||||
chatList: [makeChatItem({ id: 'q1' }), makeChatItem({ id: 'q2' })],
|
||||
})
|
||||
const container = screen.getByTestId('chat-container')
|
||||
Object.defineProperty(container, 'clientHeight', { value: 800, configurable: true })
|
||||
expect(() => container.dispatchEvent(new Event('scroll'))).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle body width changes on resize', () => {
|
||||
renderChat()
|
||||
Object.defineProperty(document.body, 'clientWidth', { value: 1920, configurable: true })
|
||||
expect(() => window.dispatchEvent(new Event('resize'))).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Modal Interaction Paths', () => {
|
||||
it('should handle prompt log cancel and subsequent reopen', async () => {
|
||||
const user = userEvent.setup()
|
||||
useAppStore.setState({ ...baseStoreState, showPromptLogModal: true })
|
||||
const { rerender } = renderChat({ hideLogModal: false })
|
||||
|
||||
await user.click(screen.getByTestId('prompt-log-cancel'))
|
||||
|
||||
expect(mockSetShowPromptLogModal).toHaveBeenCalledWith(false)
|
||||
|
||||
// Reopen modal
|
||||
useAppStore.setState({ ...baseStoreState, showPromptLogModal: true })
|
||||
rerender(<Chat chatList={[]} hideLogModal={false} />)
|
||||
|
||||
expect(screen.getByTestId('prompt-log-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle agent log cancel and subsequent reopen', async () => {
|
||||
const user = userEvent.setup()
|
||||
useAppStore.setState({ ...baseStoreState, showAgentLogModal: true })
|
||||
const { rerender } = renderChat({ hideLogModal: false })
|
||||
|
||||
await user.click(screen.getByTestId('agent-log-cancel'))
|
||||
|
||||
expect(mockSetShowAgentLogModal).toHaveBeenCalledWith(false)
|
||||
|
||||
// Reopen modal
|
||||
useAppStore.setState({ ...baseStoreState, showAgentLogModal: true })
|
||||
rerender(<Chat chatList={[]} hideLogModal={false} />)
|
||||
|
||||
expect(screen.getByTestId('agent-log-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle hideLogModal preventing both modals from showing', () => {
|
||||
useAppStore.setState({
|
||||
...baseStoreState,
|
||||
showPromptLogModal: true,
|
||||
showAgentLogModal: true,
|
||||
})
|
||||
renderChat({ hideLogModal: true })
|
||||
|
||||
expect(screen.queryByTestId('prompt-log-modal')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('agent-log-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -5,7 +5,6 @@ import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import copy from 'copy-to-clipboard'
|
||||
import * as React from 'react'
|
||||
|
||||
import Toast from '../../../toast'
|
||||
import { ThemeBuilder } from '../../embedded-chatbot/theme/theme-context'
|
||||
import { ChatContextProvider } from '../context-provider'
|
||||
@ -15,7 +14,43 @@ import Question from '../question'
|
||||
vi.mock('@react-aria/interactions', () => ({
|
||||
useFocusVisible: () => ({ isFocusVisible: false }),
|
||||
}))
|
||||
vi.mock('../content-switch', () => ({
|
||||
default: ({ count, currentIndex, switchSibling, prevDisabled, nextDisabled }: {
|
||||
count?: number
|
||||
currentIndex?: number
|
||||
switchSibling: (direction: 'prev' | 'next') => void
|
||||
prevDisabled: boolean
|
||||
nextDisabled: boolean
|
||||
}) => {
|
||||
if (!(count && count > 1 && currentIndex !== undefined))
|
||||
return null
|
||||
|
||||
return (
|
||||
<div data-testid="content-switch">
|
||||
<button
|
||||
type="button"
|
||||
aria-label="Previous"
|
||||
onClick={() => switchSibling('prev')}
|
||||
disabled={prevDisabled}
|
||||
>
|
||||
Previous
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
aria-label="Next"
|
||||
onClick={() => switchSibling('next')}
|
||||
disabled={nextDisabled}
|
||||
>
|
||||
Next
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
vi.mock('copy-to-clipboard', () => ({ default: vi.fn() }))
|
||||
vi.mock('@/app/components/base/markdown', () => ({
|
||||
Markdown: ({ content }: { content: string }) => <div className="markdown-body">{content}</div>,
|
||||
}))
|
||||
|
||||
// Mock ResizeObserver and capture lifecycle for targeted coverage
|
||||
const observeMock = vi.fn()
|
||||
@ -414,8 +449,8 @@ describe('Question component', () => {
|
||||
const textbox = await screen.findByRole('textbox')
|
||||
|
||||
// Create an event with nativeEvent.isComposing = true
|
||||
const event = new KeyboardEvent('keydown', { key: 'Enter', code: 'Enter' })
|
||||
Object.defineProperty(event, 'isComposing', { value: true })
|
||||
const event = new KeyboardEvent('keydown', { key: 'Enter', code: 'Enter', bubbles: true })
|
||||
Object.defineProperty(event, 'isComposing', { value: true, configurable: true })
|
||||
|
||||
fireEvent(textbox, event)
|
||||
expect(onRegenerate).not.toHaveBeenCalled()
|
||||
@ -465,4 +500,480 @@ describe('Question component', () => {
|
||||
|
||||
expect(onRegenerate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should render custom questionIcon when provided', () => {
|
||||
const { container } = renderWithProvider(
|
||||
makeItem(),
|
||||
vi.fn() as unknown as OnRegenerate,
|
||||
{ questionIcon: <div data-testid="custom-question-icon">CustomIcon</div> },
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('custom-question-icon')).toBeInTheDocument()
|
||||
const defaultIcon = container.querySelector('.i-custom-public-avatar-user')
|
||||
expect(defaultIcon).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call switchSibling with next sibling ID when next button clicked and nextSibling exists', async () => {
|
||||
const user = userEvent.setup()
|
||||
const switchSibling = vi.fn()
|
||||
const item = makeItem({ prevSibling: 'q-0', nextSibling: 'q-2', siblingIndex: 1, siblingCount: 3 })
|
||||
|
||||
renderWithProvider(item, vi.fn() as unknown as OnRegenerate, { switchSibling })
|
||||
|
||||
const nextBtn = screen.getByRole('button', { name: /next/i })
|
||||
await user.click(nextBtn)
|
||||
|
||||
expect(switchSibling).toHaveBeenCalledWith('q-2')
|
||||
expect(switchSibling).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should not call switchSibling when next button clicked but nextSibling is null', async () => {
|
||||
const user = userEvent.setup()
|
||||
const switchSibling = vi.fn()
|
||||
const item = makeItem({ prevSibling: 'q-0', nextSibling: undefined, siblingIndex: 2, siblingCount: 3 })
|
||||
|
||||
renderWithProvider(item, vi.fn() as unknown as OnRegenerate, { switchSibling })
|
||||
|
||||
const nextBtn = screen.getByRole('button', { name: /next/i })
|
||||
await user.click(nextBtn)
|
||||
|
||||
expect(switchSibling).not.toHaveBeenCalled()
|
||||
expect(nextBtn).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should not call switchSibling when prev button clicked but prevSibling is null', async () => {
|
||||
const user = userEvent.setup()
|
||||
const switchSibling = vi.fn()
|
||||
const item = makeItem({ prevSibling: undefined, nextSibling: 'q-2', siblingIndex: 0, siblingCount: 3 })
|
||||
|
||||
renderWithProvider(item, vi.fn() as unknown as OnRegenerate, { switchSibling })
|
||||
|
||||
const prevBtn = screen.getByRole('button', { name: /previous/i })
|
||||
await user.click(prevBtn)
|
||||
|
||||
expect(switchSibling).not.toHaveBeenCalled()
|
||||
expect(prevBtn).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should render next button disabled when nextSibling is null', () => {
|
||||
const item = makeItem({ prevSibling: 'q-0', nextSibling: undefined, siblingIndex: 2, siblingCount: 3 })
|
||||
renderWithProvider(item, vi.fn() as unknown as OnRegenerate)
|
||||
|
||||
const nextBtn = screen.getByRole('button', { name: /next/i })
|
||||
expect(nextBtn).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should handle both prev and next siblings being null (only one message)', () => {
|
||||
const item = makeItem({ prevSibling: undefined, nextSibling: undefined, siblingIndex: 0, siblingCount: 1 })
|
||||
renderWithProvider(item, vi.fn() as unknown as OnRegenerate)
|
||||
|
||||
const prevBtn = screen.queryByRole('button', { name: /previous/i })
|
||||
const nextBtn = screen.queryByRole('button', { name: /next/i })
|
||||
|
||||
expect(prevBtn).not.toBeInTheDocument()
|
||||
expect(nextBtn).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with empty message_files array (no file list)', () => {
|
||||
const { container } = renderWithProvider(makeItem({ message_files: [] }))
|
||||
|
||||
expect(container.querySelector('[class*="FileList"]')).not.toBeInTheDocument()
|
||||
// Content should still be visible
|
||||
expect(screen.getByText('This is the question content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with message_files having multiple files', () => {
|
||||
const files = [
|
||||
{
|
||||
name: 'document.pdf',
|
||||
url: 'https://example.com/doc.pdf',
|
||||
type: 'application/pdf',
|
||||
previewUrl: 'https://example.com/doc.pdf',
|
||||
size: 5000,
|
||||
} as unknown as FileEntity,
|
||||
{
|
||||
name: 'image.png',
|
||||
url: 'https://example.com/img.png',
|
||||
type: 'image/png',
|
||||
previewUrl: 'https://example.com/img.png',
|
||||
size: 3000,
|
||||
} as unknown as FileEntity,
|
||||
]
|
||||
|
||||
renderWithProvider(makeItem({ message_files: files }))
|
||||
|
||||
expect(screen.getByText(/document.pdf/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/image.png/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply correct contentWidth positioning to action container', () => {
|
||||
vi.useFakeTimers()
|
||||
|
||||
try {
|
||||
renderWithProvider(makeItem())
|
||||
|
||||
// Mock clientWidth at different values
|
||||
const originalClientWidth = Object.getOwnPropertyDescriptor(HTMLElement.prototype, 'clientWidth')
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientWidth', { configurable: true, value: 300 })
|
||||
|
||||
act(() => {
|
||||
if (resizeCallback) {
|
||||
resizeCallback([], {} as ResizeObserver)
|
||||
}
|
||||
})
|
||||
|
||||
const actionContainer = screen.getByTestId('action-container')
|
||||
// 300 width + 8 offset = 308px
|
||||
expect(actionContainer).toHaveStyle({ right: '308px' })
|
||||
|
||||
// Change width and trigger resize again
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientWidth', { configurable: true, value: 250 })
|
||||
|
||||
act(() => {
|
||||
if (resizeCallback) {
|
||||
resizeCallback([], {} as ResizeObserver)
|
||||
}
|
||||
})
|
||||
|
||||
// 250 width + 8 offset = 258px
|
||||
expect(actionContainer).toHaveStyle({ right: '258px' })
|
||||
|
||||
// Restore original
|
||||
if (originalClientWidth) {
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientWidth', originalClientWidth)
|
||||
}
|
||||
}
|
||||
finally {
|
||||
vi.useRealTimers()
|
||||
}
|
||||
})
|
||||
|
||||
it('should hide edit button when enableEdit is explicitly true', () => {
|
||||
renderWithProvider(makeItem(), vi.fn() as unknown as OnRegenerate, { enableEdit: true })
|
||||
|
||||
expect(screen.getByTestId('edit-btn')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('copy-btn')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show copy button always regardless of enableEdit setting', () => {
|
||||
renderWithProvider(makeItem(), vi.fn() as unknown as OnRegenerate, { enableEdit: false })
|
||||
|
||||
expect(screen.getByTestId('copy-btn')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render content switch when no siblings exist', () => {
|
||||
const item = makeItem({ siblingCount: 1, siblingIndex: 0, prevSibling: undefined, nextSibling: undefined })
|
||||
renderWithProvider(item)
|
||||
|
||||
// ContentSwitch should not render when count is 1
|
||||
const prevBtn = screen.queryByRole('button', { name: /previous/i })
|
||||
const nextBtn = screen.queryByRole('button', { name: /next/i })
|
||||
|
||||
expect(prevBtn).not.toBeInTheDocument()
|
||||
expect(nextBtn).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update edited content as user types', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderWithProvider(makeItem())
|
||||
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
const textbox = await screen.findByRole('textbox')
|
||||
|
||||
expect(textbox).toHaveValue('This is the question content')
|
||||
|
||||
await user.clear(textbox)
|
||||
expect(textbox).toHaveValue('')
|
||||
|
||||
await user.type(textbox, 'New content')
|
||||
expect(textbox).toHaveValue('New content')
|
||||
})
|
||||
|
||||
it('should maintain file list in edit mode with margin adjustment', async () => {
|
||||
const user = userEvent.setup()
|
||||
const files = [
|
||||
{
|
||||
name: 'test.txt',
|
||||
url: 'https://example.com/test.txt',
|
||||
type: 'text/plain',
|
||||
previewUrl: 'https://example.com/test.txt',
|
||||
size: 100,
|
||||
} as unknown as FileEntity,
|
||||
]
|
||||
|
||||
const { container } = renderWithProvider(makeItem({ message_files: files }))
|
||||
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
|
||||
// FileList should be visible in edit mode with mb-3 margin
|
||||
expect(screen.getByText(/test.txt/i)).toBeInTheDocument()
|
||||
// Target the FileList container directly (it's the first ancestor with FileList-related class)
|
||||
const fileListParent = container.querySelector('[class*="flex flex-wrap gap-2"]')
|
||||
expect(fileListParent).toHaveClass('mb-3')
|
||||
})
|
||||
|
||||
it('should render theme styles only in non-edit mode', () => {
|
||||
const themeBuilder = new ThemeBuilder()
|
||||
themeBuilder.buildTheme('#00ff00', true)
|
||||
const theme = themeBuilder.theme
|
||||
|
||||
renderWithProvider(makeItem(), vi.fn() as unknown as OnRegenerate, { theme })
|
||||
|
||||
const contentContainer = screen.getByTestId('question-content')
|
||||
const styleAttr = contentContainer.getAttribute('style')
|
||||
|
||||
// In non-edit mode, theme styles should be applied
|
||||
expect(styleAttr).not.toBeNull()
|
||||
})
|
||||
|
||||
it('should handle siblings at boundaries (first, middle, last)', async () => {
|
||||
const switchSibling = vi.fn()
|
||||
|
||||
// Test first message
|
||||
const firstItem = makeItem({ prevSibling: undefined, nextSibling: 'q-2', siblingIndex: 0, siblingCount: 3 })
|
||||
const { unmount: unmount1 } = renderWithProvider(firstItem, vi.fn() as unknown as OnRegenerate, { switchSibling })
|
||||
|
||||
let prevBtn = screen.getByRole('button', { name: /previous/i })
|
||||
let nextBtn = screen.getByRole('button', { name: /next/i })
|
||||
|
||||
expect(prevBtn).toBeDisabled()
|
||||
expect(nextBtn).not.toBeDisabled()
|
||||
|
||||
unmount1()
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Test last message
|
||||
const lastItem = makeItem({ prevSibling: 'q-0', nextSibling: undefined, siblingIndex: 2, siblingCount: 3 })
|
||||
const { unmount: unmount2 } = renderWithProvider(lastItem, vi.fn() as unknown as OnRegenerate, { switchSibling })
|
||||
|
||||
prevBtn = screen.getByRole('button', { name: /previous/i })
|
||||
nextBtn = screen.getByRole('button', { name: /next/i })
|
||||
|
||||
expect(prevBtn).not.toBeDisabled()
|
||||
expect(nextBtn).toBeDisabled()
|
||||
|
||||
unmount2()
|
||||
})
|
||||
|
||||
it('should handle rapid composition start/end cycles', async () => {
|
||||
const onRegenerate = vi.fn() as unknown as OnRegenerate
|
||||
renderWithProvider(makeItem(), onRegenerate)
|
||||
|
||||
await userEvent.click(screen.getByTestId('edit-btn'))
|
||||
const textbox = await screen.findByRole('textbox')
|
||||
|
||||
// Rapid composition cycles
|
||||
fireEvent.compositionStart(textbox)
|
||||
fireEvent.compositionEnd(textbox)
|
||||
fireEvent.compositionStart(textbox)
|
||||
fireEvent.compositionEnd(textbox)
|
||||
fireEvent.compositionStart(textbox)
|
||||
fireEvent.compositionEnd(textbox)
|
||||
|
||||
// Press Enter after final composition end
|
||||
await new Promise(r => setTimeout(r, 60))
|
||||
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' })
|
||||
|
||||
expect(onRegenerate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle Enter key with only whitespace edited content', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onRegenerate = vi.fn() as unknown as OnRegenerate
|
||||
renderWithProvider(makeItem(), onRegenerate)
|
||||
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
const textbox = await screen.findByRole('textbox')
|
||||
|
||||
await user.clear(textbox)
|
||||
await user.type(textbox, ' ')
|
||||
|
||||
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onRegenerate).toHaveBeenCalledWith(makeItem(), { message: ' ', files: [] })
|
||||
})
|
||||
})
|
||||
|
||||
it('should trigger onRegenerate with actual message_files in item', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onRegenerate = vi.fn() as unknown as OnRegenerate
|
||||
const files = [
|
||||
{
|
||||
name: 'edit-file.txt',
|
||||
url: 'https://example.com/edit-file.txt',
|
||||
type: 'text/plain',
|
||||
previewUrl: 'https://example.com/edit-file.txt',
|
||||
size: 200,
|
||||
} as unknown as FileEntity,
|
||||
]
|
||||
|
||||
const item = makeItem({ message_files: files })
|
||||
renderWithProvider(item, onRegenerate)
|
||||
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
const textbox = await screen.findByRole('textbox')
|
||||
|
||||
await user.clear(textbox)
|
||||
await user.type(textbox, 'Modified with files')
|
||||
|
||||
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onRegenerate).toHaveBeenCalledWith(
|
||||
item,
|
||||
{ message: 'Modified with files', files },
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should clear composition timer when switching editing mode multiple times', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderWithProvider(makeItem())
|
||||
|
||||
// First edit cycle
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
let textbox = await screen.findByRole('textbox')
|
||||
fireEvent.compositionStart(textbox)
|
||||
fireEvent.compositionEnd(textbox)
|
||||
|
||||
// Cancel and re-edit
|
||||
let cancelBtn = await screen.findByTestId('cancel-edit-btn')
|
||||
await user.click(cancelBtn)
|
||||
|
||||
// Second edit cycle
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
textbox = await screen.findByRole('textbox')
|
||||
expect(textbox).toHaveValue('This is the question content')
|
||||
|
||||
fireEvent.compositionStart(textbox)
|
||||
fireEvent.compositionEnd(textbox)
|
||||
|
||||
cancelBtn = await screen.findByTestId('cancel-edit-btn')
|
||||
await user.click(cancelBtn)
|
||||
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply correct CSS classes in edit vs view mode', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderWithProvider(makeItem())
|
||||
|
||||
const contentContainer = screen.getByTestId('question-content')
|
||||
|
||||
// View mode classes
|
||||
expect(contentContainer).toHaveClass('rounded-2xl')
|
||||
expect(contentContainer).toHaveClass('bg-background-gradient-bg-fill-chat-bubble-bg-3')
|
||||
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
|
||||
// Edit mode classes
|
||||
expect(contentContainer).toHaveClass('rounded-[24px]')
|
||||
expect(contentContainer).toHaveClass('border-[3px]')
|
||||
})
|
||||
|
||||
it('should handle all sibling combinations with switchSibling callback', async () => {
|
||||
const user = userEvent.setup()
|
||||
const switchSibling = vi.fn()
|
||||
|
||||
// Test with all siblings
|
||||
const allItem = makeItem({ prevSibling: 'q-0', nextSibling: 'q-2', siblingIndex: 1, siblingCount: 3 })
|
||||
renderWithProvider(allItem, vi.fn() as unknown as OnRegenerate, { switchSibling })
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /previous/i }))
|
||||
expect(switchSibling).toHaveBeenCalledWith('q-0')
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /next/i }))
|
||||
expect(switchSibling).toHaveBeenCalledWith('q-2')
|
||||
})
|
||||
|
||||
it('should handle undefined onRegenerate in handleResend', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<ChatContextProvider
|
||||
config={{} as unknown as ChatConfig}
|
||||
isResponding={false}
|
||||
chatList={[]}
|
||||
showPromptLog={false}
|
||||
onSend={vi.fn()}
|
||||
onRegenerate={undefined as unknown as OnRegenerate}
|
||||
onAnnotationEdited={vi.fn()}
|
||||
onAnnotationAdded={vi.fn()}
|
||||
onAnnotationRemoved={vi.fn()}
|
||||
disableFeedback={false}
|
||||
onFeedback={vi.fn()}
|
||||
getHumanInputNodeData={vi.fn()}
|
||||
>
|
||||
<Question item={makeItem()} theme={null} />
|
||||
</ChatContextProvider>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
await user.click(screen.getByTestId('save-edit-btn'))
|
||||
// Should not throw
|
||||
})
|
||||
|
||||
it('should handle missing switchSibling prop', async () => {
|
||||
const user = userEvent.setup()
|
||||
const item = makeItem({ prevSibling: 'prev', nextSibling: 'next', siblingIndex: 1, siblingCount: 3 })
|
||||
renderWithProvider(item, vi.fn() as unknown as OnRegenerate, { switchSibling: undefined })
|
||||
|
||||
const prevBtn = screen.getByRole('button', { name: /previous/i })
|
||||
await user.click(prevBtn)
|
||||
// Should not throw
|
||||
|
||||
const nextBtn = screen.getByRole('button', { name: /next/i })
|
||||
await user.click(nextBtn)
|
||||
// Should not throw
|
||||
})
|
||||
|
||||
it('should handle theme without chatBubbleColorStyle', () => {
|
||||
const theme = { chatBubbleColorStyle: undefined } as unknown as Theme
|
||||
renderWithProvider(makeItem(), vi.fn() as unknown as OnRegenerate, { theme })
|
||||
const content = screen.getByTestId('question-content')
|
||||
expect(content.getAttribute('style')).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle undefined message_files', () => {
|
||||
const item = makeItem({ message_files: undefined as unknown as FileEntity[] })
|
||||
const { container } = renderWithProvider(item)
|
||||
expect(container.querySelector('[class*="FileList"]')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle handleSwitchSibling call when siblings are missing', async () => {
|
||||
const user = userEvent.setup()
|
||||
const switchSibling = vi.fn()
|
||||
const item = makeItem({ prevSibling: undefined, nextSibling: undefined, siblingIndex: 0, siblingCount: 2 })
|
||||
renderWithProvider(item, vi.fn() as unknown as OnRegenerate, { switchSibling })
|
||||
|
||||
const prevBtn = screen.getByRole('button', { name: /previous/i })
|
||||
const nextBtn = screen.getByRole('button', { name: /next/i })
|
||||
|
||||
// These will now call switchSibling because of the mock, hit the falsy checks in Question
|
||||
await user.click(prevBtn)
|
||||
await user.click(nextBtn)
|
||||
|
||||
expect(switchSibling).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should clear timer on unmount when timer is active', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { unmount } = renderWithProvider(makeItem())
|
||||
await user.click(screen.getByTestId('edit-btn'))
|
||||
const textbox = await screen.findByRole('textbox')
|
||||
fireEvent.compositionStart(textbox)
|
||||
fireEvent.compositionEnd(textbox) // starts timer
|
||||
unmount()
|
||||
// Should not throw and branch should be hit
|
||||
})
|
||||
|
||||
it('should handle handleSwitchSibling with no siblings and missing switchSibling prop', async () => {
|
||||
const user = userEvent.setup()
|
||||
const item = makeItem({ prevSibling: undefined, nextSibling: undefined, siblingIndex: 0, siblingCount: 2 })
|
||||
renderWithProvider(item, vi.fn() as unknown as OnRegenerate, { switchSibling: undefined })
|
||||
|
||||
const prevBtn = screen.getByRole('button', { name: /previous/i })
|
||||
await user.click(prevBtn)
|
||||
expect(screen.queryByRole('alert')).not.toBeInTheDocument() // No crash
|
||||
})
|
||||
})
|
||||
|
||||
@ -54,6 +54,26 @@ describe('AgentContent', () => {
|
||||
expect(screen.getByTestId('agent-content-markdown')).toHaveTextContent('Log Annotation Content')
|
||||
})
|
||||
|
||||
it('renders empty string if logAnnotation content is missing', () => {
|
||||
const itemWithEmptyAnnotation = {
|
||||
...mockItem,
|
||||
annotation: {
|
||||
logAnnotation: { content: '' },
|
||||
},
|
||||
}
|
||||
const { rerender } = render(<AgentContent item={itemWithEmptyAnnotation as ChatItem} />)
|
||||
expect(screen.getByTestId('agent-content-markdown')).toHaveAttribute('data-content', '')
|
||||
|
||||
const itemWithUndefinedAnnotation = {
|
||||
...mockItem,
|
||||
annotation: {
|
||||
logAnnotation: {},
|
||||
},
|
||||
}
|
||||
rerender(<AgentContent item={itemWithUndefinedAnnotation as ChatItem} />)
|
||||
expect(screen.getByTestId('agent-content-markdown')).toHaveAttribute('data-content', '')
|
||||
})
|
||||
|
||||
it('renders content prop if provided and no annotation', () => {
|
||||
render(<AgentContent item={mockItem} content="Direct Content" />)
|
||||
expect(screen.getByTestId('agent-content-markdown')).toHaveTextContent('Direct Content')
|
||||
|
||||
@ -39,6 +39,28 @@ describe('BasicContent', () => {
|
||||
expect(markdown).toHaveAttribute('data-content', 'Annotated Content')
|
||||
})
|
||||
|
||||
it('renders empty string if logAnnotation content is missing', () => {
|
||||
const itemWithEmptyAnnotation = {
|
||||
...mockItem,
|
||||
annotation: {
|
||||
logAnnotation: {
|
||||
content: '',
|
||||
},
|
||||
},
|
||||
}
|
||||
const { rerender } = render(<BasicContent item={itemWithEmptyAnnotation as ChatItem} />)
|
||||
expect(screen.getByTestId('basic-content-markdown')).toHaveAttribute('data-content', '')
|
||||
|
||||
const itemWithUndefinedAnnotation = {
|
||||
...mockItem,
|
||||
annotation: {
|
||||
logAnnotation: {},
|
||||
},
|
||||
}
|
||||
rerender(<BasicContent item={itemWithUndefinedAnnotation as ChatItem} />)
|
||||
expect(screen.getByTestId('basic-content-markdown')).toHaveAttribute('data-content', '')
|
||||
})
|
||||
|
||||
it('wraps Windows UNC paths in backticks', () => {
|
||||
const itemWithUNC = {
|
||||
...mockItem,
|
||||
|
||||
@ -0,0 +1,376 @@
|
||||
import type { ChatItem } from '../../../types'
|
||||
import type { AppData } from '@/models/share'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import Answer from '../index'
|
||||
|
||||
// Mock the chat context
|
||||
vi.mock('../context', () => ({
|
||||
useChatContext: vi.fn(() => ({
|
||||
getHumanInputNodeData: vi.fn(),
|
||||
})),
|
||||
}))
|
||||
|
||||
describe('Answer Component', () => {
|
||||
const defaultProps = {
|
||||
item: {
|
||||
id: 'msg-1',
|
||||
content: 'Test response',
|
||||
isAnswer: true,
|
||||
} as unknown as ChatItem,
|
||||
question: 'Hello?',
|
||||
index: 0,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 500,
|
||||
})
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render basic content correctly', async () => {
|
||||
render(<Answer {...defaultProps} />)
|
||||
expect(screen.getByTestId('markdown-body')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render loading animation when responding and content is empty', () => {
|
||||
const { container } = render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{ id: '1', content: '', isAnswer: true } as unknown as ChatItem}
|
||||
responding={true}
|
||||
/>,
|
||||
)
|
||||
expect(container).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Component Blocks', () => {
|
||||
it('should render workflow process', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
workflowProcess: { status: 'running', tracing: [], steps: [] },
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('chat-answer-container')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render agent thoughts', () => {
|
||||
const { container } = render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
agent_thoughts: [{ id: '1', thought: 'Thinking...' }],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(container.querySelector('.group')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render file lists', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
allFiles: [{ id: 'f1', type: 'image', name: 'test.png' }],
|
||||
message_files: [{ id: 'f2', type: 'document', name: 'doc.pdf' }],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getAllByTestId('file-list')).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('should render annotation edit title', async () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
annotation: { id: 'a1', authorName: 'John Doe' },
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(await screen.findByText(/John Doe/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render citations', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
citation: [{ id: 'c1', title: 'Source 1' }],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('citation-title')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Human Inputs Layout', () => {
|
||||
it('should render human input form data list', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
humanInputFormDataList: [{ id: 'form1' }],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('chat-answer-container')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render human input filled form data list', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
humanInputFilledFormDataList: [{ id: 'form1_filled' }],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('chat-answer-container')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Interactions', () => {
|
||||
it('should handle switch sibling', () => {
|
||||
const mockSwitchSibling = vi.fn()
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
siblingCount: 3,
|
||||
siblingIndex: 1,
|
||||
prevSibling: 'msg-0',
|
||||
nextSibling: 'msg-2',
|
||||
} as unknown as ChatItem}
|
||||
switchSibling={mockSwitchSibling}
|
||||
/>,
|
||||
)
|
||||
|
||||
const prevBtn = screen.getByRole('button', { name: 'Previous' })
|
||||
fireEvent.click(prevBtn)
|
||||
expect(mockSwitchSibling).toHaveBeenCalledWith('msg-0')
|
||||
|
||||
// reset mock for next sibling click
|
||||
const nextBtn = screen.getByRole('button', { name: 'Next' })
|
||||
fireEvent.click(nextBtn)
|
||||
expect(mockSwitchSibling).toHaveBeenCalledWith('msg-2')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases and Props', () => {
|
||||
it('should handle hideAvatar properly', () => {
|
||||
render(<Answer {...defaultProps} hideAvatar={true} />)
|
||||
expect(screen.queryByTestId('emoji')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render custom answerIcon', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
answerIcon={<div data-testid="custom-answer-icon">Custom Icon</div>}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('custom-answer-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle hideProcessDetail with appData', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
hideProcessDetail={true}
|
||||
appData={{ site: { show_workflow_steps: false } } as unknown as AppData}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
workflowProcess: { status: 'running', tracing: [], steps: [] },
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('chat-answer-container')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render More component', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
more: { messages: [{ text: 'more content' }] },
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('more-container')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render content with hasHumanInput but contentIsEmpty and no agent_thoughts', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
content: '',
|
||||
humanInputFormDataList: [{ id: 'form1' }],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('chat-answer-container-humaninput')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render content switch within hasHumanInput but contentIsEmpty', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
content: '',
|
||||
siblingCount: 2,
|
||||
siblingIndex: 1,
|
||||
prevSibling: 'msg-0',
|
||||
humanInputFormDataList: [{ id: 'form1' }],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('chat-answer-container-humaninput')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle responding=true in human inputs layout block 2', () => {
|
||||
const { container } = render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
responding={true}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
content: '',
|
||||
humanInputFormDataList: [{ id: 'form1' }],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(container).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle ResizeObserver callback', () => {
|
||||
const originalResizeObserver = globalThis.ResizeObserver
|
||||
let triggerResize = () => { }
|
||||
globalThis.ResizeObserver = class ResizeObserver {
|
||||
constructor(callback: unknown) {
|
||||
triggerResize = callback as () => void
|
||||
}
|
||||
|
||||
observe() { }
|
||||
unobserve() { }
|
||||
disconnect() { }
|
||||
} as unknown as typeof ResizeObserver
|
||||
|
||||
render(<Answer {...defaultProps} />)
|
||||
|
||||
// Trigger the callback to cover getContentWidth and getHumanInputFormContainerWidth
|
||||
act(() => {
|
||||
triggerResize()
|
||||
})
|
||||
|
||||
globalThis.ResizeObserver = originalResizeObserver
|
||||
// Verify component still renders correctly after resize callback
|
||||
expect(screen.getByTestId('chat-answer-container')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render all component blocks within human inputs layout to cover missing branches', () => {
|
||||
const { container } = render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
humanInputFilledFormDataList: [{ id: 'form1' } as unknown as Record<string, unknown>],
|
||||
humanInputFormDataList: [], // hits length > 0 false branch
|
||||
agent_thoughts: [{ id: 'thought1', thought: 'thinking' }],
|
||||
allFiles: [{ _id: 'file1', name: 'file1.txt', type: 'document' } as unknown as Record<string, unknown>],
|
||||
message_files: [{ id: 'file2', url: 'http://test.com', type: 'image/png' } as unknown as Record<string, unknown>],
|
||||
annotation: { id: 'anno1', authorName: 'Author' } as unknown as Record<string, unknown>,
|
||||
citation: [{ item: { title: 'cite 1' } }] as unknown as Record<string, unknown>[],
|
||||
siblingCount: 2,
|
||||
siblingIndex: 1,
|
||||
prevSibling: 'msg-0',
|
||||
nextSibling: 'msg-2',
|
||||
more: { messages: [{ text: 'more content' }] },
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(container).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle hideProcessDetail with NO appData', () => {
|
||||
render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
hideProcessDetail={true}
|
||||
appData={undefined}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
workflowProcess: { status: 'running', tracing: [], steps: [] },
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByTestId('chat-answer-container')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle hideProcessDetail branches in human inputs layout', () => {
|
||||
// Branch: hideProcessDetail=true, appData=undefined
|
||||
const { container: c1 } = render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
hideProcessDetail={true}
|
||||
appData={undefined}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
workflowProcess: { status: 'running', tracing: [], steps: [] },
|
||||
humanInputFormDataList: [{ id: 'form1' } as unknown as Record<string, unknown>],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Branch: hideProcessDetail=true, appData provided
|
||||
const { container: c2 } = render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
hideProcessDetail={true}
|
||||
appData={{ site: { show_workflow_steps: false } } as unknown as AppData}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
workflowProcess: { status: 'running', tracing: [], steps: [] },
|
||||
humanInputFormDataList: [{ id: 'form1' } as unknown as Record<string, unknown>],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Branch: hideProcessDetail=false
|
||||
const { container: c3 } = render(
|
||||
<Answer
|
||||
{...defaultProps}
|
||||
hideProcessDetail={false}
|
||||
appData={{ site: { show_workflow_steps: true } } as unknown as AppData}
|
||||
item={{
|
||||
...defaultProps.item,
|
||||
workflowProcess: { status: 'running', tracing: [], steps: [] },
|
||||
humanInputFormDataList: [{ id: 'form1' } as unknown as Record<string, unknown>],
|
||||
} as unknown as ChatItem}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(c1).toBeInTheDocument()
|
||||
expect(c2).toBeInTheDocument()
|
||||
expect(c3).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -3,8 +3,6 @@ import type { ChatContextValue } from '../../context'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import copy from 'copy-to-clipboard'
|
||||
import * as React from 'react'
|
||||
import { vi } from 'vitest'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import Operation from '../operation'
|
||||
@ -98,12 +96,8 @@ vi.mock('@/app/components/base/features/new-feature-panel/annotation-reply/annot
|
||||
return (
|
||||
<div data-testid="annotation-ctrl">
|
||||
{cached
|
||||
? (
|
||||
<button data-testid="annotation-edit-btn" onClick={onEdit}>Edit</button>
|
||||
)
|
||||
: (
|
||||
<button data-testid="annotation-add-btn" onClick={handleAdd}>Add</button>
|
||||
)}
|
||||
? (<button data-testid="annotation-edit-btn" onClick={onEdit}>Edit</button>)
|
||||
: (<button data-testid="annotation-add-btn" onClick={handleAdd}>Add</button>)}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
@ -440,6 +434,17 @@ describe('Operation', () => {
|
||||
const bar = screen.getByTestId('operation-bar')
|
||||
expect(bar.querySelectorAll('.i-ri-thumb-up-line').length).toBe(0)
|
||||
})
|
||||
|
||||
it('should test feedback modal translation fallbacks', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockT.mockImplementation((_key: string): string => '')
|
||||
renderOperation()
|
||||
const thumbDown = screen.getByTestId('operation-bar').querySelector('.i-ri-thumb-down-line')!.closest('button')!
|
||||
await user.click(thumbDown)
|
||||
// Check if modal title/labels fallback works
|
||||
expect(screen.getByRole('tooltip')).toBeInTheDocument()
|
||||
mockT.mockImplementation(key => key)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Admin feedback (with annotation support)', () => {
|
||||
@ -538,6 +543,19 @@ describe('Operation', () => {
|
||||
renderOperation({ ...baseProps, item })
|
||||
expect(screen.getByTestId('operation-bar').querySelectorAll('.i-ri-thumb-up-line').length).toBe(0)
|
||||
})
|
||||
|
||||
it('should render action buttons with Default state when feedback rating is undefined', () => {
|
||||
// Setting a malformed feedback object with no rating but triggers the wrapper to see undefined fallbacks
|
||||
const item = {
|
||||
...baseItem,
|
||||
feedback: {} as unknown as Record<string, unknown>,
|
||||
adminFeedback: {} as unknown as Record<string, unknown>,
|
||||
} as ChatItem
|
||||
renderOperation({ ...baseProps, item })
|
||||
// Since it renders the 'else' block for hasAdminFeedback (which is false due to !)
|
||||
// the like/dislike regular ActionButtons should hit the Default state
|
||||
expect(screen.getByTestId('operation-bar')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Positioning and layout', () => {
|
||||
@ -595,6 +613,60 @@ describe('Operation', () => {
|
||||
// Reset to default behavior
|
||||
mockT.mockImplementation(key => key)
|
||||
})
|
||||
|
||||
it('should handle buildFeedbackTooltip with empty translation fallbacks', () => {
|
||||
// Mock t to return empty string for 'like' and 'dislike' to hit fallback branches:
|
||||
mockT.mockImplementation((key: string): string => {
|
||||
if (key.includes('operation.like'))
|
||||
return ''
|
||||
if (key.includes('operation.dislike'))
|
||||
return ''
|
||||
return key
|
||||
})
|
||||
const itemLike = { ...baseItem, feedback: { rating: 'like' as const, content: 'test content' } }
|
||||
const { rerender } = renderOperation({ ...baseProps, item: itemLike })
|
||||
expect(screen.getByTestId('operation-bar')).toBeInTheDocument()
|
||||
|
||||
const itemDislike = { ...baseItem, feedback: { rating: 'dislike' as const, content: 'test content' } }
|
||||
rerender(
|
||||
<div className="group">
|
||||
<Operation {...baseProps} item={itemDislike} />
|
||||
</div>,
|
||||
)
|
||||
expect(screen.getByTestId('operation-bar')).toBeInTheDocument()
|
||||
|
||||
mockT.mockImplementation(key => key)
|
||||
})
|
||||
|
||||
it('should handle buildFeedbackTooltip without rating', () => {
|
||||
// Mock tooltip display without rating to hit: 'if (!feedbackData?.rating) return label'
|
||||
const item = { ...baseItem, feedback: { rating: null } as unknown as Record<string, unknown> } as unknown as ChatItem
|
||||
renderOperation({ ...baseProps, item })
|
||||
const bar = screen.getByTestId('operation-bar')
|
||||
expect(bar).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle missing onFeedback gracefully in handleFeedback', async () => {
|
||||
const user = userEvent.setup()
|
||||
// First, render with feedback enabled to get the DOM node
|
||||
mockContextValue.config = makeChatConfig({ supportFeedback: true })
|
||||
mockContextValue.onFeedback = vi.fn()
|
||||
const { rerender } = renderOperation()
|
||||
|
||||
const thumbUp = screen.getByTestId('operation-bar').querySelector('.i-ri-thumb-up-line')!.closest('button')!
|
||||
|
||||
// Then, disable the context callback to hit the `if (!onFeedback) return` early exit internally upon rerender/click
|
||||
mockContextValue.onFeedback = undefined
|
||||
// Rerender to ensure the component closure gets the updated undefined value from the mock context
|
||||
rerender(
|
||||
<div className="group">
|
||||
<Operation {...baseProps} />
|
||||
</div>,
|
||||
)
|
||||
|
||||
await user.click(thumbUp)
|
||||
expect(mockContextValue.onFeedback).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Annotation integration', () => {
|
||||
@ -722,5 +794,53 @@ describe('Operation', () => {
|
||||
await user.click(screen.getByTestId('copy-btn'))
|
||||
expect(copy).toHaveBeenCalledWith('Hello world')
|
||||
})
|
||||
|
||||
it('should handle editing annotation missing onAnnotationEdited gracefully', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockContextValue.config = makeChatConfig({
|
||||
supportAnnotation: true,
|
||||
annotation_reply: { id: 'ar-1', score_threshold: 0.5, embedding_model: { embedding_provider_name: '', embedding_model_name: '' }, enabled: true },
|
||||
appId: 'test-app',
|
||||
})
|
||||
mockContextValue.onAnnotationEdited = undefined
|
||||
const item = { ...baseItem, annotation: { id: 'ann-1', created_at: 123, authorName: 'test author' } as unknown as Record<string, unknown> } as unknown as ChatItem
|
||||
renderOperation({ ...baseProps, item })
|
||||
const editBtn = screen.getByTestId('annotation-edit-btn')
|
||||
await user.click(editBtn)
|
||||
await user.click(screen.getByTestId('modal-edit'))
|
||||
expect(mockContextValue.onAnnotationEdited).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle adding annotation missing onAnnotationAdded gracefully', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockContextValue.config = makeChatConfig({
|
||||
supportAnnotation: true,
|
||||
annotation_reply: { id: 'ar-1', score_threshold: 0.5, embedding_model: { embedding_provider_name: '', embedding_model_name: '' }, enabled: true },
|
||||
appId: 'test-app',
|
||||
})
|
||||
mockContextValue.onAnnotationAdded = undefined
|
||||
const item = { ...baseItem, annotation: { id: 'ann-1', created_at: 123, authorName: 'test author' } as unknown as Record<string, unknown> } as unknown as ChatItem
|
||||
renderOperation({ ...baseProps, item })
|
||||
const editBtn = screen.getByTestId('annotation-edit-btn')
|
||||
await user.click(editBtn)
|
||||
await user.click(screen.getByTestId('modal-add'))
|
||||
expect(mockContextValue.onAnnotationAdded).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle removing annotation missing onAnnotationRemoved gracefully', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockContextValue.config = makeChatConfig({
|
||||
supportAnnotation: true,
|
||||
annotation_reply: { id: 'ar-1', score_threshold: 0.5, embedding_model: { embedding_provider_name: '', embedding_model_name: '' }, enabled: true },
|
||||
appId: 'test-app',
|
||||
})
|
||||
mockContextValue.onAnnotationRemoved = undefined
|
||||
const item = { ...baseItem, annotation: { id: 'ann-1', created_at: 123, authorName: 'test author' } as unknown as Record<string, unknown> } as unknown as ChatItem
|
||||
renderOperation({ ...baseProps, item })
|
||||
const editBtn = screen.getByTestId('annotation-edit-btn')
|
||||
await user.click(editBtn)
|
||||
await user.click(screen.getByTestId('modal-remove'))
|
||||
expect(mockContextValue.onAnnotationRemoved).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -0,0 +1,120 @@
|
||||
import type { FormInputItem } from '@/app/components/workflow/nodes/human-input/types'
|
||||
import type { Locale } from '@/i18n-config/language'
|
||||
import { UserActionButtonType } from '@/app/components/workflow/nodes/human-input/types'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import {
|
||||
getButtonStyle,
|
||||
getRelativeTime,
|
||||
initializeInputs,
|
||||
isRelativeTimeSameOrAfter,
|
||||
splitByOutputVar,
|
||||
} from '../utils'
|
||||
|
||||
const createInput = (overrides: Partial<FormInputItem>): FormInputItem => ({
|
||||
label: 'field',
|
||||
variable: 'field',
|
||||
required: false,
|
||||
max_length: 128,
|
||||
type: InputVarType.textInput,
|
||||
default: {
|
||||
type: 'constant' as const,
|
||||
value: '',
|
||||
selector: [], // Dummy selector
|
||||
},
|
||||
output_variable_name: 'field',
|
||||
...overrides,
|
||||
} as unknown as FormInputItem)
|
||||
|
||||
describe('human-input utils', () => {
|
||||
describe('getButtonStyle', () => {
|
||||
it('should map all supported button styles', () => {
|
||||
expect(getButtonStyle(UserActionButtonType.Primary)).toBe('primary')
|
||||
expect(getButtonStyle(UserActionButtonType.Default)).toBe('secondary')
|
||||
expect(getButtonStyle(UserActionButtonType.Accent)).toBe('secondary-accent')
|
||||
expect(getButtonStyle(UserActionButtonType.Ghost)).toBe('ghost')
|
||||
})
|
||||
|
||||
it('should return undefined for unsupported style values', () => {
|
||||
expect(getButtonStyle('unknown' as UserActionButtonType)).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('splitByOutputVar', () => {
|
||||
it('should split content around output variable placeholders', () => {
|
||||
expect(splitByOutputVar('Hello {{#$output.user_name#}}!')).toEqual([
|
||||
'Hello ',
|
||||
'{{#$output.user_name#}}',
|
||||
'!',
|
||||
])
|
||||
})
|
||||
|
||||
it('should return original content when no placeholders exist', () => {
|
||||
expect(splitByOutputVar('no placeholders')).toEqual(['no placeholders'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('initializeInputs', () => {
|
||||
it('should initialize text fields with constants and variable defaults', () => {
|
||||
const formInputs = [
|
||||
createInput({
|
||||
type: InputVarType.textInput,
|
||||
output_variable_name: 'name',
|
||||
default: { type: 'constant', value: 'John', selector: [] },
|
||||
}),
|
||||
createInput({
|
||||
type: InputVarType.paragraph,
|
||||
output_variable_name: 'bio',
|
||||
default: { type: 'variable', value: '', selector: [] },
|
||||
}),
|
||||
]
|
||||
|
||||
expect(initializeInputs(formInputs, { bio: 'Lives in Berlin' })).toEqual({
|
||||
name: 'John',
|
||||
bio: 'Lives in Berlin',
|
||||
})
|
||||
})
|
||||
|
||||
it('should set non text-like inputs to undefined', () => {
|
||||
const formInputs = [
|
||||
createInput({
|
||||
type: InputVarType.select,
|
||||
output_variable_name: 'role',
|
||||
}),
|
||||
]
|
||||
|
||||
expect(initializeInputs(formInputs)).toEqual({
|
||||
role: undefined,
|
||||
})
|
||||
})
|
||||
|
||||
it('should fallback to empty string when variable default is missing', () => {
|
||||
const formInputs = [
|
||||
createInput({
|
||||
type: InputVarType.textInput,
|
||||
output_variable_name: 'summary',
|
||||
default: { type: 'variable', value: '', selector: [] },
|
||||
}),
|
||||
]
|
||||
|
||||
expect(initializeInputs(formInputs, {})).toEqual({
|
||||
summary: '',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('time helpers', () => {
|
||||
it('should format relative time for supported and fallback locales', () => {
|
||||
const now = Date.now()
|
||||
const twoMinutesAgo = now - 2 * 60 * 1000
|
||||
|
||||
expect(getRelativeTime(twoMinutesAgo, 'en-US')).toMatch(/ago/i)
|
||||
expect(getRelativeTime(twoMinutesAgo, 'es-ES' as Locale)).toMatch(/ago/i)
|
||||
})
|
||||
|
||||
it('should compare utc timestamp against current time', () => {
|
||||
const now = Date.now()
|
||||
expect(isRelativeTimeSameOrAfter(now + 60_000)).toBe(true)
|
||||
expect(isRelativeTimeSameOrAfter(now - 60_000)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -152,10 +152,10 @@ const Answer: FC<AnswerProps> = ({
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<div className="chat-answer-container group ml-4 w-0 grow pb-4" ref={containerRef}>
|
||||
<div className="chat-answer-container group ml-4 w-0 grow pb-4" ref={containerRef} data-testid="chat-answer-container">
|
||||
{/* Block 1: Workflow Process + Human Input Forms */}
|
||||
{hasHumanInputs && (
|
||||
<div className={cn('group relative pr-10', chatAnswerContainerInner)}>
|
||||
<div className={cn('group relative pr-10', chatAnswerContainerInner)} data-testid="chat-answer-container-humaninput">
|
||||
<div
|
||||
ref={humanInputFormContainerRef}
|
||||
className={cn('relative inline-block w-full max-w-full rounded-2xl bg-chat-bubble-bg px-4 py-3 text-text-primary body-lg-regular')}
|
||||
@ -319,7 +319,7 @@ const Answer: FC<AnswerProps> = ({
|
||||
|
||||
{/* Original single block layout (when no human inputs) */}
|
||||
{!hasHumanInputs && (
|
||||
<div className={cn('group relative pr-10', chatAnswerContainerInner)}>
|
||||
<div className={cn('group relative pr-10', chatAnswerContainerInner)} data-testid="chat-answer-container-inner">
|
||||
<div
|
||||
ref={contentRef}
|
||||
className={cn('relative inline-block max-w-full rounded-2xl bg-chat-bubble-bg px-4 py-3 text-text-primary body-lg-regular', workflowProcess && 'w-full')}
|
||||
|
||||
@ -1,46 +1,145 @@
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import type { TransferMethod } from '@/types/app'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { vi } from 'vitest'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import ChatInputArea from '../index'
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hoist shared mock references so they are available inside vi.mock factories
|
||||
// ---------------------------------------------------------------------------
|
||||
const { mockGetPermission, mockNotify } = vi.hoisted(() => ({
|
||||
mockGetPermission: vi.fn().mockResolvedValue(undefined),
|
||||
mockNotify: vi.fn(),
|
||||
}))
|
||||
vi.setConfig({ testTimeout: 60000 })
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// External dependency mocks
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Track whether getPermission should reject
|
||||
const { mockGetPermissionConfig } = vi.hoisted(() => ({
|
||||
mockGetPermissionConfig: { shouldReject: false },
|
||||
}))
|
||||
|
||||
vi.mock('js-audio-recorder', () => ({
|
||||
default: class {
|
||||
static getPermission = mockGetPermission
|
||||
start = vi.fn()
|
||||
default: class MockRecorder {
|
||||
static getPermission = vi.fn().mockImplementation(() => {
|
||||
if (mockGetPermissionConfig.shouldReject) {
|
||||
return Promise.reject(new Error('Permission denied'))
|
||||
}
|
||||
return Promise.resolve(undefined)
|
||||
})
|
||||
|
||||
start = vi.fn().mockResolvedValue(undefined)
|
||||
stop = vi.fn()
|
||||
getWAVBlob = vi.fn().mockReturnValue(new Blob([''], { type: 'audio/wav' }))
|
||||
getRecordAnalyseData = vi.fn().mockReturnValue(new Uint8Array(128))
|
||||
getChannelData = vi.fn().mockReturnValue({ left: new Float32Array(0), right: new Float32Array(0) })
|
||||
getWAV = vi.fn().mockReturnValue(new ArrayBuffer(0))
|
||||
destroy = vi.fn()
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/voice-input/utils', () => ({
|
||||
convertToMp3: vi.fn().mockReturnValue(new Blob([''], { type: 'audio/mp3' })),
|
||||
}))
|
||||
|
||||
// Mock VoiceInput component - simplified version
|
||||
vi.mock('@/app/components/base/voice-input', () => {
|
||||
const VoiceInputMock = ({
|
||||
onCancel,
|
||||
onConverted,
|
||||
}: {
|
||||
onCancel: () => void
|
||||
onConverted: (text: string) => void
|
||||
}) => {
|
||||
// Use module-level state for simplicity
|
||||
const [showStop, setShowStop] = React.useState(true)
|
||||
|
||||
const handleStop = () => {
|
||||
setShowStop(false)
|
||||
// Simulate async conversion
|
||||
setTimeout(() => {
|
||||
onConverted('Converted voice text')
|
||||
setShowStop(true)
|
||||
}, 100)
|
||||
}
|
||||
|
||||
return (
|
||||
<div data-testid="voice-input-mock">
|
||||
<div data-testid="voice-input-speaking">voiceInput.speaking</div>
|
||||
<div data-testid="voice-input-converting-text">voiceInput.converting</div>
|
||||
{showStop && (
|
||||
<button data-testid="voice-input-stop" onClick={handleStop}>
|
||||
Stop
|
||||
</button>
|
||||
)}
|
||||
<button data-testid="voice-input-cancel" onClick={onCancel}>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
default: VoiceInputMock,
|
||||
}
|
||||
})
|
||||
|
||||
vi.stubGlobal('requestAnimationFrame', (cb: FrameRequestCallback) => setTimeout(() => cb(Date.now()), 16))
|
||||
vi.stubGlobal('cancelAnimationFrame', (id: number) => clearTimeout(id))
|
||||
vi.stubGlobal('devicePixelRatio', 1)
|
||||
|
||||
// Mock Canvas
|
||||
HTMLCanvasElement.prototype.getContext = vi.fn().mockReturnValue({
|
||||
scale: vi.fn(),
|
||||
beginPath: vi.fn(),
|
||||
moveTo: vi.fn(),
|
||||
rect: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
closePath: vi.fn(),
|
||||
clearRect: vi.fn(),
|
||||
roundRect: vi.fn(),
|
||||
})
|
||||
HTMLCanvasElement.prototype.getBoundingClientRect = vi.fn().mockReturnValue({
|
||||
width: 100,
|
||||
height: 50,
|
||||
})
|
||||
|
||||
vi.mock('@/service/share', () => ({
|
||||
audioToText: vi.fn().mockResolvedValue({ text: 'Converted text' }),
|
||||
audioToText: vi.fn().mockResolvedValue({ text: 'Converted voice text' }),
|
||||
AppSourceType: { webApp: 'webApp', installedApp: 'installedApp' },
|
||||
}))
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// File-uploader store – shared mutable state so individual tests can mutate it
|
||||
// File-uploader store
|
||||
// ---------------------------------------------------------------------------
|
||||
const mockFileStore: { files: FileEntity[], setFiles: ReturnType<typeof vi.fn> } = {
|
||||
files: [],
|
||||
setFiles: vi.fn(),
|
||||
}
|
||||
const {
|
||||
mockFileStore,
|
||||
mockIsDragActive,
|
||||
mockFeaturesState,
|
||||
mockNotify,
|
||||
mockIsMultipleLine,
|
||||
mockCheckInputsFormResult,
|
||||
} = vi.hoisted(() => ({
|
||||
mockFileStore: {
|
||||
files: [] as FileEntity[],
|
||||
setFiles: vi.fn(),
|
||||
},
|
||||
mockIsDragActive: { value: false },
|
||||
mockIsMultipleLine: { value: false },
|
||||
mockFeaturesState: {
|
||||
features: {
|
||||
moreLikeThis: { enabled: false },
|
||||
opening: { enabled: false },
|
||||
moderation: { enabled: false },
|
||||
speech2text: { enabled: false },
|
||||
text2speech: { enabled: false },
|
||||
file: { enabled: false },
|
||||
suggested: { enabled: false },
|
||||
citation: { enabled: false },
|
||||
annotationReply: { enabled: false },
|
||||
},
|
||||
},
|
||||
mockNotify: vi.fn(),
|
||||
mockCheckInputsFormResult: { value: true },
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/file-uploader/store', () => ({
|
||||
useFileStore: () => ({ getState: () => mockFileStore }),
|
||||
@ -50,9 +149,8 @@ vi.mock('@/app/components/base/file-uploader/store', () => ({
|
||||
}))
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// File-uploader hooks – provide stable drag/drop handlers
|
||||
// File-uploader hooks
|
||||
// ---------------------------------------------------------------------------
|
||||
let mockIsDragActive = false
|
||||
|
||||
vi.mock('@/app/components/base/file-uploader/hooks', () => ({
|
||||
useFile: () => ({
|
||||
@ -61,29 +159,13 @@ vi.mock('@/app/components/base/file-uploader/hooks', () => ({
|
||||
handleDragFileOver: vi.fn(),
|
||||
handleDropFile: vi.fn(),
|
||||
handleClipboardPasteFile: vi.fn(),
|
||||
isDragActive: mockIsDragActive,
|
||||
isDragActive: mockIsDragActive.value,
|
||||
}),
|
||||
}))
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Features context hook – avoids needing FeaturesContext.Provider in the tree
|
||||
// Features context mock
|
||||
// ---------------------------------------------------------------------------
|
||||
// FeatureBar calls: useFeatures(s => s.features)
|
||||
// So the selector receives the store state object; we must nest the features
|
||||
// under a `features` key to match what the real store exposes.
|
||||
const mockFeaturesState = {
|
||||
features: {
|
||||
moreLikeThis: { enabled: false },
|
||||
opening: { enabled: false },
|
||||
moderation: { enabled: false },
|
||||
speech2text: { enabled: false },
|
||||
text2speech: { enabled: false },
|
||||
file: { enabled: false },
|
||||
suggested: { enabled: false },
|
||||
citation: { enabled: false },
|
||||
annotationReply: { enabled: false },
|
||||
},
|
||||
}
|
||||
|
||||
vi.mock('@/app/components/base/features/hooks', () => ({
|
||||
useFeatures: (selector: (s: typeof mockFeaturesState) => unknown) =>
|
||||
@ -98,9 +180,8 @@ vi.mock('@/app/components/base/toast/context', () => ({
|
||||
}))
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal layout hook – controls single/multi-line textarea mode
|
||||
// Internal layout hook
|
||||
// ---------------------------------------------------------------------------
|
||||
let mockIsMultipleLine = false
|
||||
|
||||
vi.mock('../hooks', () => ({
|
||||
useTextAreaHeight: () => ({
|
||||
@ -110,17 +191,17 @@ vi.mock('../hooks', () => ({
|
||||
holdSpaceRef: { current: document.createElement('div') },
|
||||
handleTextareaResize: vi.fn(),
|
||||
get isMultipleLine() {
|
||||
return mockIsMultipleLine
|
||||
return mockIsMultipleLine.value
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Input-forms validation hook – always passes by default
|
||||
// Input-forms validation hook
|
||||
// ---------------------------------------------------------------------------
|
||||
vi.mock('../../check-input-forms-hooks', () => ({
|
||||
useCheckInputsForms: () => ({
|
||||
checkInputsForm: vi.fn().mockReturnValue(true),
|
||||
checkInputsForm: vi.fn().mockImplementation(() => mockCheckInputsFormResult.value),
|
||||
}),
|
||||
}))
|
||||
|
||||
@ -134,28 +215,10 @@ vi.mock('next/navigation', () => ({
|
||||
}))
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared fixture – typed as FileUpload to avoid implicit any
|
||||
// Shared fixture
|
||||
// ---------------------------------------------------------------------------
|
||||
// const mockVisionConfig: FileUpload = {
|
||||
// fileUploadConfig: {
|
||||
// image_file_size_limit: 10,
|
||||
// file_size_limit: 10,
|
||||
// audio_file_size_limit: 10,
|
||||
// video_file_size_limit: 10,
|
||||
// workflow_file_upload_limit: 10,
|
||||
// },
|
||||
// allowed_file_types: [],
|
||||
// allowed_file_extensions: [],
|
||||
// enabled: true,
|
||||
// number_limits: 3,
|
||||
// transfer_methods: ['local_file', 'remote_url'],
|
||||
// } as FileUpload
|
||||
|
||||
const mockVisionConfig: FileUpload = {
|
||||
// Required because of '& EnabledOrDisabled' at the end of your type
|
||||
enabled: true,
|
||||
|
||||
// The nested config object
|
||||
fileUploadConfig: {
|
||||
image_file_size_limit: 10,
|
||||
file_size_limit: 10,
|
||||
@ -168,34 +231,24 @@ const mockVisionConfig: FileUpload = {
|
||||
attachment_image_file_size_limit: 0,
|
||||
file_upload_limit: 0,
|
||||
},
|
||||
|
||||
// These match the keys in your FileUpload type
|
||||
allowed_file_types: [],
|
||||
allowed_file_extensions: [],
|
||||
number_limits: 3,
|
||||
|
||||
// NOTE: Your type defines 'allowed_file_upload_methods',
|
||||
// not 'transfer_methods' at the top level.
|
||||
allowed_file_upload_methods: ['local_file', 'remote_url'] as TransferMethod[],
|
||||
|
||||
// If you wanted to define specific image/video behavior:
|
||||
allowed_file_upload_methods: [TransferMethod.local_file, TransferMethod.remote_url],
|
||||
image: {
|
||||
enabled: true,
|
||||
number_limits: 3,
|
||||
transfer_methods: ['local_file', 'remote_url'] as TransferMethod[],
|
||||
transfer_methods: [TransferMethod.local_file, TransferMethod.remote_url],
|
||||
},
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Minimal valid FileEntity fixture – avoids undefined `type` crash in FileItem
|
||||
// ---------------------------------------------------------------------------
|
||||
const makeFile = (overrides: Partial<FileEntity> = {}): FileEntity => ({
|
||||
id: 'file-1',
|
||||
name: 'photo.png',
|
||||
type: 'image/png', // required: FileItem calls type.split('/')[0]
|
||||
type: 'image/png',
|
||||
size: 1024,
|
||||
progress: 100,
|
||||
transferMethod: 'local_file',
|
||||
transferMethod: TransferMethod.local_file,
|
||||
uploadedId: 'uploaded-ok',
|
||||
...overrides,
|
||||
} as FileEntity)
|
||||
@ -203,7 +256,10 @@ const makeFile = (overrides: Partial<FileEntity> = {}): FileEntity => ({
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
const getTextarea = () => screen.getByPlaceholderText(/inputPlaceholder/i)
|
||||
const getTextarea = () => (
|
||||
screen.queryByPlaceholderText(/inputPlaceholder/i)
|
||||
|| screen.queryByPlaceholderText(/inputDisabledPlaceholder/i)
|
||||
) as HTMLTextAreaElement | null
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
@ -212,15 +268,16 @@ describe('ChatInputArea', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockFileStore.files = []
|
||||
mockIsDragActive = false
|
||||
mockIsMultipleLine = false
|
||||
mockIsDragActive.value = false
|
||||
mockIsMultipleLine.value = false
|
||||
mockCheckInputsFormResult.value = true
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Rendering', () => {
|
||||
it('should render the textarea with default placeholder', () => {
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
expect(getTextarea()).toBeInTheDocument()
|
||||
expect(getTextarea()!).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the readonly placeholder when readonly prop is set', () => {
|
||||
@ -228,206 +285,152 @@ describe('ChatInputArea', () => {
|
||||
expect(screen.getByPlaceholderText(/inputDisabledPlaceholder/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the send button', () => {
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
expect(screen.getByTestId('send-button')).toBeInTheDocument()
|
||||
it('should include botName in placeholder text if provided', () => {
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} botName="TestBot" />)
|
||||
// The i18n pattern shows interpolation: namespace.key:{"botName":"TestBot"}
|
||||
expect(getTextarea()!).toHaveAttribute('placeholder', expect.stringContaining('botName'))
|
||||
})
|
||||
|
||||
it('should apply disabled styles when the disabled prop is true', () => {
|
||||
const { container } = render(<ChatInputArea visionConfig={mockVisionConfig} disabled />)
|
||||
const disabledWrapper = container.querySelector('.pointer-events-none')
|
||||
expect(disabledWrapper).toBeInTheDocument()
|
||||
expect(container.firstChild).toHaveClass('opacity-50')
|
||||
})
|
||||
|
||||
it('should apply drag-active styles when a file is being dragged over the input', () => {
|
||||
mockIsDragActive = true
|
||||
it('should apply drag-active styles when a file is being dragged over', () => {
|
||||
mockIsDragActive.value = true
|
||||
const { container } = render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
expect(container.querySelector('.border-dashed')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the operation section inline when single-line', () => {
|
||||
// mockIsMultipleLine is false by default
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
expect(screen.getByTestId('send-button')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the operation section below the textarea when multi-line', () => {
|
||||
mockIsMultipleLine = true
|
||||
it('should render the send button', () => {
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
expect(screen.getByTestId('send-button')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Typing', () => {
|
||||
describe('User Interaction', () => {
|
||||
it('should update textarea value as the user types', async () => {
|
||||
const user = userEvent.setup()
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(getTextarea(), 'Hello world')
|
||||
|
||||
expect(getTextarea()).toHaveValue('Hello world')
|
||||
await user.type(textarea, 'Hello world')
|
||||
expect(textarea).toHaveValue('Hello world')
|
||||
})
|
||||
|
||||
it('should clear the textarea after a message is successfully sent', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<ChatInputArea onSend={vi.fn()} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.type(getTextarea(), 'Hello world')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(getTextarea()).toHaveValue('')
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Sending Messages', () => {
|
||||
it('should call onSend with query and files when clicking the send button', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should clear the textarea after a message is sent', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(getTextarea(), 'Hello world')
|
||||
await user.type(textarea, 'Hello world')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).toHaveBeenCalledTimes(1)
|
||||
expect(onSend).toHaveBeenCalledWith('Hello world', [])
|
||||
expect(onSend).toHaveBeenCalled()
|
||||
expect(textarea).toHaveValue('')
|
||||
})
|
||||
|
||||
it('should call onSend and reset the input when pressing Enter', async () => {
|
||||
const user = userEvent.setup()
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(getTextarea(), 'Hello world{Enter}')
|
||||
await user.type(textarea, 'Hello world')
|
||||
fireEvent.keyDown(textarea, { key: 'Enter', code: 'Enter', nativeEvent: { isComposing: false } })
|
||||
|
||||
expect(onSend).toHaveBeenCalledWith('Hello world', [])
|
||||
expect(getTextarea()).toHaveValue('')
|
||||
expect(textarea).toHaveValue('')
|
||||
})
|
||||
|
||||
it('should NOT call onSend when pressing Shift+Enter (inserts newline instead)', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
it('should handle pasted text', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(getTextarea(), 'Hello world{Shift>}{Enter}{/Shift}')
|
||||
await user.click(textarea)
|
||||
await user.paste('Pasted text')
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(getTextarea()).toHaveValue('Hello world\n')
|
||||
})
|
||||
|
||||
it('should NOT call onSend in readonly mode', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} readonly />)
|
||||
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass already-uploaded files to onSend', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSend = vi.fn()
|
||||
|
||||
// makeFile ensures `type` is always a proper MIME string
|
||||
const uploadedFile = makeFile({ id: 'file-1', name: 'photo.png', uploadedId: 'uploaded-123' })
|
||||
mockFileStore.files = [uploadedFile]
|
||||
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
await user.type(getTextarea(), 'With attachment')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).toHaveBeenCalledWith('With attachment', [uploadedFile])
|
||||
})
|
||||
|
||||
it('should not send on Enter while IME composition is active, then send after composition ends', () => {
|
||||
vi.useFakeTimers()
|
||||
try {
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()
|
||||
|
||||
fireEvent.change(textarea, { target: { value: 'Composed text' } })
|
||||
fireEvent.compositionStart(textarea)
|
||||
fireEvent.keyDown(textarea, { key: 'Enter' })
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
|
||||
fireEvent.compositionEnd(textarea)
|
||||
vi.advanceTimersByTime(60)
|
||||
fireEvent.keyDown(textarea, { key: 'Enter' })
|
||||
|
||||
expect(onSend).toHaveBeenCalledWith('Composed text', [])
|
||||
}
|
||||
finally {
|
||||
vi.useRealTimers()
|
||||
}
|
||||
expect(textarea).toHaveValue('Pasted text')
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
describe('History Navigation', () => {
|
||||
it('should restore the last sent message when pressing Cmd+ArrowUp once', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should navigate back in history with Meta+ArrowUp', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea onSend={vi.fn()} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(textarea, 'First{Enter}')
|
||||
await user.type(textarea, 'Second{Enter}')
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}')
|
||||
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}')
|
||||
expect(textarea).toHaveValue('Second')
|
||||
})
|
||||
|
||||
it('should go further back in history with repeated Cmd+ArrowUp', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<ChatInputArea onSend={vi.fn()} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()
|
||||
|
||||
await user.type(textarea, 'First{Enter}')
|
||||
await user.type(textarea, 'Second{Enter}')
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}')
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}')
|
||||
|
||||
expect(textarea).toHaveValue('First')
|
||||
})
|
||||
|
||||
it('should move forward in history when pressing Cmd+ArrowDown', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should navigate forward in history with Meta+ArrowDown', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea onSend={vi.fn()} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(textarea, 'First{Enter}')
|
||||
await user.type(textarea, 'Second{Enter}')
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // → Second
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // → First
|
||||
await user.type(textarea, '{Meta>}{ArrowDown}{/Meta}') // → Second
|
||||
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // Second
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // First
|
||||
await user.type(textarea, '{Meta>}{ArrowDown}{/Meta}') // Second
|
||||
|
||||
expect(textarea).toHaveValue('Second')
|
||||
})
|
||||
|
||||
it('should clear the input when navigating past the most recent history entry', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should clear input when navigating past the end of history', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea onSend={vi.fn()} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(textarea, 'First{Enter}')
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // → First
|
||||
await user.type(textarea, '{Meta>}{ArrowDown}{/Meta}') // → past end → ''
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // First
|
||||
await user.type(textarea, '{Meta>}{ArrowDown}{/Meta}') // empty
|
||||
|
||||
expect(textarea).toHaveValue('')
|
||||
})
|
||||
|
||||
it('should not go below the start of history when pressing Cmd+ArrowUp at the boundary', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should NOT navigate history when typing regular text and pressing ArrowUp', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea onSend={vi.fn()} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(textarea, 'Only{Enter}')
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // → Only
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // → '' (seed at index 0)
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // boundary – should stay at ''
|
||||
await user.type(textarea, 'First{Enter}')
|
||||
await user.type(textarea, 'Some text')
|
||||
await user.keyboard('{ArrowUp}')
|
||||
|
||||
expect(textarea).toHaveValue('Some text')
|
||||
})
|
||||
|
||||
it('should handle ArrowUp when history is empty', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.keyboard('{Meta>}{ArrowUp}{/Meta}')
|
||||
expect(textarea).toHaveValue('')
|
||||
})
|
||||
|
||||
it('should handle ArrowDown at history boundary', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea onSend={vi.fn()} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(textarea, 'First{Enter}')
|
||||
await user.type(textarea, '{Meta>}{ArrowUp}{/Meta}') // First
|
||||
await user.type(textarea, '{Meta>}{ArrowDown}{/Meta}') // empty
|
||||
await user.type(textarea, '{Meta>}{ArrowDown}{/Meta}') // still empty
|
||||
|
||||
expect(textarea).toHaveValue('')
|
||||
})
|
||||
@ -435,160 +438,270 @@ describe('ChatInputArea', () => {
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Voice Input', () => {
|
||||
it('should render the voice input button when speech-to-text is enabled', () => {
|
||||
it('should render the voice input button when enabled', () => {
|
||||
render(<ChatInputArea speechToTextConfig={{ enabled: true }} visionConfig={mockVisionConfig} />)
|
||||
expect(screen.getByTestId('voice-input-button')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('voice-input-button')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should NOT render the voice input button when speech-to-text is disabled', () => {
|
||||
render(<ChatInputArea speechToTextConfig={{ enabled: false }} visionConfig={mockVisionConfig} />)
|
||||
expect(screen.queryByTestId('voice-input-button')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should request microphone permission when the voice button is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should handle stop recording in VoiceInput', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea speechToTextConfig={{ enabled: true }} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.click(screen.getByTestId('voice-input-button'))
|
||||
// Wait for VoiceInput to show speaking
|
||||
await screen.findByText(/voiceInput.speaking/i)
|
||||
const stopBtn = screen.getByTestId('voice-input-stop')
|
||||
await user.click(stopBtn)
|
||||
|
||||
expect(mockGetPermission).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should notify with an error when microphone permission is denied', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockGetPermission.mockRejectedValueOnce(new Error('Permission denied'))
|
||||
render(<ChatInputArea speechToTextConfig={{ enabled: true }} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.click(screen.getByTestId('voice-input-button'))
|
||||
// Converting should show up
|
||||
await screen.findByText(/voiceInput.converting/i)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' }))
|
||||
expect(getTextarea()!).toHaveValue('Converted voice text')
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT invoke onSend while voice input is being activated', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSend = vi.fn()
|
||||
render(
|
||||
<ChatInputArea
|
||||
onSend={onSend}
|
||||
speechToTextConfig={{ enabled: true }}
|
||||
visionConfig={mockVisionConfig}
|
||||
/>,
|
||||
)
|
||||
it('should handle cancel in VoiceInput', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea speechToTextConfig={{ enabled: true }} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.click(screen.getByTestId('voice-input-button'))
|
||||
await screen.findByText(/voiceInput.speaking/i)
|
||||
const stopBtn = screen.getByTestId('voice-input-stop')
|
||||
await user.click(stopBtn)
|
||||
|
||||
// Wait for converting and cancel button
|
||||
const cancelBtn = await screen.findByTestId('voice-input-cancel')
|
||||
await user.click(cancelBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('voice-input-stop')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error toast when voice permission is denied', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
mockGetPermissionConfig.shouldReject = true
|
||||
|
||||
render(<ChatInputArea speechToTextConfig={{ enabled: true }} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.click(screen.getByTestId('voice-input-button'))
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
// Permission denied should trigger error toast
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
})
|
||||
|
||||
mockGetPermissionConfig.shouldReject = false
|
||||
})
|
||||
|
||||
it('should handle empty converted text in VoiceInput', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
// Mock failure or empty result
|
||||
const { audioToText } = await import('@/service/share')
|
||||
vi.mocked(audioToText).mockResolvedValueOnce({ text: '' })
|
||||
|
||||
render(<ChatInputArea speechToTextConfig={{ enabled: true }} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.click(screen.getByTestId('voice-input-button'))
|
||||
await screen.findByText(/voiceInput.speaking/i)
|
||||
const stopBtn = screen.getByTestId('voice-input-stop')
|
||||
await user.click(stopBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('voice-input-stop')).toBeNull()
|
||||
})
|
||||
expect(getTextarea()!).toHaveValue('')
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Validation', () => {
|
||||
it('should notify and NOT call onSend when the query is blank', async () => {
|
||||
const user = userEvent.setup()
|
||||
describe('Validation & Constraints', () => {
|
||||
it('should notify and NOT send when query is blank', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'info' }))
|
||||
})
|
||||
|
||||
it('should notify and NOT call onSend when the query contains only whitespace', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.type(getTextarea(), ' ')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'info' }))
|
||||
})
|
||||
|
||||
it('should notify and NOT call onSend while the bot is already responding', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should notify and NOT send while bot is responding', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} isResponding visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.type(getTextarea(), 'Hello')
|
||||
await user.type(getTextarea()!, 'Hello')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'info' }))
|
||||
})
|
||||
|
||||
it('should NOT send while file upload is in progress', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
mockFileStore.files = [makeFile({ uploadedId: '', progress: 50 })]
|
||||
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
await user.type(getTextarea()!, 'Hello')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'info' }))
|
||||
})
|
||||
|
||||
it('should notify and NOT call onSend while a file upload is still in progress', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should send successfully with completed file uploads', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
|
||||
// uploadedId is empty string → upload not yet finished
|
||||
mockFileStore.files = [
|
||||
makeFile({ id: 'file-upload', uploadedId: '', progress: 50 }),
|
||||
]
|
||||
const completedFile = makeFile()
|
||||
mockFileStore.files = [completedFile]
|
||||
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
await user.type(getTextarea(), 'Hello')
|
||||
await user.type(getTextarea()!, 'Hello')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).toHaveBeenCalledWith('Hello', [completedFile])
|
||||
})
|
||||
|
||||
it('should handle mixed transfer methods correctly', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
const remoteFile = makeFile({
|
||||
id: 'remote',
|
||||
transferMethod: TransferMethod.remote_url,
|
||||
uploadedId: 'remote-id',
|
||||
})
|
||||
mockFileStore.files = [remoteFile]
|
||||
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
await user.type(getTextarea()!, 'Remote test')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).toHaveBeenCalledWith('Remote test', [remoteFile])
|
||||
})
|
||||
|
||||
it('should NOT call onSend if checkInputsForm fails', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
mockCheckInputsFormResult.value = false
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
|
||||
await user.type(getTextarea()!, 'Validation fail')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'info' }))
|
||||
})
|
||||
|
||||
it('should call onSend normally when all uploaded files have completed', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSend = vi.fn()
|
||||
it('should work when onSend prop is missing', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
|
||||
// uploadedId is present → upload finished
|
||||
mockFileStore.files = [makeFile({ uploadedId: 'uploaded-ok' })]
|
||||
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
await user.type(getTextarea(), 'With completed file')
|
||||
await user.type(getTextarea()!, 'No onSend')
|
||||
await user.click(screen.getByTestId('send-button'))
|
||||
// Should not throw
|
||||
})
|
||||
})
|
||||
|
||||
expect(onSend).toHaveBeenCalledTimes(1)
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Special Keyboard & Composition Events', () => {
|
||||
it('should NOT send on Enter if Shift is pressed', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
|
||||
await user.type(textarea, 'Hello')
|
||||
await user.keyboard('{Shift>}{Enter}{/Shift}')
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should block Enter key during composition', async () => {
|
||||
vi.useFakeTimers()
|
||||
const onSend = vi.fn()
|
||||
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
|
||||
fireEvent.compositionStart(textarea)
|
||||
fireEvent.change(textarea, { target: { value: 'Composing' } })
|
||||
fireEvent.keyDown(textarea, { key: 'Enter', code: 'Enter', nativeEvent: { isComposing: true } })
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
|
||||
fireEvent.compositionEnd(textarea)
|
||||
// Wait for the 50ms delay in handleCompositionEnd
|
||||
vi.advanceTimersByTime(60)
|
||||
|
||||
fireEvent.keyDown(textarea, { key: 'Enter', code: 'Enter', nativeEvent: { isComposing: false } })
|
||||
|
||||
expect(onSend).toHaveBeenCalled()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Layout & Styles', () => {
|
||||
it('should toggle opacity class based on disabled prop', () => {
|
||||
const { container, rerender } = render(<ChatInputArea visionConfig={mockVisionConfig} disabled={false} />)
|
||||
expect(container.firstChild).not.toHaveClass('opacity-50')
|
||||
|
||||
rerender(<ChatInputArea visionConfig={mockVisionConfig} disabled={true} />)
|
||||
expect(container.firstChild).toHaveClass('opacity-50')
|
||||
})
|
||||
|
||||
it('should handle multi-line layout correctly', () => {
|
||||
mockIsMultipleLine.value = true
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
// Send button should still be present
|
||||
expect(screen.getByTestId('send-button')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle drag enter event on textarea', () => {
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} />)
|
||||
const textarea = getTextarea()!
|
||||
fireEvent.dragOver(textarea, { dataTransfer: { types: ['Files'] } })
|
||||
// Verify no crash and textarea stays
|
||||
expect(textarea).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Feature Bar', () => {
|
||||
it('should render the FeatureBar section when showFeatureBar is true', () => {
|
||||
const { container } = render(
|
||||
<ChatInputArea visionConfig={mockVisionConfig} showFeatureBar />,
|
||||
)
|
||||
// FeatureBar renders a rounded-bottom container beneath the input
|
||||
expect(container.querySelector('[class*="rounded-b"]')).toBeInTheDocument()
|
||||
it('should render feature bar when showFeatureBar is true', () => {
|
||||
render(<ChatInputArea visionConfig={mockVisionConfig} showFeatureBar />)
|
||||
expect(screen.getByText(/feature.bar.empty/i)).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should NOT render the FeatureBar when showFeatureBar is false', () => {
|
||||
const { container } = render(
|
||||
<ChatInputArea visionConfig={mockVisionConfig} showFeatureBar={false} />,
|
||||
)
|
||||
expect(container.querySelector('[class*="rounded-b"]')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not invoke onFeatureBarClick when the component is in readonly mode', async () => {
|
||||
const user = userEvent.setup()
|
||||
it('should call onFeatureBarClick when clicked', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onFeatureBarClick = vi.fn()
|
||||
render(
|
||||
<ChatInputArea
|
||||
visionConfig={mockVisionConfig}
|
||||
showFeatureBar
|
||||
readonly
|
||||
onFeatureBarClick={onFeatureBarClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
// In readonly mode the FeatureBar receives `noop` as its click handler.
|
||||
// Click every button that is not a named test-id button to exercise the guard.
|
||||
const buttons = screen.queryAllByRole('button')
|
||||
for (const btn of buttons) {
|
||||
if (!btn.dataset.testid)
|
||||
await user.click(btn)
|
||||
}
|
||||
await user.click(screen.getByText(/feature.bar.empty/i))
|
||||
expect(onFeatureBarClick).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('should NOT call onFeatureBarClick when readonly', async () => {
|
||||
const user = userEvent.setup({ delay: null })
|
||||
const onFeatureBarClick = vi.fn()
|
||||
render(
|
||||
<ChatInputArea
|
||||
visionConfig={mockVisionConfig}
|
||||
showFeatureBar
|
||||
onFeatureBarClick={onFeatureBarClick}
|
||||
readonly
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText(/feature.bar.empty/i))
|
||||
expect(onFeatureBarClick).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import type { Resources } from '../index'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useDocumentDownload } from '@/service/knowledge/use-document'
|
||||
|
||||
import { downloadUrl } from '@/utils/download'
|
||||
@ -605,5 +604,113 @@ describe('Popup', () => {
|
||||
const tooltips = screen.getAllByTestId('citation-tooltip')
|
||||
expect(tooltips[2]).toBeInTheDocument()
|
||||
})
|
||||
|
||||
describe('Item Key Generation (Branch Coverage)', () => {
|
||||
it('should use index_node_hash when document_id is missing', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<Popup
|
||||
data={makeData({
|
||||
sources: [makeSource({ document_id: '', index_node_hash: 'hash-123' })],
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
await openPopup(user)
|
||||
// Verify it renders without key collision (no console error expected, though not explicitly checked here)
|
||||
expect(screen.getByTestId('popup-source-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use data.documentId when both source ids are missing', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<Popup
|
||||
data={makeData({
|
||||
documentId: 'parent-doc-id',
|
||||
sources: [makeSource({ document_id: undefined, index_node_hash: undefined })],
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
await openPopup(user)
|
||||
expect(screen.getByTestId('popup-source-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fallback to \'doc\' when all ids are missing', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<Popup
|
||||
data={makeData({
|
||||
documentId: undefined,
|
||||
sources: [makeSource({ document_id: undefined, index_node_hash: undefined })],
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
await openPopup(user)
|
||||
expect(screen.getByTestId('popup-source-item')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fallback to index when segment_position is missing', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<Popup
|
||||
data={makeData({
|
||||
sources: [makeSource({ document_id: 'doc-1', segment_position: undefined })],
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
await openPopup(user)
|
||||
expect(screen.getByTestId('popup-segment-position')).toHaveTextContent('1')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Download Logic Edge Cases (Branch Coverage)', () => {
|
||||
it('should return early if datasetId is missing', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<Popup
|
||||
data={makeData({
|
||||
dataSourceType: 'upload_file',
|
||||
sources: [makeSource({ dataset_id: '' })],
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
await openPopup(user)
|
||||
// Even if the button is rendered (it shouldn't be based on line 71),
|
||||
// we check the handler directly if possible, or just the button absence.
|
||||
expect(screen.queryByTestId('popup-download-btn')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should return early if both documentIds are missing', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<Popup
|
||||
data={makeData({
|
||||
documentId: '',
|
||||
dataSourceType: 'upload_file',
|
||||
sources: [makeSource({ document_id: '' })],
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
await openPopup(user)
|
||||
const btn = screen.queryByTestId('popup-download-btn')
|
||||
if (btn) {
|
||||
await user.click(btn)
|
||||
expect(mockDownloadDocument).not.toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
|
||||
it('should return early if not an upload file', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<Popup
|
||||
data={makeData({
|
||||
dataSourceType: 'notion',
|
||||
sources: [makeSource({ dataset_id: 'ds-1' })],
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
await openPopup(user)
|
||||
expect(screen.queryByTestId('popup-download-btn')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -169,6 +169,7 @@ const Chat: FC<ChatProps> = ({
|
||||
}, [handleScrollToBottom, handleWindowResize])
|
||||
|
||||
useEffect(() => {
|
||||
/* v8 ignore next - @preserve */
|
||||
if (chatContainerRef.current) {
|
||||
requestAnimationFrame(() => {
|
||||
handleScrollToBottom()
|
||||
@ -188,6 +189,7 @@ const Chat: FC<ChatProps> = ({
|
||||
}, [handleWindowResize])
|
||||
|
||||
useEffect(() => {
|
||||
/* v8 ignore next - @preserve */
|
||||
if (chatFooterRef.current && chatContainerRef.current) {
|
||||
const resizeContainerObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
@ -216,9 +218,10 @@ const Chat: FC<ChatProps> = ({
|
||||
useEffect(() => {
|
||||
const setUserScrolled = () => {
|
||||
const container = chatContainerRef.current
|
||||
/* v8 ignore next 2 - @preserve */
|
||||
if (!container)
|
||||
return
|
||||
|
||||
/* v8 ignore next 2 - @preserve */
|
||||
if (isAutoScrollingRef.current)
|
||||
return
|
||||
|
||||
@ -229,6 +232,7 @@ const Chat: FC<ChatProps> = ({
|
||||
}
|
||||
|
||||
const container = chatContainerRef.current
|
||||
/* v8 ignore next 2 - @preserve */
|
||||
if (!container)
|
||||
return
|
||||
|
||||
|
||||
@ -133,11 +133,13 @@ const Question: FC<QuestionProps> = ({
|
||||
}, [switchSibling, item.prevSibling, item.nextSibling])
|
||||
|
||||
const getContentWidth = () => {
|
||||
/* v8 ignore next 2 -- @preserve */
|
||||
if (contentRef.current)
|
||||
setContentWidth(contentRef.current?.clientWidth)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
/* v8 ignore next 2 -- @preserve */
|
||||
if (!contentRef.current)
|
||||
return
|
||||
const resizeObserver = new ResizeObserver(() => {
|
||||
|
||||
@ -1,7 +1,14 @@
|
||||
import type { RefObject } from 'react'
|
||||
import type { ChatConfig, ChatItem, ChatItemInTree } from '../../types'
|
||||
import type { EmbeddedChatbotContextValue } from '../context'
|
||||
import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { vi } from 'vitest'
|
||||
import type { ConversationItem } from '@/models/share'
|
||||
import {
|
||||
cleanup,
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from '@testing-library/react'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import {
|
||||
AppSourceType,
|
||||
@ -26,6 +33,10 @@ vi.mock('../inputs-form', () => ({
|
||||
default: () => <div>inputs form</div>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/markdown', () => ({
|
||||
Markdown: ({ content }: { content: string }) => <div>{content}</div>,
|
||||
}))
|
||||
|
||||
vi.mock('../../chat', () => ({
|
||||
__esModule: true,
|
||||
default: ({
|
||||
@ -63,6 +74,7 @@ vi.mock('../../chat', () => ({
|
||||
{questionIcon}
|
||||
<button onClick={() => onSend('hello world')}>send through chat</button>
|
||||
<button onClick={() => onRegenerate({ id: 'answer-1', isAnswer: true, content: 'answer', parentMessageId: 'question-1' })}>regenerate answer</button>
|
||||
<button onClick={() => onRegenerate({ id: 'answer-1', isAnswer: true, content: 'answer', parentMessageId: 'question-1' }, { message: 'new query' })}>regenerate edited</button>
|
||||
<button onClick={() => switchSibling('sibling-2')}>switch sibling</button>
|
||||
<button disabled={inputDisabled}>send message</button>
|
||||
<button onClick={onStopResponding}>stop responding</button>
|
||||
@ -113,7 +125,18 @@ const createContextValue = (overrides: Partial<EmbeddedChatbotContextValue> = {}
|
||||
use_icon_as_answer_icon: false,
|
||||
},
|
||||
},
|
||||
appParams: {} as ChatConfig,
|
||||
appParams: {
|
||||
system_parameters: {
|
||||
audio_file_size_limit: 1,
|
||||
file_size_limit: 1,
|
||||
image_file_size_limit: 1,
|
||||
video_file_size_limit: 1,
|
||||
workflow_file_upload_limit: 1,
|
||||
},
|
||||
more_like_this: {
|
||||
enabled: false,
|
||||
},
|
||||
} as ChatConfig,
|
||||
appChatListDataLoading: false,
|
||||
currentConversationId: '',
|
||||
currentConversationItem: undefined,
|
||||
@ -396,5 +419,245 @@ describe('EmbeddedChatbot chat-wrapper', () => {
|
||||
render(<ChatWrapper />)
|
||||
expect(screen.getByText('inputs form')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not disable sending when a required checkbox is not checked', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
inputsForms: [{ variable: 'agree', label: 'Agree', required: true, type: InputVarType.checkbox }],
|
||||
newConversationInputsRef: { current: { agree: false } },
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
expect(screen.getByRole('button', { name: 'send message' })).not.toBeDisabled()
|
||||
})
|
||||
|
||||
it('should return null for chatNode when all inputs are hidden', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
allInputsHidden: true,
|
||||
inputsForms: [{ variable: 'test', label: 'Test', type: InputVarType.textInput }],
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
expect(screen.queryByText('inputs form')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render simple welcome message when suggested questions are absent', () => {
|
||||
vi.mocked(useChat).mockReturnValue(createUseChatReturn({
|
||||
chatList: [{ id: 'opening-1', isAnswer: true, isOpeningStatement: true, content: 'Simple Welcome' }] as ChatItem[],
|
||||
}))
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
currentConversationId: '',
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
expect(screen.getByText('Simple Welcome')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use icon as answer icon when enabled in site config', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
appData: {
|
||||
app_id: 'app-1',
|
||||
can_replace_logo: true,
|
||||
custom_config: { remove_webapp_brand: false, replace_webapp_logo: '' },
|
||||
enable_site: true,
|
||||
end_user_id: 'user-1',
|
||||
site: {
|
||||
title: 'Embedded App',
|
||||
icon_type: 'emoji',
|
||||
icon: 'bot',
|
||||
icon_background: '#000000',
|
||||
icon_url: '',
|
||||
use_icon_as_answer_icon: true,
|
||||
},
|
||||
},
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Regeneration and config variants', () => {
|
||||
it('should handle regeneration with edited question', async () => {
|
||||
const handleSend = vi.fn()
|
||||
// IDs must match what's hardcoded in the mock Chat component's regenerate button
|
||||
const chatList = [
|
||||
{ id: 'question-1', isAnswer: false, content: 'Old question' },
|
||||
{ id: 'answer-1', isAnswer: true, content: 'Old answer', parentMessageId: 'question-1' },
|
||||
]
|
||||
vi.mocked(useChat).mockReturnValue(createUseChatReturn({
|
||||
handleSend,
|
||||
chatList: chatList as ChatItem[],
|
||||
}))
|
||||
|
||||
render(<ChatWrapper />)
|
||||
const regenBtn = screen.getByRole('button', { name: 'regenerate answer' })
|
||||
|
||||
fireEvent.click(regenBtn)
|
||||
expect(handleSend).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should use opening statement from currentConversationItem if available', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
appParams: { opening_statement: 'Global opening' } as ChatConfig,
|
||||
currentConversationItem: {
|
||||
id: 'conv-1',
|
||||
name: 'Conversation 1',
|
||||
inputs: {},
|
||||
introduction: 'Conversation specific opening',
|
||||
} as ConversationItem,
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
})
|
||||
|
||||
it('should handle mobile chatNode variants', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
isMobile: true,
|
||||
currentConversationId: 'conv-1',
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
})
|
||||
|
||||
it('should initialize collapsed based on currentConversationId and isTryApp', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
currentConversationId: 'conv-1',
|
||||
appSourceType: AppSourceType.tryApp,
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
})
|
||||
|
||||
it('should resume paused workflows when chat history is loaded', () => {
|
||||
const handleSwitchSibling = vi.fn()
|
||||
vi.mocked(useChat).mockReturnValue(createUseChatReturn({
|
||||
handleSwitchSibling,
|
||||
}))
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
appPrevChatList: [
|
||||
{
|
||||
id: 'node-1',
|
||||
isAnswer: true,
|
||||
content: '',
|
||||
workflow_run_id: 'run-1',
|
||||
humanInputFormDataList: [{ label: 'text', variable: 'v', required: true, type: InputVarType.textInput, hide: false }],
|
||||
children: [],
|
||||
} as unknown as ChatItemInTree,
|
||||
],
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
expect(handleSwitchSibling).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle conversation completion and suggested questions in chat actions', async () => {
|
||||
const handleSend = vi.fn()
|
||||
vi.mocked(useChat).mockReturnValue(createUseChatReturn({
|
||||
handleSend,
|
||||
}))
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
currentConversationId: 'conv-id', // index 0 true target
|
||||
appSourceType: AppSourceType.webApp,
|
||||
}))
|
||||
|
||||
render(<ChatWrapper />)
|
||||
fireEvent.click(screen.getByRole('button', { name: 'send through chat' }))
|
||||
|
||||
expect(handleSend).toHaveBeenCalled()
|
||||
const options = handleSend.mock.calls[0]?.[2] as { onConversationComplete?: (id: string) => void }
|
||||
expect(options.onConversationComplete).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle regeneration with parent answer and edited question', () => {
|
||||
const handleSend = vi.fn()
|
||||
const chatList = [
|
||||
{ id: 'question-1', isAnswer: false, content: 'Q1' },
|
||||
{ id: 'answer-1', isAnswer: true, content: 'A1', parentMessageId: 'question-1', metadata: { usage: { total_tokens: 10 } } },
|
||||
]
|
||||
vi.mocked(useChat).mockReturnValue(createUseChatReturn({
|
||||
handleSend,
|
||||
chatList: chatList as ChatItem[],
|
||||
}))
|
||||
|
||||
render(<ChatWrapper />)
|
||||
fireEvent.click(screen.getByRole('button', { name: 'regenerate edited' }))
|
||||
expect(handleSend).toHaveBeenCalledWith(expect.any(String), expect.objectContaining({ query: 'new query' }), expect.any(Object))
|
||||
})
|
||||
|
||||
it('should handle fallback values for config and user data', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
appParams: null,
|
||||
appId: undefined,
|
||||
initUserVariables: { avatar_url: 'url' }, // name is missing
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
})
|
||||
|
||||
it('should handle mobile view for welcome screens', () => {
|
||||
// Complex welcome mobile
|
||||
vi.mocked(useChat).mockReturnValue(createUseChatReturn({
|
||||
chatList: [{ id: 'o-1', isAnswer: true, isOpeningStatement: true, content: 'Welcome', suggestedQuestions: ['Q?'] }] as ChatItem[],
|
||||
}))
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
isMobile: true,
|
||||
currentConversationId: '',
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
|
||||
cleanup()
|
||||
// Simple welcome mobile
|
||||
vi.mocked(useChat).mockReturnValue(createUseChatReturn({
|
||||
chatList: [{ id: 'o-2', isAnswer: true, isOpeningStatement: true, content: 'Welcome' }] as ChatItem[],
|
||||
}))
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
isMobile: true,
|
||||
currentConversationId: '',
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
})
|
||||
|
||||
it('should handle loop early returns in input validation', () => {
|
||||
// hasEmptyInput early return (line 103)
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
inputsForms: [
|
||||
{ variable: 'v1', label: 'V1', required: true, type: InputVarType.textInput },
|
||||
{ variable: 'v2', label: 'V2', required: true, type: InputVarType.textInput },
|
||||
],
|
||||
newConversationInputsRef: { current: { v1: '', v2: '' } },
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
|
||||
cleanup()
|
||||
// fileIsUploading early return (line 106)
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
inputsForms: [
|
||||
{ variable: 'f1', label: 'F1', required: true, type: InputVarType.singleFile },
|
||||
{ variable: 'v2', label: 'V2', required: true, type: InputVarType.textInput },
|
||||
],
|
||||
newConversationInputsRef: {
|
||||
current: {
|
||||
f1: { transferMethod: 'local_file', uploadedId: '' },
|
||||
v2: '',
|
||||
},
|
||||
},
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
})
|
||||
|
||||
it('should handle null/undefined refs and config fallbacks', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue(createContextValue({
|
||||
currentChatInstanceRef: { current: null } as unknown as RefObject<{ handleStop: () => void }>,
|
||||
appParams: null,
|
||||
appMeta: null,
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
})
|
||||
|
||||
it('should handle isValidGeneratedAnswer truthy branch in regeneration', () => {
|
||||
const handleSend = vi.fn()
|
||||
// A valid generated answer needs metadata with usage
|
||||
const chatList = [
|
||||
{ id: 'question-1', isAnswer: false, content: 'Q' },
|
||||
{ id: 'answer-1', isAnswer: true, content: 'A', metadata: { usage: { total_tokens: 10 } }, parentMessageId: 'question-1' },
|
||||
]
|
||||
vi.mocked(useChat).mockReturnValue(createUseChatReturn({
|
||||
handleSend,
|
||||
chatList: chatList as ChatItem[],
|
||||
}))
|
||||
render(<ChatWrapper />)
|
||||
fireEvent.click(screen.getByRole('button', { name: 'regenerate answer' }))
|
||||
expect(handleSend).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -4,6 +4,7 @@ import type { AppConversationData, AppData, AppMeta, ConversationItem } from '@/
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { ToastProvider } from '@/app/components/base/toast'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import {
|
||||
AppSourceType,
|
||||
fetchChatList,
|
||||
@ -11,6 +12,7 @@ import {
|
||||
generationConversationName,
|
||||
} from '@/service/share'
|
||||
import { shareQueryKeys } from '@/service/use-share'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { CONVERSATION_ID_INFO } from '../../constants'
|
||||
import { useEmbeddedChatbot } from '../hooks'
|
||||
|
||||
@ -556,4 +558,343 @@ describe('useEmbeddedChatbot', () => {
|
||||
expect(updateFeedback).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('embeddedUserId and embeddedConversationId falsy paths', () => {
|
||||
it('should set userId to undefined when embeddedUserId is empty string', async () => {
|
||||
// This exercises the `embeddedUserId || undefined` branch on line 99
|
||||
mockStoreState.embeddedUserId = ''
|
||||
mockStoreState.embeddedConversationId = ''
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
await waitFor(() => {
|
||||
// When embeddedUserId is empty, allowResetChat is true (no conversationId from URL or stored)
|
||||
expect(result.current.allowResetChat).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Language settings', () => {
|
||||
it('should set language from URL parameters', async () => {
|
||||
const originalSearch = window.location.search
|
||||
Object.defineProperty(window, 'location', {
|
||||
writable: true,
|
||||
value: { search: '?locale=zh-Hans' },
|
||||
})
|
||||
const { changeLanguage } = await import('@/i18n-config/client')
|
||||
|
||||
await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
expect(changeLanguage).toHaveBeenCalledWith('zh-Hans')
|
||||
Object.defineProperty(window, 'location', { value: { search: originalSearch } })
|
||||
})
|
||||
|
||||
it('should set language from system variables when URL param is missing', async () => {
|
||||
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({ locale: 'fr-FR' })
|
||||
const { changeLanguage } = await import('@/i18n-config/client')
|
||||
|
||||
await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
expect(changeLanguage).toHaveBeenCalledWith('fr-FR')
|
||||
})
|
||||
|
||||
it('should fall back to app default language', async () => {
|
||||
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({})
|
||||
mockStoreState.appInfo = {
|
||||
app_id: 'app-1',
|
||||
site: {
|
||||
title: 'Test App',
|
||||
default_language: 'ja-JP',
|
||||
},
|
||||
} as unknown as AppData
|
||||
const { changeLanguage } = await import('@/i18n-config/client')
|
||||
|
||||
await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
expect(changeLanguage).toHaveBeenCalledWith('ja-JP')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Additional Input Form Edges', () => {
|
||||
it('should handle invalid number inputs and checkbox defaults', async () => {
|
||||
mockStoreState.appParams = {
|
||||
user_input_form: [
|
||||
{ number: { variable: 'n1', default: 10 } },
|
||||
{ checkbox: { variable: 'c1', default: false } },
|
||||
],
|
||||
} as unknown as ChatConfig
|
||||
mockGetProcessedInputsFromUrlParams.mockResolvedValue({
|
||||
n1: 'not-a-number',
|
||||
c1: 'true',
|
||||
})
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
const forms = result.current.inputsForms
|
||||
expect(forms.find(f => f.variable === 'n1')?.default).toBe(10)
|
||||
expect(forms.find(f => f.variable === 'c1')?.default).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle select with invalid option and file-list/json types', async () => {
|
||||
mockStoreState.appParams = {
|
||||
user_input_form: [
|
||||
{ select: { variable: 's1', options: ['A'], default: 'A' } },
|
||||
],
|
||||
} as unknown as ChatConfig
|
||||
mockGetProcessedInputsFromUrlParams.mockResolvedValue({
|
||||
s1: 'INVALID',
|
||||
})
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
expect(result.current.inputsForms[0].default).toBe('A')
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleConversationIdInfoChange logic', () => {
|
||||
it('should handle existing appId as string and update it to object', async () => {
|
||||
localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({ 'app-1': 'legacy-id' }))
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
act(() => {
|
||||
result.current.handleConversationIdInfoChange('new-conv-id')
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
const stored = JSON.parse(localStorage.getItem(CONVERSATION_ID_INFO) || '{}')
|
||||
const appEntry = stored['app-1']
|
||||
// userId may be 'embedded-user-1' or 'DEFAULT' depending on timing; either is valid
|
||||
const storedId = appEntry?.['embedded-user-1'] ?? appEntry?.DEFAULT
|
||||
expect(storedId).toBe('new-conv-id')
|
||||
})
|
||||
})
|
||||
|
||||
it('should use DEFAULT when userId is null', async () => {
|
||||
// Override userId to be null/empty to exercise the "|| 'DEFAULT'" fallback path
|
||||
mockStoreState.embeddedUserId = null
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
act(() => {
|
||||
result.current.handleConversationIdInfoChange('default-conv-id')
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
const stored = JSON.parse(localStorage.getItem(CONVERSATION_ID_INFO) || '{}')
|
||||
const appEntry = stored['app-1']
|
||||
// Should use DEFAULT key since userId is null
|
||||
expect(appEntry?.DEFAULT).toBe('default-conv-id')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('allInputsHidden and no required variables', () => {
|
||||
it('should pass checkInputsRequired immediately when there are no required fields', async () => {
|
||||
mockStoreState.appParams = {
|
||||
user_input_form: [
|
||||
// All optional (not required)
|
||||
{ 'text-input': { variable: 't1', required: false, label: 'T1' } },
|
||||
],
|
||||
} as unknown as ChatConfig
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
const onStart = vi.fn()
|
||||
act(() => {
|
||||
result.current.handleStartChat(onStart)
|
||||
})
|
||||
expect(onStart).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass checkInputsRequired when all inputs are hidden', async () => {
|
||||
mockStoreState.appParams = {
|
||||
user_input_form: [
|
||||
{ 'text-input': { variable: 't1', required: true, label: 'T1', hide: true } },
|
||||
{ 'text-input': { variable: 't2', required: true, label: 'T2', hide: true } },
|
||||
],
|
||||
} as unknown as ChatConfig
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
await waitFor(() => expect(result.current.allInputsHidden).toBe(true))
|
||||
|
||||
const onStart = vi.fn()
|
||||
act(() => {
|
||||
result.current.handleStartChat(onStart)
|
||||
})
|
||||
expect(onStart).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('checkInputsRequired silent mode and multi-file', () => {
|
||||
it('should return true in silent mode even if fields are missing', async () => {
|
||||
mockStoreState.appParams = {
|
||||
user_input_form: [{ 'text-input': { variable: 't1', required: true, label: 'T1' } }],
|
||||
} as unknown as ChatConfig
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
// checkInputsRequired is internal; trigger via handleStartChat which calls it
|
||||
const onStart = vi.fn()
|
||||
act(() => {
|
||||
// With silent=true not exposed, we test that handleStartChat calls the callback
|
||||
// when allInputsHidden is true (all forms hidden)
|
||||
result.current.handleStartChat(onStart)
|
||||
})
|
||||
// The form field has required=true but silent mode through allInputsHidden=false,
|
||||
// so the callback is NOT called (validation blocked it)
|
||||
// This exercises the silent=false path with empty field -> notify -> return false
|
||||
expect(onStart).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle multi-file uploading status', async () => {
|
||||
mockStoreState.appParams = {
|
||||
user_input_form: [{ 'file-list': { variable: 'files', required: true, type: InputVarType.multiFiles } }],
|
||||
} as unknown as ChatConfig
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNewConversationInputsChange({
|
||||
files: [
|
||||
{ transferMethod: TransferMethod.local_file, uploadedId: 'ok' },
|
||||
{ transferMethod: TransferMethod.local_file, uploadedId: null },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
// handleStartChat returns void, but we just verify no callback fires (file upload pending)
|
||||
const onStart = vi.fn()
|
||||
act(() => {
|
||||
result.current.handleStartChat(onStart)
|
||||
})
|
||||
expect(onStart).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should detect single-file upload still in progress', async () => {
|
||||
mockStoreState.appParams = {
|
||||
user_input_form: [{ 'file-list': { variable: 'f1', required: true, type: InputVarType.singleFile } }],
|
||||
} as unknown as ChatConfig
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
act(() => {
|
||||
// Single file (not array) transfer that hasn't finished uploading
|
||||
result.current.handleNewConversationInputsChange({
|
||||
f1: { transferMethod: TransferMethod.local_file, uploadedId: null },
|
||||
})
|
||||
})
|
||||
|
||||
const onStart = vi.fn()
|
||||
act(() => {
|
||||
result.current.handleStartChat(onStart)
|
||||
})
|
||||
expect(onStart).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should skip validation for hasEmptyInput when fileIsUploading already set', async () => {
|
||||
// Two required fields: first passes but starts uploading, second would be empty — should be skipped
|
||||
mockStoreState.appParams = {
|
||||
user_input_form: [
|
||||
{ 'file-list': { variable: 'f1', required: true, type: InputVarType.multiFiles } },
|
||||
{ 'text-input': { variable: 't1', required: true, label: 'T1' } },
|
||||
],
|
||||
} as unknown as ChatConfig
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNewConversationInputsChange({
|
||||
f1: [{ transferMethod: TransferMethod.local_file, uploadedId: null }],
|
||||
t1: '', // empty but should be skipped because fileIsUploading is set first
|
||||
})
|
||||
})
|
||||
|
||||
const onStart = vi.fn()
|
||||
act(() => {
|
||||
result.current.handleStartChat(onStart)
|
||||
})
|
||||
expect(onStart).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFormattedChatList edge cases', () => {
|
||||
it('should handle messages with no message_files and no agent_thoughts', async () => {
|
||||
// Ensure a currentConversationId is set so appChatListData is fetched
|
||||
localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({ 'app-1': { DEFAULT: 'conversation-1' } }))
|
||||
mockFetchConversations.mockResolvedValue(
|
||||
createConversationData({ data: [createConversationItem({ id: 'conversation-1' })] }),
|
||||
)
|
||||
mockFetchChatList.mockResolvedValue({
|
||||
data: [{
|
||||
id: 'msg-no-files',
|
||||
query: 'Q',
|
||||
answer: 'A',
|
||||
// no message_files, no agent_thoughts — exercises the || [] fallback branches
|
||||
}],
|
||||
})
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
await waitFor(() => expect(result.current.appPrevChatList.length).toBeGreaterThan(0), { timeout: 3000 })
|
||||
|
||||
const chatList = result.current.appPrevChatList
|
||||
const question = chatList.find((m: unknown) => (m as Record<string, unknown>).id === 'question-msg-no-files')
|
||||
expect(question).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('currentConversationItem from pinned list', () => {
|
||||
it('should find currentConversationItem from pinned list when not in main list', async () => {
|
||||
const pinnedData = createConversationData({
|
||||
data: [createConversationItem({ id: 'pinned-conv', name: 'Pinned' })],
|
||||
})
|
||||
mockFetchConversations.mockImplementation(async (_a: unknown, _b: unknown, _c: unknown, pinned?: boolean) => {
|
||||
return pinned ? pinnedData : createConversationData({ data: [] })
|
||||
})
|
||||
mockFetchChatList.mockResolvedValue({ data: [] })
|
||||
localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({ 'app-1': { DEFAULT: 'pinned-conv' } }))
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.pinnedConversationList.length).toBeGreaterThan(0)
|
||||
}, { timeout: 3000 })
|
||||
await waitFor(() => {
|
||||
expect(result.current.currentConversationItem?.id).toBe('pinned-conv')
|
||||
}, { timeout: 3000 })
|
||||
})
|
||||
})
|
||||
|
||||
describe('newConversation updates existing item', () => {
|
||||
it('should update an existing conversation in the list when its id matches', async () => {
|
||||
const initialItem = createConversationItem({ id: 'conversation-1', name: 'Old Name' })
|
||||
const renamedItem = createConversationItem({ id: 'conversation-1', name: 'New Generated Name' })
|
||||
mockFetchConversations.mockResolvedValue(createConversationData({ data: [initialItem] }))
|
||||
mockGenerationConversationName.mockResolvedValue(renamedItem)
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
await waitFor(() => expect(result.current.conversationList.length).toBeGreaterThan(0))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNewConversationCompleted('conversation-1')
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
const match = result.current.conversationList.find(c => c.id === 'conversation-1')
|
||||
expect(match?.name).toBe('New Generated Name')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('currentConversationLatestInputs', () => {
|
||||
it('should return inputs from latest chat message when conversation has data', async () => {
|
||||
const convId = 'conversation-with-inputs'
|
||||
localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({ 'app-1': { DEFAULT: convId } }))
|
||||
mockFetchConversations.mockResolvedValue(
|
||||
createConversationData({ data: [createConversationItem({ id: convId })] }),
|
||||
)
|
||||
mockFetchChatList.mockResolvedValue({
|
||||
data: [{ id: 'm1', query: 'Q', answer: 'A', inputs: { key1: 'val1' } }],
|
||||
})
|
||||
|
||||
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
|
||||
|
||||
await waitFor(() => expect(result.current.currentConversationItem?.id).toBe(convId), { timeout: 3000 })
|
||||
// After item is resolved, currentConversationInputs should be populated
|
||||
await waitFor(() => expect(result.current.currentConversationInputs).toBeDefined(), { timeout: 3000 })
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -3,9 +3,8 @@ import type { ImgHTMLAttributes } from 'react'
|
||||
import type { EmbeddedChatbotContextValue } from '../../context'
|
||||
import type { AppData } from '@/models/share'
|
||||
import type { SystemFeatures } from '@/types/feature'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { act, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { vi } from 'vitest'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { InstallationScope, LicenseStatus } from '@/types/feature'
|
||||
import { useEmbeddedChatbotContext } from '../../context'
|
||||
@ -120,6 +119,18 @@ describe('EmbeddedChatbot Header', () => {
|
||||
Object.defineProperty(window, 'top', { value: window, configurable: true })
|
||||
})
|
||||
|
||||
const dispatchChatbotConfigMessage = async (origin: string, payload: { isToggledByButton: boolean, isDraggable: boolean }) => {
|
||||
await act(async () => {
|
||||
window.dispatchEvent(new MessageEvent('message', {
|
||||
origin,
|
||||
data: {
|
||||
type: 'dify-chatbot-config',
|
||||
payload,
|
||||
},
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
describe('Desktop Rendering', () => {
|
||||
it('should render desktop header with branding by default', async () => {
|
||||
render(<Header title="Test Chatbot" />)
|
||||
@ -164,7 +175,23 @@ describe('EmbeddedChatbot Header', () => {
|
||||
expect(img).toHaveAttribute('src', 'https://example.com/workspace.png')
|
||||
})
|
||||
|
||||
it('should render Dify logo by default when no branding or custom logo is provided', () => {
|
||||
it('should render Dify logo by default when branding enabled is true but no logo provided', () => {
|
||||
vi.mocked(useGlobalPublicStore).mockImplementation((selector: (s: GlobalPublicStoreMock) => unknown) => selector({
|
||||
systemFeatures: {
|
||||
...defaultSystemFeatures,
|
||||
branding: {
|
||||
...defaultSystemFeatures.branding,
|
||||
enabled: true,
|
||||
workspace_logo: '',
|
||||
},
|
||||
},
|
||||
setSystemFeatures: vi.fn(),
|
||||
}))
|
||||
render(<Header title="Test Chatbot" />)
|
||||
expect(screen.getByAltText('Dify logo')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render Dify logo when branding is disabled', () => {
|
||||
vi.mocked(useGlobalPublicStore).mockImplementation((selector: (s: GlobalPublicStoreMock) => unknown) => selector({
|
||||
systemFeatures: {
|
||||
...defaultSystemFeatures,
|
||||
@ -196,6 +223,20 @@ describe('EmbeddedChatbot Header', () => {
|
||||
expect(screen.queryByTestId('webapp-brand')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render divider only when currentConversationId is present', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue({ ...defaultContext } as EmbeddedChatbotContextValue)
|
||||
const { unmount } = render(<Header title="Test Chatbot" />)
|
||||
expect(screen.getByTestId('divider')).toBeInTheDocument()
|
||||
unmount()
|
||||
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue({
|
||||
...defaultContext,
|
||||
currentConversationId: '',
|
||||
} as EmbeddedChatbotContextValue)
|
||||
render(<Header title="Test Chatbot" />)
|
||||
expect(screen.queryByTestId('divider')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render reset button when allowResetChat is true and conversation exists', () => {
|
||||
render(<Header title="Test Chatbot" allowResetChat={true} />)
|
||||
|
||||
@ -266,6 +307,42 @@ describe('EmbeddedChatbot Header', () => {
|
||||
|
||||
expect(screen.getByTestId('mobile-reset-chat-button')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should NOT render mobile reset button when currentConversationId is missing', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue({
|
||||
...defaultContext,
|
||||
currentConversationId: '',
|
||||
} as EmbeddedChatbotContextValue)
|
||||
render(<Header title="Mobile Chatbot" isMobile allowResetChat />)
|
||||
|
||||
expect(screen.queryByTestId('mobile-reset-chat-button')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render ViewFormDropdown in mobile when conditions are met', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue({
|
||||
...defaultContext,
|
||||
inputsForms: [{ id: '1' }],
|
||||
} as EmbeddedChatbotContextValue)
|
||||
render(<Header title="Mobile Chatbot" isMobile />)
|
||||
expect(screen.getByTestId('view-form-dropdown')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle mobile expand button', async () => {
|
||||
const user = userEvent.setup()
|
||||
const mockPostMessage = setupIframe()
|
||||
render(<Header title="Mobile Chatbot" isMobile />)
|
||||
|
||||
await dispatchChatbotConfigMessage('https://parent.com', { isToggledByButton: true, isDraggable: false })
|
||||
|
||||
const expandBtn = await screen.findByTestId('mobile-expand-button')
|
||||
expect(expandBtn).toBeInTheDocument()
|
||||
|
||||
await user.click(expandBtn)
|
||||
expect(mockPostMessage).toHaveBeenCalledWith(
|
||||
{ type: 'dify-chatbot-expand-change' },
|
||||
'https://parent.com',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Iframe Communication', () => {
|
||||
@ -284,13 +361,7 @@ describe('EmbeddedChatbot Header', () => {
|
||||
const mockPostMessage = setupIframe()
|
||||
render(<Header title="Iframe" />)
|
||||
|
||||
window.dispatchEvent(new MessageEvent('message', {
|
||||
origin: 'https://parent.com',
|
||||
data: {
|
||||
type: 'dify-chatbot-config',
|
||||
payload: { isToggledByButton: true, isDraggable: false },
|
||||
},
|
||||
}))
|
||||
await dispatchChatbotConfigMessage('https://parent.com', { isToggledByButton: true, isDraggable: false })
|
||||
|
||||
const expandBtn = await screen.findByTestId('expand-button')
|
||||
expect(expandBtn).toBeInTheDocument()
|
||||
@ -308,13 +379,7 @@ describe('EmbeddedChatbot Header', () => {
|
||||
setupIframe()
|
||||
render(<Header title="Iframe" />)
|
||||
|
||||
window.dispatchEvent(new MessageEvent('message', {
|
||||
origin: 'https://parent.com',
|
||||
data: {
|
||||
type: 'dify-chatbot-config',
|
||||
payload: { isToggledByButton: true, isDraggable: true },
|
||||
},
|
||||
}))
|
||||
await dispatchChatbotConfigMessage('https://parent.com', { isToggledByButton: true, isDraggable: true })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('expand-button')).not.toBeInTheDocument()
|
||||
@ -325,20 +390,43 @@ describe('EmbeddedChatbot Header', () => {
|
||||
setupIframe()
|
||||
render(<Header title="Iframe" />)
|
||||
|
||||
window.dispatchEvent(new MessageEvent('message', {
|
||||
origin: 'https://secure.com',
|
||||
data: { type: 'dify-chatbot-config', payload: { isToggledByButton: true, isDraggable: false } },
|
||||
}))
|
||||
await dispatchChatbotConfigMessage('https://secure.com', { isToggledByButton: true, isDraggable: false })
|
||||
|
||||
await screen.findByTestId('expand-button')
|
||||
|
||||
window.dispatchEvent(new MessageEvent('message', {
|
||||
origin: 'https://malicious.com',
|
||||
data: { type: 'dify-chatbot-config', payload: { isToggledByButton: false, isDraggable: false } },
|
||||
}))
|
||||
await dispatchChatbotConfigMessage('https://malicious.com', { isToggledByButton: false, isDraggable: false })
|
||||
|
||||
// Should still be visible (not hidden by the malicious message)
|
||||
expect(screen.getByTestId('expand-button')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should ignore non-config messages for origin locking', async () => {
|
||||
setupIframe()
|
||||
render(<Header title="Iframe" />)
|
||||
|
||||
await act(async () => {
|
||||
window.dispatchEvent(new MessageEvent('message', {
|
||||
origin: 'https://first.com',
|
||||
data: { type: 'other-type' },
|
||||
}))
|
||||
})
|
||||
|
||||
await dispatchChatbotConfigMessage('https://second.com', { isToggledByButton: true, isDraggable: false })
|
||||
|
||||
// Should lock to second.com
|
||||
const expandBtn = await screen.findByTestId('expand-button')
|
||||
expect(expandBtn).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should NOT handle toggle expand if showToggleExpandButton is false', async () => {
|
||||
const mockPostMessage = setupIframe()
|
||||
render(<Header title="Iframe" />)
|
||||
// Directly call handleToggleExpand would require more setup, but we can verify it doesn't trigger unexpectedly
|
||||
expect(mockPostMessage).not.toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'dify-chatbot-expand-change' }),
|
||||
expect.anything(),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
|
||||
@ -118,4 +118,30 @@ describe('InputsFormNode', () => {
|
||||
const mainDiv = screen.getByTestId('inputs-form-node')
|
||||
expect(mainDiv).toHaveClass('mb-0 px-0')
|
||||
})
|
||||
|
||||
it('should apply mobile styles when isMobile is true', () => {
|
||||
vi.mocked(useEmbeddedChatbotContext).mockReturnValue({
|
||||
...mockContextValue,
|
||||
isMobile: true,
|
||||
} as unknown as any)
|
||||
const { rerender } = render(<InputsFormNode collapsed={false} setCollapsed={setCollapsed} />)
|
||||
|
||||
// Main container
|
||||
const mainDiv = screen.getByTestId('inputs-form-node')
|
||||
expect(mainDiv).toHaveClass('mb-4 pt-4')
|
||||
|
||||
// Header container (parent of the icon)
|
||||
const header = screen.getByText(/chat.chatSettingsTitle/i).parentElement
|
||||
expect(header).toHaveClass('px-4 py-3')
|
||||
|
||||
// Content container
|
||||
expect(screen.getByTestId('mock-inputs-form-content').parentElement).toHaveClass('p-4')
|
||||
|
||||
// Start chat button container
|
||||
expect(screen.getByTestId('inputs-form-start-chat-button').parentElement).toHaveClass('p-4')
|
||||
|
||||
// Collapsed state mobile styles
|
||||
rerender(<InputsFormNode collapsed={true} setCollapsed={setCollapsed} />)
|
||||
expect(screen.getByText(/chat.chatSettingsTitle/i).parentElement).toHaveClass('px-4 py-3')
|
||||
})
|
||||
})
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
import { CssTransform, hexToRGBA } from '../utils'
|
||||
|
||||
describe('Theme Utils', () => {
|
||||
describe('hexToRGBA', () => {
|
||||
it('should convert hex with # to rgba', () => {
|
||||
expect(hexToRGBA('#000000', 1)).toBe('rgba(0,0,0,1)')
|
||||
expect(hexToRGBA('#FFFFFF', 0.5)).toBe('rgba(255,255,255,0.5)')
|
||||
expect(hexToRGBA('#FF0000', 0.1)).toBe('rgba(255,0,0,0.1)')
|
||||
})
|
||||
|
||||
it('should convert hex without # to rgba', () => {
|
||||
expect(hexToRGBA('000000', 1)).toBe('rgba(0,0,0,1)')
|
||||
expect(hexToRGBA('FFFFFF', 0.5)).toBe('rgba(255,255,255,0.5)')
|
||||
})
|
||||
|
||||
it('should handle various opacity values', () => {
|
||||
expect(hexToRGBA('#000000', 0)).toBe('rgba(0,0,0,0)')
|
||||
expect(hexToRGBA('#000000', 1)).toBe('rgba(0,0,0,1)')
|
||||
})
|
||||
})
|
||||
|
||||
describe('CssTransform', () => {
|
||||
it('should return empty object for empty string', () => {
|
||||
expect(CssTransform('')).toEqual({})
|
||||
})
|
||||
|
||||
it('should transform single property', () => {
|
||||
expect(CssTransform('color: red')).toEqual({ color: 'red' })
|
||||
})
|
||||
|
||||
it('should transform multiple properties', () => {
|
||||
expect(CssTransform('color: red; margin: 10px')).toEqual({
|
||||
color: 'red',
|
||||
margin: '10px',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle extra whitespace', () => {
|
||||
expect(CssTransform(' color : red ; margin : 10px ')).toEqual({
|
||||
color: 'red',
|
||||
margin: '10px',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle trailing semicolon', () => {
|
||||
expect(CssTransform('color: red;')).toEqual({ color: 'red' })
|
||||
})
|
||||
|
||||
it('should ignore empty pairs', () => {
|
||||
expect(CssTransform('color: red;; margin: 10px; ')).toEqual({
|
||||
color: 'red',
|
||||
margin: '10px',
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -192,4 +192,226 @@ describe('checkbox list component', () => {
|
||||
await userEvent.click(screen.getByText('common.operation.resetKeywords'))
|
||||
expect(input).toHaveValue('')
|
||||
})
|
||||
|
||||
it('does not toggle disabled option when clicked', async () => {
|
||||
const onChange = vi.fn()
|
||||
const disabledOptions = [
|
||||
{ label: 'Enabled', value: 'enabled' },
|
||||
{ label: 'Disabled', value: 'disabled', disabled: true },
|
||||
]
|
||||
|
||||
render(
|
||||
<CheckboxList
|
||||
options={disabledOptions}
|
||||
value={[]}
|
||||
onChange={onChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
const disabledCheckbox = screen.getByTestId('checkbox-disabled')
|
||||
await userEvent.click(disabledCheckbox)
|
||||
expect(onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not toggle option when component is disabled and option is clicked via div', async () => {
|
||||
const onChange = vi.fn()
|
||||
|
||||
render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
value={[]}
|
||||
onChange={onChange}
|
||||
disabled
|
||||
/>,
|
||||
)
|
||||
|
||||
// Find option and click the div container
|
||||
const optionLabels = screen.getAllByText('Option 1')
|
||||
const optionDiv = optionLabels[0].closest('[data-testid="option-item"]')
|
||||
expect(optionDiv).toBeInTheDocument()
|
||||
await userEvent.click(optionDiv as HTMLElement)
|
||||
expect(onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('renders with label prop', () => {
|
||||
render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
label="Test Label"
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText('Test Label')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders without showSelectAll, showCount, showSearch', () => {
|
||||
render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
showSelectAll={false}
|
||||
showCount={false}
|
||||
showSearch={false}
|
||||
/>,
|
||||
)
|
||||
expect(screen.queryByTestId('checkbox-selectAll')).not.toBeInTheDocument()
|
||||
options.forEach((option) => {
|
||||
expect(screen.getByText(option.label)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('renders with custom containerClassName', () => {
|
||||
const { container } = render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
containerClassName="custom-class"
|
||||
/>,
|
||||
)
|
||||
expect(container.querySelector('.custom-class')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('applies maxHeight style to options container', () => {
|
||||
render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
maxHeight="200px"
|
||||
/>,
|
||||
)
|
||||
const optionsContainer = screen.getByTestId('options-container')
|
||||
expect(optionsContainer).toHaveStyle({ maxHeight: '200px', overflowY: 'auto' })
|
||||
})
|
||||
|
||||
it('shows indeterminate state when some options are selected', async () => {
|
||||
const onChange = vi.fn()
|
||||
render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
value={['option1', 'option2']}
|
||||
onChange={onChange}
|
||||
showSelectAll
|
||||
/>,
|
||||
)
|
||||
// When some but not all options are selected, clicking select-all should select all remaining options
|
||||
const selectAll = screen.getByTestId('checkbox-selectAll')
|
||||
expect(selectAll).toBeInTheDocument()
|
||||
expect(selectAll).toHaveAttribute('aria-checked', 'mixed')
|
||||
|
||||
await userEvent.click(selectAll)
|
||||
expect(onChange).toHaveBeenCalledWith(['option1', 'option2', 'option3', 'apple'])
|
||||
})
|
||||
|
||||
it('filters options correctly when searching', async () => {
|
||||
render(<CheckboxList options={options} />)
|
||||
|
||||
const input = screen.getByRole('textbox')
|
||||
await userEvent.type(input, 'option')
|
||||
|
||||
expect(screen.getByText('Option 1')).toBeInTheDocument()
|
||||
expect(screen.getByText('Option 2')).toBeInTheDocument()
|
||||
expect(screen.getByText('Option 3')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Apple')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows no data message when no options match search', async () => {
|
||||
render(<CheckboxList options={options} />)
|
||||
|
||||
const input = screen.getByRole('textbox')
|
||||
await userEvent.type(input, 'xyz')
|
||||
|
||||
expect(screen.getByText(/common.operation.noSearchResults/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('toggles option by clicking option row', async () => {
|
||||
const onChange = vi.fn()
|
||||
|
||||
render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
value={[]}
|
||||
onChange={onChange}
|
||||
showSelectAll={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
const optionLabel = screen.getByText('Option 1')
|
||||
const optionRow = optionLabel.closest('div[data-testid="option-item"]')
|
||||
expect(optionRow).toBeInTheDocument()
|
||||
await userEvent.click(optionRow as HTMLElement)
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith(['option1'])
|
||||
})
|
||||
|
||||
it('does not toggle when clicking disabled option row', async () => {
|
||||
const onChange = vi.fn()
|
||||
const disabledOptions = [
|
||||
{ label: 'Option 1', value: 'option1', disabled: true },
|
||||
]
|
||||
|
||||
render(
|
||||
<CheckboxList
|
||||
options={disabledOptions}
|
||||
value={[]}
|
||||
onChange={onChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
const optionRow = screen.getByText('Option 1').closest('div[data-testid="option-item"]')
|
||||
expect(optionRow).toBeInTheDocument()
|
||||
await userEvent.click(optionRow as HTMLElement)
|
||||
|
||||
expect(onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('renders without title and description', () => {
|
||||
render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
title=""
|
||||
description=""
|
||||
/>,
|
||||
)
|
||||
expect(screen.queryByText(/Test Title/)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText(/Test Description/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows correct filtered count message when searching', async () => {
|
||||
render(
|
||||
<CheckboxList
|
||||
options={options}
|
||||
title="Items"
|
||||
/>,
|
||||
)
|
||||
|
||||
const input = screen.getByRole('textbox')
|
||||
await userEvent.type(input, 'opt')
|
||||
|
||||
expect(screen.getByText(/operation.searchCount/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows no data message when no options are provided', () => {
|
||||
render(
|
||||
<CheckboxList
|
||||
options={[]}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText('common.noData')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('does not toggle option when component is disabled even with enabled option', async () => {
|
||||
const onChange = vi.fn()
|
||||
const disabledOptions = [
|
||||
{ label: 'Option', value: 'option' },
|
||||
]
|
||||
|
||||
render(
|
||||
<CheckboxList
|
||||
options={disabledOptions}
|
||||
value={[]}
|
||||
onChange={onChange}
|
||||
disabled
|
||||
/>,
|
||||
)
|
||||
|
||||
const checkbox = screen.getByTestId('checkbox-option')
|
||||
await userEvent.click(checkbox)
|
||||
expect(onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -161,6 +161,7 @@ const CheckboxList: FC<CheckboxListProps> = ({
|
||||
<div
|
||||
className="p-1"
|
||||
style={maxHeight ? { maxHeight, overflowY: 'auto' } : {}}
|
||||
data-testid="options-container"
|
||||
>
|
||||
{!filteredOptions.length
|
||||
? (
|
||||
@ -183,6 +184,7 @@ const CheckboxList: FC<CheckboxListProps> = ({
|
||||
return (
|
||||
<div
|
||||
key={option.value}
|
||||
data-testid="option-item"
|
||||
className={cn(
|
||||
'flex cursor-pointer items-center gap-2 rounded-md px-2 py-1.5 transition-colors hover:bg-state-base-hover',
|
||||
option.disabled && 'cursor-not-allowed opacity-50',
|
||||
|
||||
@ -64,4 +64,47 @@ describe('Checkbox Component', () => {
|
||||
expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled')
|
||||
expect(checkbox).toHaveClass('cursor-not-allowed')
|
||||
})
|
||||
|
||||
it('handles keyboard events (Space and Enter) when not disabled', () => {
|
||||
const onCheck = vi.fn()
|
||||
render(<Checkbox {...mockProps} onCheck={onCheck} />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
|
||||
fireEvent.keyDown(checkbox, { key: ' ' })
|
||||
expect(onCheck).toHaveBeenCalledTimes(1)
|
||||
|
||||
fireEvent.keyDown(checkbox, { key: 'Enter' })
|
||||
expect(onCheck).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('does not handle keyboard events when disabled', () => {
|
||||
const onCheck = vi.fn()
|
||||
render(<Checkbox {...mockProps} disabled onCheck={onCheck} />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
|
||||
fireEvent.keyDown(checkbox, { key: ' ' })
|
||||
expect(onCheck).not.toHaveBeenCalled()
|
||||
|
||||
fireEvent.keyDown(checkbox, { key: 'Enter' })
|
||||
expect(onCheck).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('exposes aria-disabled attribute', () => {
|
||||
const { rerender } = render(<Checkbox {...mockProps} />)
|
||||
expect(screen.getByTestId('checkbox-test')).toHaveAttribute('aria-disabled', 'false')
|
||||
|
||||
rerender(<Checkbox {...mockProps} disabled />)
|
||||
expect(screen.getByTestId('checkbox-test')).toHaveAttribute('aria-disabled', 'true')
|
||||
})
|
||||
|
||||
it('normalizes aria-checked attribute', () => {
|
||||
const { rerender } = render(<Checkbox {...mockProps} />)
|
||||
expect(screen.getByTestId('checkbox-test')).toHaveAttribute('aria-checked', 'false')
|
||||
|
||||
rerender(<Checkbox {...mockProps} checked />)
|
||||
expect(screen.getByTestId('checkbox-test')).toHaveAttribute('aria-checked', 'true')
|
||||
|
||||
rerender(<Checkbox {...mockProps} indeterminate />)
|
||||
expect(screen.getByTestId('checkbox-test')).toHaveAttribute('aria-checked', 'mixed')
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import { RiCheckLine } from '@remixicon/react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import IndeterminateIcon from './assets/indeterminate-icon'
|
||||
|
||||
type CheckboxProps = {
|
||||
id?: string
|
||||
checked?: boolean
|
||||
onCheck?: (event: React.MouseEvent<HTMLDivElement>) => void
|
||||
onCheck?: (event: React.MouseEvent<HTMLDivElement> | React.KeyboardEvent<HTMLDivElement>) => void
|
||||
className?: string
|
||||
disabled?: boolean
|
||||
indeterminate?: boolean
|
||||
@ -40,10 +39,23 @@ const Checkbox = ({
|
||||
return
|
||||
onCheck?.(event)
|
||||
}}
|
||||
onKeyDown={(event) => {
|
||||
if (disabled)
|
||||
return
|
||||
if (event.key === ' ' || event.key === 'Enter') {
|
||||
if (event.key === ' ')
|
||||
event.preventDefault()
|
||||
onCheck?.(event)
|
||||
}
|
||||
}}
|
||||
data-testid={`checkbox-${id}`}
|
||||
role="checkbox"
|
||||
aria-checked={indeterminate ? 'mixed' : !!checked}
|
||||
aria-disabled={!!disabled}
|
||||
tabIndex={disabled ? -1 : 0}
|
||||
>
|
||||
{!checked && indeterminate && <IndeterminateIcon />}
|
||||
{checked && <RiCheckLine className="h-3 w-3" data-testid={`check-icon-${id}`} />}
|
||||
{checked && <div className="i-ri-check-line h-3 w-3" data-testid={`check-icon-${id}`} />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -61,6 +61,11 @@ describe('CopyFeedbackNew', () => {
|
||||
expect(container.querySelector('.cursor-pointer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders with custom className', () => {
|
||||
const { container } = render(<CopyFeedbackNew content="test content" className="test-class" />)
|
||||
expect(container.querySelector('.test-class')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('applies copied CSS class when copied is true', () => {
|
||||
mockCopied = true
|
||||
const { container } = render(<CopyFeedbackNew content="test content" />)
|
||||
|
||||
@ -21,17 +21,19 @@ const CopyFeedback = ({ content }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { copied, copy, reset } = useClipboard()
|
||||
|
||||
const tooltipText = copied
|
||||
? t(`${prefixEmbedded}.copied`, { ns: 'appOverview' })
|
||||
: t(`${prefixEmbedded}.copy`, { ns: 'appOverview' })
|
||||
/* v8 ignore next -- i18n test mock always returns a non-empty string; runtime fallback is defensive. -- @preserve */
|
||||
const safeText = tooltipText || ''
|
||||
|
||||
const handleCopy = useCallback(() => {
|
||||
copy(content)
|
||||
}, [copy, content])
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
popupContent={
|
||||
(copied
|
||||
? t(`${prefixEmbedded}.copied`, { ns: 'appOverview' })
|
||||
: t(`${prefixEmbedded}.copy`, { ns: 'appOverview' })) || ''
|
||||
}
|
||||
popupContent={safeText}
|
||||
>
|
||||
<ActionButton>
|
||||
<div
|
||||
@ -52,27 +54,27 @@ export const CopyFeedbackNew = ({ content, className }: Pick<Props, 'className'
|
||||
const { t } = useTranslation()
|
||||
const { copied, copy, reset } = useClipboard()
|
||||
|
||||
const tooltipText = copied
|
||||
? t(`${prefixEmbedded}.copied`, { ns: 'appOverview' })
|
||||
: t(`${prefixEmbedded}.copy`, { ns: 'appOverview' })
|
||||
/* v8 ignore next -- i18n test mock always returns a non-empty string; runtime fallback is defensive. -- @preserve */
|
||||
const safeText = tooltipText || ''
|
||||
|
||||
const handleCopy = useCallback(() => {
|
||||
copy(content)
|
||||
}, [copy, content])
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
popupContent={
|
||||
(copied
|
||||
? t(`${prefixEmbedded}.copied`, { ns: 'appOverview' })
|
||||
: t(`${prefixEmbedded}.copy`, { ns: 'appOverview' })) || ''
|
||||
}
|
||||
popupContent={safeText}
|
||||
>
|
||||
<div
|
||||
className={`h-8 w-8 cursor-pointer rounded-lg hover:bg-components-button-ghost-bg-hover ${className ?? ''
|
||||
}`}
|
||||
className={`h-8 w-8 cursor-pointer rounded-lg hover:bg-components-button-ghost-bg-hover ${className ?? ''}`}
|
||||
>
|
||||
<div
|
||||
onClick={handleCopy}
|
||||
onMouseLeave={reset}
|
||||
className={`h-full w-full ${copyStyle.copyIcon} ${copied ? copyStyle.copied : ''
|
||||
}`}
|
||||
className={`h-full w-full ${copyStyle.copyIcon} ${copied ? copyStyle.copied : ''}`}
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { fireEvent, render } from '@testing-library/react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import CopyIcon from '..'
|
||||
|
||||
const copy = vi.fn()
|
||||
@ -20,33 +20,28 @@ describe('copy icon component', () => {
|
||||
})
|
||||
|
||||
it('renders normally', () => {
|
||||
const { container } = render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
expect(container.querySelector('svg')).not.toBeNull()
|
||||
})
|
||||
|
||||
it('shows copy icon initially', () => {
|
||||
const { container } = render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
const icon = container.querySelector('[data-icon="Copy"]')
|
||||
render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
const icon = screen.getByTestId('copy-icon')
|
||||
expect(icon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows copy check icon when copied', () => {
|
||||
copied = true
|
||||
const { container } = render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
const icon = container.querySelector('[data-icon="CopyCheck"]')
|
||||
render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
const icon = screen.getByTestId('copied-icon')
|
||||
expect(icon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles copy when clicked', () => {
|
||||
const { container } = render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
const icon = container.querySelector('[data-icon="Copy"]')
|
||||
render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
const icon = screen.getByTestId('copy-icon')
|
||||
fireEvent.click(icon as Element)
|
||||
expect(copy).toBeCalledTimes(1)
|
||||
})
|
||||
|
||||
it('resets on mouse leave', () => {
|
||||
const { container } = render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
const icon = container.querySelector('[data-icon="Copy"]')
|
||||
render(<CopyIcon content="this is some test content for the copy icon component" />)
|
||||
const icon = screen.getByTestId('copy-icon')
|
||||
const div = icon?.parentElement as HTMLElement
|
||||
fireEvent.mouseLeave(div)
|
||||
expect(reset).toBeCalledTimes(1)
|
||||
|
||||
@ -2,10 +2,6 @@
|
||||
import { useClipboard } from 'foxact/use-clipboard'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
Copy,
|
||||
CopyCheck,
|
||||
} from '@/app/components/base/icons/src/vender/line/files'
|
||||
import Tooltip from '../tooltip'
|
||||
|
||||
type Props = {
|
||||
@ -22,22 +18,20 @@ const CopyIcon = ({ content }: Props) => {
|
||||
copy(content)
|
||||
}, [copy, content])
|
||||
|
||||
const tooltipText = copied
|
||||
? t(`${prefixEmbedded}.copied`, { ns: 'appOverview' })
|
||||
: t(`${prefixEmbedded}.copy`, { ns: 'appOverview' })
|
||||
/* v8 ignore next -- i18n test mock always returns a non-empty string; runtime fallback is defensive. -- @preserve */
|
||||
const safeTooltipText = tooltipText || ''
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
popupContent={
|
||||
(copied
|
||||
? t(`${prefixEmbedded}.copied`, { ns: 'appOverview' })
|
||||
: t(`${prefixEmbedded}.copy`, { ns: 'appOverview' })) || ''
|
||||
}
|
||||
popupContent={safeTooltipText}
|
||||
>
|
||||
<div onMouseLeave={reset}>
|
||||
{!copied
|
||||
? (
|
||||
<Copy className="mx-1 h-3.5 w-3.5 cursor-pointer text-text-tertiary" onClick={handleCopy} />
|
||||
)
|
||||
: (
|
||||
<CopyCheck className="mx-1 h-3.5 w-3.5 text-text-tertiary" />
|
||||
)}
|
||||
? (<span className="i-custom-vender-line-files-copy mx-1 h-3.5 w-3.5 cursor-pointer text-text-tertiary" onClick={handleCopy} data-testid="copy-icon" />)
|
||||
: (<span className="i-custom-vender-line-files-copy-check mx-1 h-3.5 w-3.5 text-text-tertiary" data-testid="copied-icon" />)}
|
||||
</div>
|
||||
</Tooltip>
|
||||
)
|
||||
|
||||
@ -65,6 +65,14 @@ describe('DatePicker', () => {
|
||||
|
||||
expect(screen.getByRole('textbox').getAttribute('value')).not.toBe('')
|
||||
})
|
||||
|
||||
it('should normalize non-Dayjs value input', () => {
|
||||
const value = new Date('2024-06-15T14:30:00Z') as unknown as DatePickerProps['value']
|
||||
const props = createDatePickerProps({ value })
|
||||
render(<DatePicker {...props} />)
|
||||
|
||||
expect(screen.getByRole('textbox').getAttribute('value')).not.toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
// Open/close behavior
|
||||
@ -243,6 +251,31 @@ describe('DatePicker', () => {
|
||||
|
||||
expect(screen.getByText(/operation\.pickDate/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update time when no selectedDate exists and minute is selected', () => {
|
||||
const props = createDatePickerProps({ needTimePicker: true })
|
||||
render(<DatePicker {...props} />)
|
||||
|
||||
openPicker()
|
||||
fireEvent.click(screen.getByText('--:-- --'))
|
||||
|
||||
const allLists = screen.getAllByRole('list')
|
||||
const minuteItems = within(allLists[1]).getAllByRole('listitem')
|
||||
fireEvent.click(minuteItems[15])
|
||||
|
||||
expect(screen.getByText(/operation\.pickDate/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update time when no selectedDate exists and period is selected', () => {
|
||||
const props = createDatePickerProps({ needTimePicker: true })
|
||||
render(<DatePicker {...props} />)
|
||||
|
||||
openPicker()
|
||||
fireEvent.click(screen.getByText('--:-- --'))
|
||||
fireEvent.click(screen.getByText('PM'))
|
||||
|
||||
expect(screen.getByText(/operation\.pickDate/)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Date selection
|
||||
@ -298,6 +331,17 @@ describe('DatePicker', () => {
|
||||
expect(onChange).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should clone time from timezone default when selecting a date without initial value', () => {
|
||||
const onChange = vi.fn()
|
||||
const props = createDatePickerProps({ onChange, noConfirm: true })
|
||||
render(<DatePicker {...props} />)
|
||||
|
||||
openPicker()
|
||||
fireEvent.click(screen.getByRole('button', { name: '20' }))
|
||||
|
||||
expect(onChange).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onChange with undefined when OK is clicked without a selected date', () => {
|
||||
const onChange = vi.fn()
|
||||
const props = createDatePickerProps({ onChange })
|
||||
@ -598,6 +642,22 @@ describe('DatePicker', () => {
|
||||
const emitted = onChange.mock.calls[0][0]
|
||||
expect(emitted.isValid()).toBe(true)
|
||||
})
|
||||
|
||||
it('should preserve selected date when timezone changes after selecting now without initial value', () => {
|
||||
const onChange = vi.fn()
|
||||
const props = createDatePickerProps({
|
||||
timezone: 'UTC',
|
||||
onChange,
|
||||
})
|
||||
const { rerender } = render(<DatePicker {...props} />)
|
||||
|
||||
openPicker()
|
||||
fireEvent.click(screen.getByText(/operation\.now/))
|
||||
rerender(<DatePicker {...props} timezone="Asia/Tokyo" />)
|
||||
|
||||
expect(onChange).toHaveBeenCalledTimes(1)
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Display time when selected date exists
|
||||
|
||||
@ -98,6 +98,17 @@ describe('TimePicker', () => {
|
||||
expect(input).toHaveValue('10:00 AM')
|
||||
})
|
||||
|
||||
it('should handle document mousedown listener while picker is open', () => {
|
||||
render(<TimePicker {...baseProps} value="10:00 AM" timezone="UTC" />)
|
||||
|
||||
const input = screen.getByRole('textbox')
|
||||
fireEvent.click(input)
|
||||
expect(input).toHaveValue('')
|
||||
|
||||
fireEvent.mouseDown(document.body)
|
||||
expect(input).toHaveValue('')
|
||||
})
|
||||
|
||||
it('should call onClear when clear is clicked while picker is closed', () => {
|
||||
const onClear = vi.fn()
|
||||
render(
|
||||
@ -135,14 +146,6 @@ describe('TimePicker', () => {
|
||||
expect(onClear).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should register click outside listener on mount', () => {
|
||||
const addEventSpy = vi.spyOn(document, 'addEventListener')
|
||||
render(<TimePicker {...baseProps} value="10:00 AM" timezone="UTC" />)
|
||||
|
||||
expect(addEventSpy).toHaveBeenCalledWith('mousedown', expect.any(Function))
|
||||
addEventSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should sync selectedTime from value when opening with stale state', () => {
|
||||
const onChange = vi.fn()
|
||||
render(
|
||||
@ -473,10 +476,81 @@ describe('TimePicker', () => {
|
||||
expect(isDayjsObject(emitted)).toBe(true)
|
||||
expect(emitted.hour()).toBeGreaterThanOrEqual(12)
|
||||
})
|
||||
|
||||
it('should handle selection when timezone is undefined', () => {
|
||||
const onChange = vi.fn()
|
||||
// Render without timezone prop
|
||||
render(<TimePicker {...baseProps} onChange={onChange} />)
|
||||
openPicker()
|
||||
|
||||
// Click hour "03"
|
||||
const { hourList } = getHourAndMinuteLists()
|
||||
fireEvent.click(within(hourList).getByText('03'))
|
||||
|
||||
const confirmButton = screen.getByRole('button', { name: /operation\.ok/i })
|
||||
fireEvent.click(confirmButton)
|
||||
|
||||
expect(onChange).toHaveBeenCalledTimes(1)
|
||||
const emitted = onChange.mock.calls[0][0]
|
||||
expect(emitted.hour()).toBe(3)
|
||||
})
|
||||
})
|
||||
|
||||
// Timezone change effect tests
|
||||
describe('Timezone Changes', () => {
|
||||
it('should return early when only onChange reference changes', () => {
|
||||
const value = dayjs('2024-01-01T10:30:00Z')
|
||||
const onChangeA = vi.fn()
|
||||
const onChangeB = vi.fn()
|
||||
|
||||
const { rerender } = render(
|
||||
<TimePicker
|
||||
{...baseProps}
|
||||
onChange={onChangeA}
|
||||
value={value}
|
||||
timezone="UTC"
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<TimePicker
|
||||
{...baseProps}
|
||||
onChange={onChangeB}
|
||||
value={value}
|
||||
timezone="UTC"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(onChangeA).not.toHaveBeenCalled()
|
||||
expect(onChangeB).not.toHaveBeenCalled()
|
||||
expect(screen.getByDisplayValue('10:30 AM')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should safely return when value changes to an unparsable time string', () => {
|
||||
const onChange = vi.fn()
|
||||
const invalidValue = 123 as unknown as TimePickerProps['value']
|
||||
const { rerender } = render(
|
||||
<TimePicker
|
||||
{...baseProps}
|
||||
onChange={onChange}
|
||||
value={dayjs('2024-01-01T10:30:00Z')}
|
||||
timezone="UTC"
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<TimePicker
|
||||
{...baseProps}
|
||||
onChange={onChange}
|
||||
value={invalidValue}
|
||||
timezone="UTC"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(onChange).not.toHaveBeenCalled()
|
||||
expect(screen.getByRole('textbox')).toHaveValue('')
|
||||
})
|
||||
|
||||
it('should call onChange when timezone changes with an existing value', () => {
|
||||
const onChange = vi.fn()
|
||||
const value = dayjs('2024-01-01T10:30:00Z')
|
||||
@ -584,6 +658,36 @@ describe('TimePicker', () => {
|
||||
expect(onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should preserve selected time when value is removed and timezone is undefined', () => {
|
||||
const onChange = vi.fn()
|
||||
const { rerender } = render(
|
||||
<TimePicker
|
||||
{...baseProps}
|
||||
onChange={onChange}
|
||||
value={dayjs('2024-01-01T10:30:00Z')}
|
||||
timezone="UTC"
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<TimePicker
|
||||
{...baseProps}
|
||||
onChange={onChange}
|
||||
value={undefined}
|
||||
timezone={undefined}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('textbox'))
|
||||
fireEvent.click(screen.getByRole('button', { name: /operation\.ok/i }))
|
||||
|
||||
expect(onChange).toHaveBeenCalledTimes(1)
|
||||
const emitted = onChange.mock.calls[0][0]
|
||||
expect(isDayjsObject(emitted)).toBe(true)
|
||||
expect(emitted.hour()).toBe(10)
|
||||
expect(emitted.minute()).toBe(30)
|
||||
})
|
||||
|
||||
it('should not update when neither timezone nor value changes', () => {
|
||||
const onChange = vi.fn()
|
||||
const value = dayjs('2024-01-01T10:30:00Z')
|
||||
@ -669,6 +773,19 @@ describe('TimePicker', () => {
|
||||
|
||||
expect(screen.getByDisplayValue('09:15 AM')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should return empty display value for an unparsable truthy string', () => {
|
||||
const invalidValue = 123 as unknown as TimePickerProps['value']
|
||||
render(
|
||||
<TimePicker
|
||||
{...baseProps}
|
||||
value={invalidValue}
|
||||
timezone="UTC"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('textbox')).toHaveValue('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Timezone Label Integration', () => {
|
||||
|
||||
@ -53,6 +53,7 @@ const TimePicker = ({
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
/* v8 ignore next 2 -- outside-click closing is handled by PortalToFollowElem; this local ref guard is a defensive fallback. */
|
||||
if (containerRef.current && !containerRef.current.contains(event.target as Node))
|
||||
setIsOpen(false)
|
||||
}
|
||||
|
||||
@ -1,112 +1,353 @@
|
||||
import dayjs, {
|
||||
import dayjs from 'dayjs'
|
||||
import timezone from 'dayjs/plugin/timezone'
|
||||
import utc from 'dayjs/plugin/utc'
|
||||
import {
|
||||
clearMonthMapCache,
|
||||
cloneTime,
|
||||
convertTimezoneToOffsetStr,
|
||||
formatDateForOutput,
|
||||
getDateWithTimezone,
|
||||
getDaysInMonth,
|
||||
getHourIn12Hour,
|
||||
isDayjsObject,
|
||||
parseDateWithFormat,
|
||||
toDayjs,
|
||||
} from '../dayjs'
|
||||
|
||||
describe('dayjs utilities', () => {
|
||||
const timezone = 'UTC'
|
||||
dayjs.extend(utc)
|
||||
dayjs.extend(timezone)
|
||||
|
||||
it('toDayjs parses time-only strings with timezone support', () => {
|
||||
const result = toDayjs('18:45', { timezone })
|
||||
expect(result).toBeDefined()
|
||||
expect(result?.format('HH:mm')).toBe('18:45')
|
||||
expect(result?.utcOffset()).toBe(getDateWithTimezone({ timezone }).utcOffset())
|
||||
// ── cloneTime ──────────────────────────────────────────────────────────────
|
||||
describe('cloneTime', () => {
|
||||
it('copies hour and minute from source to target, preserving target date', () => {
|
||||
const target = dayjs('2024-03-15')
|
||||
const source = dayjs('2020-01-01T09:30:00')
|
||||
const result = cloneTime(target, source)
|
||||
expect(result.format('YYYY-MM-DD')).toBe('2024-03-15')
|
||||
expect(result.hour()).toBe(9)
|
||||
expect(result.minute()).toBe(30)
|
||||
})
|
||||
})
|
||||
|
||||
// ── getDaysInMonth ─────────────────────────────────────────────────────────
|
||||
describe('getDaysInMonth', () => {
|
||||
beforeEach(() => clearMonthMapCache())
|
||||
|
||||
it('returns cells for a typical month view', () => {
|
||||
const date = dayjs('2024-01-01')
|
||||
const days = getDaysInMonth(date)
|
||||
expect(days.length).toBeGreaterThanOrEqual(28)
|
||||
expect(days.some(d => d.isCurrentMonth)).toBe(true)
|
||||
expect(days.some(d => !d.isCurrentMonth)).toBe(true)
|
||||
})
|
||||
|
||||
it('toDayjs parses 12-hour time strings', () => {
|
||||
const tz = 'America/New_York'
|
||||
const result = toDayjs('07:15 PM', { timezone: tz })
|
||||
expect(result).toBeDefined()
|
||||
expect(result?.format('HH:mm')).toBe('19:15')
|
||||
expect(result?.utcOffset()).toBe(getDateWithTimezone({ timezone: tz }).startOf('day').utcOffset())
|
||||
it('returns cached result on second call', () => {
|
||||
const date = dayjs('2024-02-01')
|
||||
const first = getDaysInMonth(date)
|
||||
const second = getDaysInMonth(date)
|
||||
expect(first).toBe(second) // same reference
|
||||
})
|
||||
|
||||
it('isDayjsObject detects dayjs instances', () => {
|
||||
const date = dayjs()
|
||||
expect(isDayjsObject(date)).toBe(true)
|
||||
expect(isDayjsObject(getDateWithTimezone({ timezone }))).toBe(true)
|
||||
it('clears cache properly', () => {
|
||||
const date = dayjs('2024-03-01')
|
||||
const first = getDaysInMonth(date)
|
||||
clearMonthMapCache()
|
||||
const second = getDaysInMonth(date)
|
||||
expect(first).not.toBe(second) // different reference after clearing
|
||||
})
|
||||
})
|
||||
|
||||
// ── getHourIn12Hour ─────────────────────────────────────────────────────────
|
||||
describe('getHourIn12Hour', () => {
|
||||
it('returns 12 for midnight (hour=0)', () => {
|
||||
expect(getHourIn12Hour(dayjs().set('hour', 0))).toBe(12)
|
||||
})
|
||||
|
||||
it('returns hour-12 for hours >= 12', () => {
|
||||
expect(getHourIn12Hour(dayjs().set('hour', 12))).toBe(0)
|
||||
expect(getHourIn12Hour(dayjs().set('hour', 15))).toBe(3)
|
||||
expect(getHourIn12Hour(dayjs().set('hour', 23))).toBe(11)
|
||||
})
|
||||
|
||||
it('returns hour as-is for AM hours (1-11)', () => {
|
||||
expect(getHourIn12Hour(dayjs().set('hour', 1))).toBe(1)
|
||||
expect(getHourIn12Hour(dayjs().set('hour', 11))).toBe(11)
|
||||
})
|
||||
})
|
||||
|
||||
// ── getDateWithTimezone ─────────────────────────────────────────────────────
|
||||
describe('getDateWithTimezone', () => {
|
||||
it('returns a clone of now when neither date nor timezone given', () => {
|
||||
const result = getDateWithTimezone({})
|
||||
expect(dayjs.isDayjs(result)).toBe(true)
|
||||
})
|
||||
|
||||
it('returns current tz date when only timezone given', () => {
|
||||
const result = getDateWithTimezone({ timezone: 'UTC' })
|
||||
expect(dayjs.isDayjs(result)).toBe(true)
|
||||
expect(result.utcOffset()).toBe(0)
|
||||
})
|
||||
|
||||
it('returns date in given timezone when both date and timezone given', () => {
|
||||
const date = dayjs.utc('2024-06-01T12:00:00Z')
|
||||
const result = getDateWithTimezone({ date, timezone: 'UTC' })
|
||||
expect(result.hour()).toBe(12)
|
||||
})
|
||||
|
||||
it('returns clone of given date when no timezone given', () => {
|
||||
const date = dayjs('2024-01-15T08:30:00')
|
||||
const result = getDateWithTimezone({ date })
|
||||
expect(result.isSame(date)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// ── isDayjsObject ───────────────────────────────────────────────────────────
|
||||
describe('isDayjsObject', () => {
|
||||
it('detects dayjs instances', () => {
|
||||
expect(isDayjsObject(dayjs())).toBe(true)
|
||||
expect(isDayjsObject(getDateWithTimezone({ timezone: 'UTC' }))).toBe(true)
|
||||
expect(isDayjsObject('2024-01-01')).toBe(false)
|
||||
expect(isDayjsObject({})).toBe(false)
|
||||
expect(isDayjsObject(null)).toBe(false)
|
||||
expect(isDayjsObject(undefined)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// ── toDayjs ────────────────────────────────────────────────────────────────
|
||||
describe('toDayjs', () => {
|
||||
const tz = 'UTC'
|
||||
|
||||
it('returns undefined for undefined value', () => {
|
||||
expect(toDayjs(undefined)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('toDayjs parses datetime strings in target timezone', () => {
|
||||
const value = '2024-05-01 12:00:00'
|
||||
const tz = 'America/New_York'
|
||||
|
||||
const result = toDayjs(value, { timezone: tz })
|
||||
|
||||
expect(result).toBeDefined()
|
||||
expect(result?.hour()).toBe(12)
|
||||
expect(result?.format('YYYY-MM-DD HH:mm')).toBe('2024-05-01 12:00')
|
||||
it('returns undefined for empty string', () => {
|
||||
expect(toDayjs('')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('toDayjs parses ISO datetime strings in target timezone', () => {
|
||||
const value = '2024-05-01T14:30:00'
|
||||
const tz = 'Europe/London'
|
||||
it('applies timezone to an existing Dayjs object', () => {
|
||||
const date = dayjs('2024-06-01T12:00:00')
|
||||
const result = toDayjs(date, { timezone: 'UTC' })
|
||||
expect(dayjs.isDayjs(result)).toBe(true)
|
||||
})
|
||||
|
||||
const result = toDayjs(value, { timezone: tz })
|
||||
it('returns the Dayjs object as-is when no timezone given', () => {
|
||||
const date = dayjs('2024-06-01')
|
||||
const result = toDayjs(date)
|
||||
expect(dayjs.isDayjs(result)).toBe(true)
|
||||
})
|
||||
|
||||
expect(result).toBeDefined()
|
||||
expect(result?.hour()).toBe(14)
|
||||
it('returns undefined for non-string non-Dayjs value', () => {
|
||||
// @ts-expect-error testing invalid input
|
||||
expect(toDayjs(12345)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('parses 24h time-only strings', () => {
|
||||
const result = toDayjs('18:45', { timezone: tz })
|
||||
expect(result?.format('HH:mm')).toBe('18:45')
|
||||
})
|
||||
|
||||
it('parses time-only strings with seconds', () => {
|
||||
const result = toDayjs('09:30:45', { timezone: tz })
|
||||
expect(result?.hour()).toBe(9)
|
||||
expect(result?.minute()).toBe(30)
|
||||
expect(result?.second()).toBe(45)
|
||||
})
|
||||
|
||||
it('toDayjs handles dates without time component', () => {
|
||||
const value = '2024-05-01'
|
||||
const tz = 'America/Los_Angeles'
|
||||
|
||||
const result = toDayjs(value, { timezone: tz })
|
||||
it('parses time-only strings with 3-digit milliseconds', () => {
|
||||
const result = toDayjs('08:00:00.500', { timezone: tz })
|
||||
expect(result?.millisecond()).toBe(500)
|
||||
})
|
||||
|
||||
it('parses time-only strings with 3-digit ms - normalizeMillisecond exact branch', () => {
|
||||
// normalizeMillisecond: length === 3 → Number('567') = 567
|
||||
const result = toDayjs('08:00:00.567', { timezone: tz })
|
||||
expect(result).toBeDefined()
|
||||
expect(result?.hour()).toBe(8)
|
||||
expect(result?.second()).toBe(0)
|
||||
})
|
||||
|
||||
it('parses time-only strings with <3-digit milliseconds (pads)', () => {
|
||||
const result = toDayjs('08:00:00.5', { timezone: tz })
|
||||
expect(result?.millisecond()).toBe(500)
|
||||
})
|
||||
|
||||
it('parses 12-hour time strings (PM)', () => {
|
||||
const result = toDayjs('07:15 PM', { timezone: 'America/New_York' })
|
||||
expect(result?.format('HH:mm')).toBe('19:15')
|
||||
})
|
||||
|
||||
it('parses 12-hour time strings (AM)', () => {
|
||||
const result = toDayjs('12:00 AM', { timezone: tz })
|
||||
expect(result?.hour()).toBe(0)
|
||||
})
|
||||
|
||||
it('parses 12-hour time strings with seconds', () => {
|
||||
const result = toDayjs('03:30:15 PM', { timezone: tz })
|
||||
expect(result?.hour()).toBe(15)
|
||||
expect(result?.second()).toBe(15)
|
||||
})
|
||||
|
||||
it('parses datetime strings via common formats', () => {
|
||||
const result = toDayjs('2024-05-01 12:00:00', { timezone: tz })
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('parses ISO datetime strings', () => {
|
||||
const result = toDayjs('2024-05-01T14:30:00', { timezone: 'Europe/London' })
|
||||
expect(result?.hour()).toBe(14)
|
||||
})
|
||||
|
||||
it('parses dates with an explicit format option', () => {
|
||||
// Use unambiguous format: YYYY/MM/DD + value 2024/05/01
|
||||
const result = toDayjs('2024/05/01', { format: 'YYYY/MM/DD', timezone: tz })
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('falls through to other formats when explicit format fails', () => {
|
||||
// '2024-05-01' doesn't match 'DD/MM/YYYY' but will match common formats
|
||||
const result = toDayjs('2024-05-01', { format: 'DD/MM/YYYY', timezone: tz })
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('falls through to common formats when explicit format fails without timezone', () => {
|
||||
const result = toDayjs('2024-05-01', { format: 'DD/MM/YYYY' })
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('returns undefined when explicit format parsing fails and no fallback matches', () => {
|
||||
const result = toDayjs('not-a-date-value', { format: 'YYYY/MM/DD' })
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('uses custom formats array', () => {
|
||||
const result = toDayjs('2024/05/01', { formats: ['YYYY/MM/DD'] })
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('returns undefined for completely invalid string', () => {
|
||||
const result = toDayjs('not-a-valid-date-at-all!!!')
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('parses date-only strings without time', () => {
|
||||
const result = toDayjs('2024-05-01', { timezone: 'America/Los_Angeles' })
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
expect(result?.hour()).toBe(0)
|
||||
expect(result?.minute()).toBe(0)
|
||||
})
|
||||
|
||||
it('uses timezone fallback parser for non-standard datetime strings', () => {
|
||||
const result = toDayjs('May 1, 2024 2:30 PM', { timezone: 'America/New_York' })
|
||||
expect(result?.isValid()).toBe(true)
|
||||
expect(result?.year()).toBe(2024)
|
||||
expect(result?.month()).toBe(4)
|
||||
expect(result?.date()).toBe(1)
|
||||
expect(result?.utcOffset()).toBe(dayjs.tz('2024-05-01', 'America/New_York').utcOffset())
|
||||
})
|
||||
|
||||
it('uses timezone fallback parser when custom formats are empty', () => {
|
||||
const result = toDayjs('2024-05-01T14:30:00Z', {
|
||||
timezone: 'America/New_York',
|
||||
formats: [],
|
||||
})
|
||||
expect(result?.isValid()).toBe(true)
|
||||
expect(result?.utcOffset()).toBe(dayjs.tz('2024-05-01', 'America/New_York').utcOffset())
|
||||
})
|
||||
})
|
||||
|
||||
// ── parseDateWithFormat ────────────────────────────────────────────────────
|
||||
describe('parseDateWithFormat', () => {
|
||||
it('returns null for empty string', () => {
|
||||
expect(parseDateWithFormat('')).toBeNull()
|
||||
})
|
||||
|
||||
it('parses with explicit format', () => {
|
||||
// Use YYYY/MM/DD which is unambiguous
|
||||
const result = parseDateWithFormat('2024/05/01', 'YYYY/MM/DD')
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('returns null for invalid string with explicit format', () => {
|
||||
expect(parseDateWithFormat('not-a-date', 'YYYY-MM-DD')).toBeNull()
|
||||
})
|
||||
|
||||
it('parses using common formats (YYYY-MM-DD)', () => {
|
||||
const result = parseDateWithFormat('2024-05-01')
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('parses using common formats (YYYY/MM/DD)', () => {
|
||||
const result = parseDateWithFormat('2024/05/01')
|
||||
expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('parses ISO datetime strings via common formats', () => {
|
||||
const result = parseDateWithFormat('2024-05-01T14:30:00')
|
||||
expect(result?.hour()).toBe(14)
|
||||
})
|
||||
|
||||
it('returns null for completely unparseable string', () => {
|
||||
expect(parseDateWithFormat('ZZZZ-ZZ-ZZ')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
// ── formatDateForOutput ────────────────────────────────────────────────────
|
||||
describe('formatDateForOutput', () => {
|
||||
it('returns empty string for invalid date', () => {
|
||||
expect(formatDateForOutput(dayjs('invalid'))).toBe('')
|
||||
})
|
||||
|
||||
it('returns date-only format by default (includeTime=false)', () => {
|
||||
const date = dayjs('2024-05-01T12:30:00')
|
||||
expect(formatDateForOutput(date)).toBe('2024-05-01')
|
||||
})
|
||||
|
||||
it('returns ISO datetime string when includeTime=true', () => {
|
||||
const date = dayjs('2024-05-01T12:30:00')
|
||||
const result = formatDateForOutput(date, true)
|
||||
expect(result).toMatch(/^2024-05-01T12:30:00/)
|
||||
})
|
||||
})
|
||||
|
||||
// ── convertTimezoneToOffsetStr ─────────────────────────────────────────────
|
||||
describe('convertTimezoneToOffsetStr', () => {
|
||||
it('should return default UTC+0 for undefined timezone', () => {
|
||||
it('returns default UTC+0 for undefined timezone', () => {
|
||||
expect(convertTimezoneToOffsetStr(undefined)).toBe('UTC+0')
|
||||
})
|
||||
|
||||
it('should return default UTC+0 for invalid timezone', () => {
|
||||
it('returns default UTC+0 for invalid timezone', () => {
|
||||
expect(convertTimezoneToOffsetStr('Invalid/Timezone')).toBe('UTC+0')
|
||||
})
|
||||
|
||||
it('should handle whole hour positive offsets without leading zeros', () => {
|
||||
it('handles positive whole-hour offsets', () => {
|
||||
expect(convertTimezoneToOffsetStr('Asia/Shanghai')).toBe('UTC+8')
|
||||
expect(convertTimezoneToOffsetStr('Pacific/Auckland')).toBe('UTC+12')
|
||||
expect(convertTimezoneToOffsetStr('Pacific/Apia')).toBe('UTC+13')
|
||||
})
|
||||
|
||||
it('should handle whole hour negative offsets without leading zeros', () => {
|
||||
it('handles negative whole-hour offsets', () => {
|
||||
expect(convertTimezoneToOffsetStr('Pacific/Niue')).toBe('UTC-11')
|
||||
expect(convertTimezoneToOffsetStr('Pacific/Honolulu')).toBe('UTC-10')
|
||||
expect(convertTimezoneToOffsetStr('America/New_York')).toBe('UTC-5')
|
||||
})
|
||||
|
||||
it('should handle zero offset', () => {
|
||||
it('handles zero offset', () => {
|
||||
expect(convertTimezoneToOffsetStr('Europe/London')).toBe('UTC+0')
|
||||
expect(convertTimezoneToOffsetStr('UTC')).toBe('UTC+0')
|
||||
})
|
||||
|
||||
it('should handle half-hour offsets (30 minutes)', () => {
|
||||
// India Standard Time: UTC+5:30
|
||||
it('handles half-hour offsets', () => {
|
||||
expect(convertTimezoneToOffsetStr('Asia/Kolkata')).toBe('UTC+5:30')
|
||||
// Australian Central Time: UTC+9:30
|
||||
expect(convertTimezoneToOffsetStr('Australia/Adelaide')).toBe('UTC+9:30')
|
||||
expect(convertTimezoneToOffsetStr('Australia/Darwin')).toBe('UTC+9:30')
|
||||
})
|
||||
|
||||
it('should handle 45-minute offsets', () => {
|
||||
// Chatham Time: UTC+12:45
|
||||
it('handles 45-minute offsets', () => {
|
||||
expect(convertTimezoneToOffsetStr('Pacific/Chatham')).toBe('UTC+12:45')
|
||||
})
|
||||
|
||||
it('should preserve leading zeros in minute part for non-zero minutes', () => {
|
||||
// Ensure +05:30 is displayed as "UTC+5:30", not "UTC+5:3"
|
||||
it('preserves leading zeros in minute part', () => {
|
||||
const result = convertTimezoneToOffsetStr('Asia/Kolkata')
|
||||
expect(result).toMatch(/UTC[+-]\d+:30/)
|
||||
expect(result).not.toMatch(/UTC[+-]\d+:3[^0]/)
|
||||
|
||||
@ -112,6 +112,7 @@ export const convertTimezoneToOffsetStr = (timezone?: string) => {
|
||||
// Extract offset from name format like "-11:00 Niue Time" or "+05:30 India Time"
|
||||
// Name format is always "{offset}:{minutes} {timezone name}"
|
||||
const offsetMatch = /^([+-]?\d{1,2}):(\d{2})/.exec(tzItem.name)
|
||||
/* v8 ignore next 2 -- timezone.json entries are normalized to "{offset} {name}"; this protects against malformed data only. */
|
||||
if (!offsetMatch)
|
||||
return DEFAULT_OFFSET_STR
|
||||
// Parse hours and minutes separately
|
||||
@ -141,6 +142,7 @@ const normalizeMillisecond = (value: string | undefined) => {
|
||||
return 0
|
||||
if (value.length === 3)
|
||||
return Number(value)
|
||||
/* v8 ignore next 2 -- TIME_ONLY_REGEX allows at most 3 fractional digits, so >3 can only occur after future regex changes. */
|
||||
if (value.length > 3)
|
||||
return Number(value.slice(0, 3))
|
||||
return Number(value.padEnd(3, '0'))
|
||||
|
||||
@ -59,6 +59,7 @@ const EmojiPickerInner: FC<IEmojiPickerInnerProps> = ({
|
||||
React.useEffect(() => {
|
||||
if (selectedEmoji) {
|
||||
setShowStyleColors(true)
|
||||
/* v8 ignore next 2 - @preserve */
|
||||
if (selectedBackground)
|
||||
onSelect?.(selectedEmoji, selectedBackground)
|
||||
}
|
||||
|
||||
@ -238,6 +238,32 @@ describe('ErrorBoundary', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should not reset when resetKeys reference changes but values are identical', async () => {
|
||||
const onReset = vi.fn()
|
||||
|
||||
const StableKeysHarness = () => {
|
||||
const [keys, setKeys] = React.useState<Array<string | number>>([1, 2])
|
||||
return (
|
||||
<>
|
||||
<button onClick={() => setKeys([1, 2])}>Update keys same values</button>
|
||||
<ErrorBoundary resetKeys={keys} onReset={onReset}>
|
||||
<ThrowOnRender shouldThrow={true} />
|
||||
</ErrorBoundary>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
render(<StableKeysHarness />)
|
||||
await screen.findByText('Something went wrong')
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Update keys same values' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Something went wrong')).toBeInTheDocument()
|
||||
})
|
||||
expect(onReset).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should reset after children change when resetOnPropsChange is true', async () => {
|
||||
const ResetOnPropsHarness = () => {
|
||||
const [shouldThrow, setShouldThrow] = React.useState(true)
|
||||
@ -269,6 +295,24 @@ describe('ErrorBoundary', () => {
|
||||
expect(screen.getByText('second child')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should call window.location.reload when Reload Page is clicked', async () => {
|
||||
const reloadSpy = vi.fn()
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { ...window.location, reload: reloadSpy },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(
|
||||
<ErrorBoundary>
|
||||
<ThrowOnRender shouldThrow={true} />
|
||||
</ErrorBoundary>,
|
||||
)
|
||||
|
||||
fireEvent.click(await screen.findByRole('button', { name: 'Reload Page' }))
|
||||
|
||||
expect(reloadSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -358,6 +402,16 @@ describe('ErrorBoundary utility exports', () => {
|
||||
|
||||
expect(Wrapped.displayName).toBe('withErrorBoundary(NamedComponent)')
|
||||
})
|
||||
|
||||
it('should fallback displayName to Component when wrapped component has no displayName and empty name', () => {
|
||||
const Nameless = (() => <div>nameless</div>) as React.FC
|
||||
Object.defineProperty(Nameless, 'displayName', { value: undefined, configurable: true })
|
||||
Object.defineProperty(Nameless, 'name', { value: '', configurable: true })
|
||||
|
||||
const Wrapped = withErrorBoundary(Nameless)
|
||||
|
||||
expect(Wrapped.displayName).toBe('withErrorBoundary(Component)')
|
||||
})
|
||||
})
|
||||
|
||||
// Validate simple fallback helper component.
|
||||
|
||||
7
web/app/components/base/features/__tests__/index.spec.ts
Normal file
7
web/app/components/base/features/__tests__/index.spec.ts
Normal file
@ -0,0 +1,7 @@
|
||||
import { FeaturesProvider } from '../index'
|
||||
|
||||
describe('features index exports', () => {
|
||||
it('should export FeaturesProvider from the barrel file', () => {
|
||||
expect(FeaturesProvider).toBeDefined()
|
||||
})
|
||||
})
|
||||
@ -146,4 +146,30 @@ describe('AnnotationCtrlButton', () => {
|
||||
expect(mockSetShowAnnotationFullModal).toHaveBeenCalled()
|
||||
expect(mockAddAnnotation).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should fallback author name to empty string when account name is missing', async () => {
|
||||
const onAdded = vi.fn()
|
||||
mockAddAnnotation.mockResolvedValueOnce({
|
||||
id: 'annotation-2',
|
||||
account: undefined,
|
||||
})
|
||||
|
||||
render(
|
||||
<AnnotationCtrlButton
|
||||
appId="test-app"
|
||||
messageId="msg-2"
|
||||
cached={false}
|
||||
query="test query"
|
||||
answer="test answer"
|
||||
onAdded={onAdded}
|
||||
onEdit={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onAdded).toHaveBeenCalledWith('annotation-2', '')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -39,6 +39,19 @@ vi.mock('@/config', () => ({
|
||||
ANNOTATION_DEFAULT: { score_threshold: 0.9 },
|
||||
}))
|
||||
|
||||
vi.mock('../score-slider', () => ({
|
||||
default: ({ value, onChange }: { value: number, onChange: (value: number) => void }) => (
|
||||
<input
|
||||
role="slider"
|
||||
type="range"
|
||||
min={80}
|
||||
max={100}
|
||||
value={value}
|
||||
onChange={e => onChange(Number((e.target as HTMLInputElement).value))}
|
||||
/>
|
||||
),
|
||||
}))
|
||||
|
||||
const defaultAnnotationConfig = {
|
||||
id: 'test-id',
|
||||
enabled: false,
|
||||
@ -158,7 +171,7 @@ describe('ConfigParamModal', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('0.90')).toBeInTheDocument()
|
||||
expect(screen.getByRole('slider')).toHaveValue('90')
|
||||
})
|
||||
|
||||
it('should render configConfirmBtn when isInit is false', () => {
|
||||
@ -262,9 +275,9 @@ describe('ConfigParamModal', () => {
|
||||
)
|
||||
|
||||
const slider = screen.getByRole('slider')
|
||||
expect(slider).toHaveAttribute('aria-valuemin', '80')
|
||||
expect(slider).toHaveAttribute('aria-valuemax', '100')
|
||||
expect(slider).toHaveAttribute('aria-valuenow', '90')
|
||||
expect(slider).toHaveAttribute('min', '80')
|
||||
expect(slider).toHaveAttribute('max', '100')
|
||||
expect(slider).toHaveValue('90')
|
||||
})
|
||||
|
||||
it('should update embedding model when model selector is used', () => {
|
||||
@ -377,7 +390,7 @@ describe('ConfigParamModal', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('slider')).toHaveAttribute('aria-valuenow', '90')
|
||||
expect(screen.getByRole('slider')).toHaveValue('90')
|
||||
})
|
||||
|
||||
it('should set loading state while saving', async () => {
|
||||
@ -412,4 +425,30 @@ describe('ConfigParamModal', () => {
|
||||
expect(onSave).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should save updated score after slider changes', async () => {
|
||||
const onSave = vi.fn().mockResolvedValue(undefined)
|
||||
render(
|
||||
<ConfigParamModal
|
||||
appId="test-app"
|
||||
isShow={true}
|
||||
onHide={vi.fn()}
|
||||
onSave={onSave}
|
||||
annotationConfig={defaultAnnotationConfig}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.change(screen.getByRole('slider'), { target: { value: '96' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
const saveBtn = buttons.find(b => b.textContent?.includes('initSetup'))
|
||||
fireEvent.click(saveBtn!)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onSave).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ embedding_provider_name: 'openai' }),
|
||||
0.96,
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
import type { Features } from '../../../types'
|
||||
import type { OnFeaturesChange } from '@/app/components/base/features/types'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import { FeaturesProvider } from '../../../context'
|
||||
import AnnotationReply from '../index'
|
||||
|
||||
const originalConsoleError = console.error
|
||||
const mockPush = vi.fn()
|
||||
let mockPathname = '/app/test-app-id/configuration'
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: mockPush }),
|
||||
usePathname: () => '/app/test-app-id/configuration',
|
||||
usePathname: () => mockPathname,
|
||||
}))
|
||||
|
||||
let mockIsShowAnnotationConfigInit = false
|
||||
@ -100,6 +102,15 @@ const renderWithProvider = (
|
||||
describe('AnnotationReply', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'error').mockImplementation((...args: unknown[]) => {
|
||||
const message = args.map(arg => String(arg)).join(' ')
|
||||
if (message.includes('A props object containing a "key" prop is being spread into JSX')
|
||||
|| message.includes('React keys must be passed directly to JSX without using spread')) {
|
||||
return
|
||||
}
|
||||
originalConsoleError(...args as Parameters<typeof console.error>)
|
||||
})
|
||||
mockPathname = '/app/test-app-id/configuration'
|
||||
mockIsShowAnnotationConfigInit = false
|
||||
mockIsShowAnnotationFullModal = false
|
||||
capturedSetAnnotationConfig = null
|
||||
@ -235,18 +246,47 @@ describe('AnnotationReply', () => {
|
||||
expect(mockPush).toHaveBeenCalledWith('/app/test-app-id/annotations')
|
||||
})
|
||||
|
||||
it('should show config param modal when isShowAnnotationConfigInit is true', () => {
|
||||
it('should fallback appId to empty string when pathname does not match', () => {
|
||||
mockPathname = '/apps/no-match'
|
||||
renderWithProvider({}, {
|
||||
annotationReply: {
|
||||
enabled: true,
|
||||
score_threshold: 0.9,
|
||||
embedding_model: {
|
||||
embedding_provider_name: 'openai',
|
||||
embedding_model_name: 'text-embedding-ada-002',
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const card = screen.getByText(/feature\.annotation\.title/).closest('[class]')!
|
||||
fireEvent.mouseEnter(card)
|
||||
fireEvent.click(screen.getByText(/feature\.annotation\.cacheManagement/))
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/app//annotations')
|
||||
})
|
||||
|
||||
it('should show config param modal when isShowAnnotationConfigInit is true', async () => {
|
||||
mockIsShowAnnotationConfigInit = true
|
||||
renderWithProvider()
|
||||
await act(async () => {
|
||||
renderWithProvider()
|
||||
await Promise.resolve()
|
||||
})
|
||||
|
||||
expect(screen.getByText(/initSetup\.title/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide config modal when hide is clicked', () => {
|
||||
it('should hide config modal when hide is clicked', async () => {
|
||||
mockIsShowAnnotationConfigInit = true
|
||||
renderWithProvider()
|
||||
await act(async () => {
|
||||
renderWithProvider()
|
||||
await Promise.resolve()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /operation\.cancel/ }))
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: /operation\.cancel/ }))
|
||||
await Promise.resolve()
|
||||
})
|
||||
|
||||
expect(mockSetIsShowAnnotationConfigInit).toHaveBeenCalledWith(false)
|
||||
})
|
||||
@ -264,7 +304,10 @@ describe('AnnotationReply', () => {
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText(/initSetup\.confirmBtn/))
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByText(/initSetup\.confirmBtn/))
|
||||
await Promise.resolve()
|
||||
})
|
||||
|
||||
expect(mockHandleEnableAnnotation).toHaveBeenCalled()
|
||||
})
|
||||
@ -298,7 +341,10 @@ describe('AnnotationReply', () => {
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText(/initSetup\.confirmBtn/))
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByText(/initSetup\.confirmBtn/))
|
||||
await Promise.resolve()
|
||||
})
|
||||
|
||||
// handleEnableAnnotation should be called with embedding model and score
|
||||
expect(mockHandleEnableAnnotation).toHaveBeenCalledWith(
|
||||
@ -327,13 +373,15 @@ describe('AnnotationReply', () => {
|
||||
|
||||
// The captured setAnnotationConfig is the component's updateAnnotationReply callback
|
||||
expect(capturedSetAnnotationConfig).not.toBeNull()
|
||||
capturedSetAnnotationConfig!({
|
||||
enabled: true,
|
||||
score_threshold: 0.8,
|
||||
embedding_model: {
|
||||
embedding_provider_name: 'openai',
|
||||
embedding_model_name: 'new-model',
|
||||
},
|
||||
act(() => {
|
||||
capturedSetAnnotationConfig!({
|
||||
enabled: true,
|
||||
score_threshold: 0.8,
|
||||
embedding_model: {
|
||||
embedding_provider_name: 'openai',
|
||||
embedding_model_name: 'new-model',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
@ -353,12 +401,12 @@ describe('AnnotationReply', () => {
|
||||
|
||||
// Should not throw when onChange is not provided
|
||||
expect(capturedSetAnnotationConfig).not.toBeNull()
|
||||
expect(() => {
|
||||
expect(() => act(() => {
|
||||
capturedSetAnnotationConfig!({
|
||||
enabled: true,
|
||||
score_threshold: 0.7,
|
||||
})
|
||||
}).not.toThrow()
|
||||
})).not.toThrow()
|
||||
})
|
||||
|
||||
it('should hide info display when hovering over enabled feature', () => {
|
||||
@ -403,9 +451,12 @@ describe('AnnotationReply', () => {
|
||||
expect(screen.getByText('0.9')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass isInit prop to ConfigParamModal', () => {
|
||||
it('should pass isInit prop to ConfigParamModal', async () => {
|
||||
mockIsShowAnnotationConfigInit = true
|
||||
renderWithProvider()
|
||||
await act(async () => {
|
||||
renderWithProvider()
|
||||
await Promise.resolve()
|
||||
})
|
||||
|
||||
expect(screen.getByText(/initSetup\.confirmBtn/)).toBeInTheDocument()
|
||||
expect(screen.queryByText(/initSetup\.configConfirmBtn/)).not.toBeInTheDocument()
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import type { AnnotationReplyConfig } from '@/models/debug'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { queryAnnotationJobStatus } from '@/service/annotation'
|
||||
import { sleep } from '@/utils'
|
||||
import useAnnotationConfig from '../use-annotation-config'
|
||||
|
||||
let mockIsAnnotationFull = false
|
||||
@ -238,4 +240,31 @@ describe('useAnnotationConfig', () => {
|
||||
expect(updatedConfig.enabled).toBe(true)
|
||||
expect(updatedConfig.score_threshold).toBeDefined()
|
||||
})
|
||||
|
||||
it('should poll job status until completed when enabling annotation', async () => {
|
||||
const setAnnotationConfig = vi.fn()
|
||||
const queryJobStatusMock = vi.mocked(queryAnnotationJobStatus)
|
||||
const sleepMock = vi.mocked(sleep)
|
||||
|
||||
queryJobStatusMock
|
||||
.mockResolvedValueOnce({ job_status: 'pending' } as unknown as Awaited<ReturnType<typeof queryAnnotationJobStatus>>)
|
||||
.mockResolvedValueOnce({ job_status: 'completed' } as unknown as Awaited<ReturnType<typeof queryAnnotationJobStatus>>)
|
||||
|
||||
const { result } = renderHook(() => useAnnotationConfig({
|
||||
appId: 'test-app',
|
||||
annotationConfig: defaultConfig,
|
||||
setAnnotationConfig,
|
||||
}))
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleEnableAnnotation({
|
||||
embedding_provider_name: 'openai',
|
||||
embedding_model_name: 'text-embedding-3-small',
|
||||
}, 0.95)
|
||||
})
|
||||
|
||||
expect(queryJobStatusMock).toHaveBeenCalledTimes(2)
|
||||
expect(sleepMock).toHaveBeenCalledWith(2000)
|
||||
expect(setAnnotationConfig).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -93,6 +93,7 @@ const ConfigParamModal: FC<Props> = ({
|
||||
className="mt-1"
|
||||
value={(annotationConfig.score_threshold || ANNOTATION_DEFAULT.score_threshold) * 100}
|
||||
onChange={(val) => {
|
||||
/* v8 ignore next -- callback dispatch depends on react-slider drag mechanics that are flaky in jsdom. @preserve */
|
||||
setAnnotationConfig({
|
||||
...annotationConfig,
|
||||
score_threshold: val / 100,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { Features } from '../../../types'
|
||||
import type { OnFeaturesChange } from '@/app/components/base/features/types'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import { FeaturesProvider } from '../../../context'
|
||||
import ConversationOpener from '../index'
|
||||
|
||||
@ -144,7 +144,9 @@ describe('ConversationOpener', () => {
|
||||
fireEvent.click(screen.getByText(/openingStatement\.writeOpener/))
|
||||
|
||||
const modalCall = mockSetShowOpeningModal.mock.calls[0][0]
|
||||
modalCall.onSaveCallback({ enabled: true, opening_statement: 'Updated' })
|
||||
act(() => {
|
||||
modalCall.onSaveCallback({ enabled: true, opening_statement: 'Updated' })
|
||||
})
|
||||
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
@ -184,4 +186,41 @@ describe('ConversationOpener', () => {
|
||||
// After leave, statement visible again
|
||||
expect(screen.getByText('Welcome!')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should return early from opener handler when disabled and hovered', () => {
|
||||
renderWithProvider({ disabled: true }, {
|
||||
opening: { enabled: true, opening_statement: 'Hello' },
|
||||
})
|
||||
|
||||
const card = screen.getByText(/feature\.conversationOpener\.title/).closest('[class]')!
|
||||
fireEvent.mouseEnter(card)
|
||||
fireEvent.click(screen.getByText(/openingStatement\.writeOpener/))
|
||||
|
||||
expect(mockSetShowOpeningModal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should run save and cancel callbacks without onChange', () => {
|
||||
renderWithProvider({}, {
|
||||
opening: { enabled: true, opening_statement: 'Hello' },
|
||||
})
|
||||
|
||||
const card = screen.getByText(/feature\.conversationOpener\.title/).closest('[class]')!
|
||||
fireEvent.mouseEnter(card)
|
||||
fireEvent.click(screen.getByText(/openingStatement\.writeOpener/))
|
||||
|
||||
const modalCall = mockSetShowOpeningModal.mock.calls[0][0]
|
||||
act(() => {
|
||||
modalCall.onSaveCallback({ enabled: true, opening_statement: 'Updated without callback' })
|
||||
modalCall.onCancelCallback()
|
||||
})
|
||||
|
||||
expect(mockSetShowOpeningModal).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should toggle feature switch without onChange callback', () => {
|
||||
renderWithProvider()
|
||||
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
expect(screen.getByRole('switch')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -31,7 +31,25 @@ vi.mock('@/app/components/app/configuration/config-prompt/confirm-add-var', () =
|
||||
}))
|
||||
|
||||
vi.mock('react-sortablejs', () => ({
|
||||
ReactSortable: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
|
||||
ReactSortable: ({
|
||||
children,
|
||||
list,
|
||||
setList,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
list: Array<{ id: number, name: string }>
|
||||
setList: (list: Array<{ id: number, name: string }>) => void
|
||||
}) => (
|
||||
<div>
|
||||
<button
|
||||
data-testid="mock-sortable-apply"
|
||||
onClick={() => setList([...list].reverse())}
|
||||
>
|
||||
Apply Sort
|
||||
</button>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const defaultData: OpeningStatement = {
|
||||
@ -168,6 +186,23 @@ describe('OpeningSettingModal', () => {
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should not call onCancel when close icon receives non-action key', async () => {
|
||||
const onCancel = vi.fn()
|
||||
await render(
|
||||
<OpeningSettingModal
|
||||
data={defaultData}
|
||||
onSave={vi.fn()}
|
||||
onCancel={onCancel}
|
||||
/>,
|
||||
)
|
||||
|
||||
const closeButton = screen.getByTestId('close-modal')
|
||||
closeButton.focus()
|
||||
fireEvent.keyDown(closeButton, { key: 'Escape' })
|
||||
|
||||
expect(onCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onSave with updated data when save is clicked', async () => {
|
||||
const onSave = vi.fn()
|
||||
await render(
|
||||
@ -507,4 +542,73 @@ describe('OpeningSettingModal', () => {
|
||||
expect(editor.textContent?.trim()).toBe('')
|
||||
expect(screen.getByText('appDebug.openingStatement.placeholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with empty suggested questions when field is missing', async () => {
|
||||
await render(
|
||||
<OpeningSettingModal
|
||||
data={{ ...defaultData, suggested_questions: undefined } as unknown as OpeningStatement}
|
||||
onSave={vi.fn()}
|
||||
onCancel={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByDisplayValue('Question 1')).not.toBeInTheDocument()
|
||||
expect(screen.queryByDisplayValue('Question 2')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render prompt variable fallback key when name is empty', async () => {
|
||||
await render(
|
||||
<OpeningSettingModal
|
||||
data={defaultData}
|
||||
onSave={vi.fn()}
|
||||
onCancel={vi.fn()}
|
||||
promptVariables={[{ key: 'account_id', name: '', type: 'string', required: true }]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(getPromptEditor()).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should save reordered suggested questions after sortable setList', async () => {
|
||||
const onSave = vi.fn()
|
||||
await render(
|
||||
<OpeningSettingModal
|
||||
data={defaultData}
|
||||
onSave={onSave}
|
||||
onCancel={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await userEvent.click(screen.getByTestId('mock-sortable-apply'))
|
||||
await userEvent.click(screen.getByText(/operation\.save/))
|
||||
|
||||
expect(onSave).toHaveBeenCalledWith(expect.objectContaining({
|
||||
suggested_questions: ['Question 2', 'Question 1'],
|
||||
}))
|
||||
})
|
||||
|
||||
it('should not save when confirm dialog action runs with empty opening statement', async () => {
|
||||
const onSave = vi.fn()
|
||||
const view = await render(
|
||||
<OpeningSettingModal
|
||||
data={{ ...defaultData, opening_statement: 'Hello {{name}}' }}
|
||||
onSave={onSave}
|
||||
onCancel={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await userEvent.click(screen.getByText(/operation\.save/))
|
||||
expect(screen.getByTestId('confirm-add-var')).toBeInTheDocument()
|
||||
|
||||
view.rerender(
|
||||
<OpeningSettingModal
|
||||
data={{ ...defaultData, opening_statement: ' ' }}
|
||||
onSave={onSave}
|
||||
onCancel={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await userEvent.click(screen.getByTestId('cancel-add'))
|
||||
expect(onSave).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -34,6 +34,7 @@ const ConversationOpener = ({
|
||||
const featuresStore = useFeaturesStore()
|
||||
const [isHovering, setIsHovering] = useState(false)
|
||||
const handleOpenOpeningModal = useCallback(() => {
|
||||
/* v8 ignore next -- guarded path is not reachable in tests with a real disabled button because click is prevented at DOM level. @preserve */
|
||||
if (disabled)
|
||||
return
|
||||
const {
|
||||
|
||||
@ -64,6 +64,14 @@ describe('FileUpload', () => {
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should toggle without onChange callback', () => {
|
||||
renderWithProvider()
|
||||
|
||||
expect(() => {
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('should show supported types when enabled', () => {
|
||||
renderWithProvider({}, {
|
||||
file: {
|
||||
|
||||
@ -150,6 +150,17 @@ describe('SettingContent', () => {
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should not call onClose when close icon receives non-action key', () => {
|
||||
const onClose = vi.fn()
|
||||
renderWithProvider({ onClose })
|
||||
|
||||
const closeIconButton = screen.getByTestId('close-setting-modal')
|
||||
closeIconButton.focus()
|
||||
fireEvent.keyDown(closeIconButton, { key: 'Escape' })
|
||||
|
||||
expect(onClose).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onClose when cancel button is clicked to close', () => {
|
||||
const onClose = vi.fn()
|
||||
renderWithProvider({ onClose })
|
||||
|
||||
@ -70,6 +70,14 @@ describe('ImageUpload', () => {
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should toggle without onChange callback', () => {
|
||||
renderWithProvider()
|
||||
|
||||
expect(() => {
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('should show supported types when enabled', () => {
|
||||
renderWithProvider({}, {
|
||||
file: {
|
||||
|
||||
@ -3,6 +3,12 @@ import type { CodeBasedExtensionForm } from '@/models/common'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import FormGeneration from '../form-generation'
|
||||
|
||||
const { mockLocale } = vi.hoisted(() => ({ mockLocale: { value: 'en-US' } }))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useLocale: () => mockLocale.value,
|
||||
}))
|
||||
|
||||
const i18n = (en: string, zh = en): I18nText =>
|
||||
({ 'en-US': en, 'zh-Hans': zh }) as unknown as I18nText
|
||||
|
||||
@ -21,6 +27,7 @@ const createForm = (overrides: Partial<CodeBasedExtensionForm> = {}): CodeBasedE
|
||||
describe('FormGeneration', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockLocale.value = 'en-US'
|
||||
})
|
||||
|
||||
it('should render text-input form fields', () => {
|
||||
@ -130,4 +137,22 @@ describe('FormGeneration', () => {
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith({ model: 'gpt-4' })
|
||||
})
|
||||
|
||||
it('should render zh-Hans labels for select field and options', () => {
|
||||
mockLocale.value = 'zh-Hans'
|
||||
const form = createForm({
|
||||
type: 'select',
|
||||
variable: 'model',
|
||||
label: i18n('Model', '模型'),
|
||||
options: [
|
||||
{ label: i18n('GPT-4', '智谱-4'), value: 'gpt-4' },
|
||||
{ label: i18n('GPT-3.5', '智谱-3.5'), value: 'gpt-3.5' },
|
||||
],
|
||||
})
|
||||
render(<FormGeneration forms={[form]} value={{}} onChange={vi.fn()} />)
|
||||
|
||||
expect(screen.getByText('模型')).toBeInTheDocument()
|
||||
fireEvent.click(screen.getByText(/placeholder\.select/))
|
||||
expect(screen.getByText('智谱-4')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -4,6 +4,10 @@ import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { FeaturesProvider } from '../../../context'
|
||||
import Moderation from '../index'
|
||||
|
||||
const { mockCodeBasedExtensionData } = vi.hoisted(() => ({
|
||||
mockCodeBasedExtensionData: [] as Array<{ name: string, label: Record<string, string> }>,
|
||||
}))
|
||||
|
||||
const mockSetShowModerationSettingModal = vi.fn()
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContext: () => ({
|
||||
@ -16,7 +20,7 @@ vi.mock('@/context/i18n', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useCodeBasedExtensions: () => ({ data: { data: [] } }),
|
||||
useCodeBasedExtensions: () => ({ data: { data: mockCodeBasedExtensionData } }),
|
||||
}))
|
||||
|
||||
const defaultFeatures: Features = {
|
||||
@ -46,6 +50,7 @@ const renderWithProvider = (
|
||||
describe('Moderation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCodeBasedExtensionData.length = 0
|
||||
})
|
||||
|
||||
it('should render the moderation title', () => {
|
||||
@ -282,6 +287,25 @@ describe('Moderation', () => {
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should invoke onCancelCallback from settings modal without onChange', () => {
|
||||
renderWithProvider({}, {
|
||||
moderation: {
|
||||
enabled: true,
|
||||
type: 'keywords',
|
||||
config: {
|
||||
inputs_config: { enabled: true, preset_response: '' },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const card = screen.getByText(/feature\.moderation\.title/).closest('[class]')!
|
||||
fireEvent.mouseEnter(card)
|
||||
fireEvent.click(screen.getByText(/operation\.settings/))
|
||||
|
||||
const modalCall = mockSetShowModerationSettingModal.mock.calls[0][0]
|
||||
expect(() => modalCall.onCancelCallback()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should invoke onSaveCallback from settings modal', () => {
|
||||
const onChange = vi.fn()
|
||||
renderWithProvider({ onChange }, {
|
||||
@ -304,6 +328,25 @@ describe('Moderation', () => {
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should invoke onSaveCallback from settings modal without onChange', () => {
|
||||
renderWithProvider({}, {
|
||||
moderation: {
|
||||
enabled: true,
|
||||
type: 'keywords',
|
||||
config: {
|
||||
inputs_config: { enabled: true, preset_response: '' },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const card = screen.getByText(/feature\.moderation\.title/).closest('[class]')!
|
||||
fireEvent.mouseEnter(card)
|
||||
fireEvent.click(screen.getByText(/operation\.settings/))
|
||||
|
||||
const modalCall = mockSetShowModerationSettingModal.mock.calls[0][0]
|
||||
expect(() => modalCall.onSaveCallback({ enabled: true, type: 'keywords', config: {} })).not.toThrow()
|
||||
})
|
||||
|
||||
it('should show code-based extension label for custom type', () => {
|
||||
renderWithProvider({}, {
|
||||
moderation: {
|
||||
@ -319,6 +362,41 @@ describe('Moderation', () => {
|
||||
expect(screen.getByText('-')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show code-based extension label when custom type is configured', () => {
|
||||
mockCodeBasedExtensionData.push({
|
||||
name: 'custom-ext',
|
||||
label: { 'en-US': 'Custom Moderation', 'zh-Hans': '自定义审核' },
|
||||
})
|
||||
|
||||
renderWithProvider({}, {
|
||||
moderation: {
|
||||
enabled: true,
|
||||
type: 'custom-ext',
|
||||
config: {
|
||||
inputs_config: { enabled: true, preset_response: '' },
|
||||
outputs_config: { enabled: false, preset_response: '' },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.getByText('Custom Moderation')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show enable content text when both input and output moderation are disabled', () => {
|
||||
renderWithProvider({}, {
|
||||
moderation: {
|
||||
enabled: true,
|
||||
type: 'keywords',
|
||||
config: {
|
||||
inputs_config: { enabled: false, preset_response: '' },
|
||||
outputs_config: { enabled: false, preset_response: '' },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.queryByText(/feature\.moderation\.(allEnabled|inputEnabled|outputEnabled)/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not open setting modal when clicking settings button while disabled', () => {
|
||||
renderWithProvider({ disabled: true }, {
|
||||
moderation: {
|
||||
@ -351,6 +429,15 @@ describe('Moderation', () => {
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should invoke onSaveCallback from enable modal without onChange', () => {
|
||||
renderWithProvider()
|
||||
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
|
||||
const modalCall = mockSetShowModerationSettingModal.mock.calls[0][0]
|
||||
expect(() => modalCall.onSaveCallback({ enabled: true, type: 'keywords', config: {} })).not.toThrow()
|
||||
})
|
||||
|
||||
it('should invoke onCancelCallback from enable modal and set enabled false', () => {
|
||||
const onChange = vi.fn()
|
||||
renderWithProvider({ onChange })
|
||||
@ -364,6 +451,31 @@ describe('Moderation', () => {
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should invoke onCancelCallback from enable modal without onChange', () => {
|
||||
renderWithProvider()
|
||||
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
|
||||
const modalCall = mockSetShowModerationSettingModal.mock.calls[0][0]
|
||||
expect(() => modalCall.onCancelCallback()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should disable moderation when toggled off without onChange', () => {
|
||||
renderWithProvider({}, {
|
||||
moderation: {
|
||||
enabled: true,
|
||||
type: 'keywords',
|
||||
config: {
|
||||
inputs_config: { enabled: true, preset_response: '' },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(() => {
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('should not show modal when enabling with existing type', () => {
|
||||
renderWithProvider({}, {
|
||||
moderation: {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import type { ModerationContentConfig } from '@/models/debug'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as i18n from 'react-i18next'
|
||||
import ModerationContent from '../moderation-content'
|
||||
|
||||
const defaultConfig: ModerationContentConfig = {
|
||||
@ -124,4 +125,19 @@ describe('ModerationContent', () => {
|
||||
expect(screen.getByText('5')).toBeInTheDocument()
|
||||
expect(screen.getByText('100')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fallback to empty placeholder when translation is empty', () => {
|
||||
const useTranslationSpy = vi.spyOn(i18n, 'useTranslation').mockReturnValue({
|
||||
t: (key: string) => key === 'feature.moderation.modal.content.placeholder' ? '' : key,
|
||||
i18n: { language: 'en-US' },
|
||||
} as unknown as ReturnType<typeof i18n.useTranslation>)
|
||||
|
||||
renderComponent({
|
||||
config: { enabled: true, preset_response: '' },
|
||||
showPreset: true,
|
||||
})
|
||||
|
||||
expect(screen.getByRole('textbox')).toHaveAttribute('placeholder', '')
|
||||
useTranslationSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import type { ModerationConfig } from '@/models/debug'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as i18n from 'react-i18next'
|
||||
import ModerationSettingModal from '../moderation-setting-modal'
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
@ -68,6 +69,13 @@ const defaultData: ModerationConfig = {
|
||||
|
||||
describe('ModerationSettingModal', () => {
|
||||
const onSave = vi.fn()
|
||||
const renderModal = async (ui: React.ReactNode) => {
|
||||
await act(async () => {
|
||||
render(ui)
|
||||
await Promise.resolve()
|
||||
})
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCodeBasedExtensions = { data: { data: [] } }
|
||||
@ -93,7 +101,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should render the modal title', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -105,7 +113,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should render provider options', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -120,7 +128,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should show keywords textarea when keywords type is selected', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -134,7 +142,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should render cancel and save buttons', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -148,7 +156,7 @@ describe('ModerationSettingModal', () => {
|
||||
|
||||
it('should call onCancel when cancel is clicked', async () => {
|
||||
const onCancel = vi.fn()
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={onCancel}
|
||||
@ -161,6 +169,60 @@ describe('ModerationSettingModal', () => {
|
||||
expect(onCancel).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onCancel when close icon receives Enter key', async () => {
|
||||
const onCancel = vi.fn()
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={onCancel}
|
||||
onSave={onSave}
|
||||
/>,
|
||||
)
|
||||
|
||||
const closeButton = document.querySelector('div[role="button"][tabindex="0"]') as HTMLElement
|
||||
expect(closeButton).toBeInTheDocument()
|
||||
closeButton.focus()
|
||||
fireEvent.keyDown(closeButton, { key: 'Enter' })
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onCancel when close icon receives Space key', async () => {
|
||||
const onCancel = vi.fn()
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={onCancel}
|
||||
onSave={onSave}
|
||||
/>,
|
||||
)
|
||||
|
||||
const closeButton = document.querySelector('div[role="button"][tabindex="0"]') as HTMLElement
|
||||
expect(closeButton).toBeInTheDocument()
|
||||
closeButton.focus()
|
||||
fireEvent.keyDown(closeButton, { key: ' ' })
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should not call onCancel when close icon receives non-action key', async () => {
|
||||
const onCancel = vi.fn()
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={onCancel}
|
||||
onSave={onSave}
|
||||
/>,
|
||||
)
|
||||
|
||||
const closeButton = document.querySelector('div[role="button"][tabindex="0"]') as HTMLElement
|
||||
expect(closeButton).toBeInTheDocument()
|
||||
closeButton.focus()
|
||||
fireEvent.keyDown(closeButton, { key: 'Escape' })
|
||||
|
||||
expect(onCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show error when saving without inputs or outputs enabled', async () => {
|
||||
const data: ModerationConfig = {
|
||||
...defaultData,
|
||||
@ -170,7 +232,7 @@ describe('ModerationSettingModal', () => {
|
||||
outputs_config: { enabled: false, preset_response: '' },
|
||||
},
|
||||
}
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={data}
|
||||
onCancel={vi.fn()}
|
||||
@ -194,7 +256,7 @@ describe('ModerationSettingModal', () => {
|
||||
outputs_config: { enabled: false, preset_response: '' },
|
||||
},
|
||||
}
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={data}
|
||||
onCancel={vi.fn()}
|
||||
@ -218,7 +280,7 @@ describe('ModerationSettingModal', () => {
|
||||
outputs_config: { enabled: false, preset_response: '' },
|
||||
},
|
||||
}
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={data}
|
||||
onCancel={vi.fn()}
|
||||
@ -239,7 +301,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should show api selector when api type is selected', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'api', config: { inputs_config: { enabled: true, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -251,7 +313,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should switch provider type when clicked', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -267,7 +329,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should update keywords on textarea change', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -282,7 +344,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should render moderation content sections', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -303,7 +365,7 @@ describe('ModerationSettingModal', () => {
|
||||
outputs_config: { enabled: false, preset_response: '' },
|
||||
},
|
||||
}
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={data}
|
||||
onCancel={vi.fn()}
|
||||
@ -327,7 +389,7 @@ describe('ModerationSettingModal', () => {
|
||||
outputs_config: { enabled: false, preset_response: '' },
|
||||
},
|
||||
}
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={data}
|
||||
onCancel={vi.fn()}
|
||||
@ -352,7 +414,7 @@ describe('ModerationSettingModal', () => {
|
||||
outputs_config: { enabled: false, preset_response: '' },
|
||||
},
|
||||
}
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={data}
|
||||
onCancel={vi.fn()}
|
||||
@ -380,7 +442,7 @@ describe('ModerationSettingModal', () => {
|
||||
outputs_config: { enabled: true, preset_response: '' },
|
||||
},
|
||||
}
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={data}
|
||||
onCancel={vi.fn()}
|
||||
@ -396,7 +458,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should toggle input moderation content', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -413,7 +475,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should toggle output moderation content', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -430,7 +492,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should select api extension via api selector', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'api', config: { inputs_config: { enabled: true, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -450,7 +512,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should save with openai_moderation type when configured', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{
|
||||
enabled: true,
|
||||
@ -473,7 +535,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should handle keyword truncation to 100 chars per line and 100 lines', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -499,7 +561,7 @@ describe('ModerationSettingModal', () => {
|
||||
outputs_config: { enabled: true, preset_response: 'output blocked' },
|
||||
},
|
||||
}
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={data}
|
||||
onCancel={vi.fn()}
|
||||
@ -518,7 +580,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should switch from keywords to api type', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -535,7 +597,7 @@ describe('ModerationSettingModal', () => {
|
||||
})
|
||||
|
||||
it('should handle empty lines in keywords', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -566,7 +628,7 @@ describe('ModerationSettingModal', () => {
|
||||
refetch: vi.fn(),
|
||||
}
|
||||
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'openai_moderation', config: { inputs_config: { enabled: true, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -594,7 +656,7 @@ describe('ModerationSettingModal', () => {
|
||||
refetch: vi.fn(),
|
||||
}
|
||||
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'openai_moderation', config: { inputs_config: { enabled: true, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -605,6 +667,10 @@ describe('ModerationSettingModal', () => {
|
||||
fireEvent.click(screen.getByText(/settings\.provider/))
|
||||
|
||||
expect(mockSetShowAccountSettingModal).toHaveBeenCalled()
|
||||
|
||||
const modalCall = mockSetShowAccountSettingModal.mock.calls[0][0]
|
||||
modalCall.onCancelCallback()
|
||||
expect(mockModelProvidersData.refetch).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not save when OpenAI type is selected but not configured', async () => {
|
||||
@ -624,7 +690,7 @@ describe('ModerationSettingModal', () => {
|
||||
refetch: vi.fn(),
|
||||
}
|
||||
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'openai_moderation', config: { inputs_config: { enabled: true, preset_response: 'blocked' }, outputs_config: { enabled: false, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -650,7 +716,7 @@ describe('ModerationSettingModal', () => {
|
||||
},
|
||||
}
|
||||
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -674,7 +740,7 @@ describe('ModerationSettingModal', () => {
|
||||
},
|
||||
}
|
||||
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'custom-ext', config: { inputs_config: { enabled: true, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -699,7 +765,7 @@ describe('ModerationSettingModal', () => {
|
||||
},
|
||||
}
|
||||
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
@ -727,7 +793,7 @@ describe('ModerationSettingModal', () => {
|
||||
},
|
||||
}
|
||||
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'custom-ext', config: { inputs_config: { enabled: true, preset_response: 'blocked' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -755,7 +821,7 @@ describe('ModerationSettingModal', () => {
|
||||
},
|
||||
}
|
||||
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'custom-ext', config: { api_url: 'https://example.com', inputs_config: { enabled: true, preset_response: 'blocked' }, outputs_config: { enabled: false, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -773,8 +839,40 @@ describe('ModerationSettingModal', () => {
|
||||
}))
|
||||
})
|
||||
|
||||
it('should update code-based extension form value and save updated config', async () => {
|
||||
mockCodeBasedExtensions = {
|
||||
data: {
|
||||
data: [{
|
||||
name: 'custom-ext',
|
||||
label: { 'en-US': 'Custom Extension', 'zh-Hans': '自定义扩展' },
|
||||
form_schema: [
|
||||
{ variable: 'api_url', label: { 'en-US': 'API URL', 'zh-Hans': 'API 地址' }, type: 'text-input', required: true, default: '', placeholder: 'Enter URL', options: [], max_length: 200 },
|
||||
],
|
||||
}],
|
||||
},
|
||||
}
|
||||
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'custom-ext', config: { inputs_config: { enabled: true, preset_response: 'blocked' }, outputs_config: { enabled: false, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
onSave={onSave}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.change(screen.getByPlaceholderText('Enter URL'), { target: { value: 'https://changed.com' } })
|
||||
fireEvent.click(screen.getByText(/operation\.save/))
|
||||
|
||||
expect(onSave).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'custom-ext',
|
||||
config: expect.objectContaining({
|
||||
api_url: 'https://changed.com',
|
||||
}),
|
||||
}))
|
||||
})
|
||||
|
||||
it('should show doc link for api type', async () => {
|
||||
await render(
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{ ...defaultData, type: 'api', config: { inputs_config: { enabled: true, preset_response: '' } } }}
|
||||
onCancel={vi.fn()}
|
||||
@ -784,4 +882,56 @@ describe('ModerationSettingModal', () => {
|
||||
|
||||
expect(screen.getByText(/apiBasedExtension\.link/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fallback missing inputs_config to disabled in formatted save data', async () => {
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={{
|
||||
enabled: true,
|
||||
type: 'api',
|
||||
config: {
|
||||
api_based_extension_id: 'ext-fallback',
|
||||
outputs_config: { enabled: true, preset_response: '' },
|
||||
},
|
||||
}}
|
||||
onCancel={vi.fn()}
|
||||
onSave={onSave}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText(/operation\.save/))
|
||||
|
||||
expect(onSave).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'api',
|
||||
config: expect.objectContaining({
|
||||
inputs_config: expect.objectContaining({ enabled: false }),
|
||||
outputs_config: expect.objectContaining({ enabled: true }),
|
||||
}),
|
||||
}))
|
||||
})
|
||||
|
||||
it('should fallback to empty translated strings for optional placeholders and titles', async () => {
|
||||
const useTranslationSpy = vi.spyOn(i18n, 'useTranslation').mockReturnValue({
|
||||
t: (key: string) => [
|
||||
'feature.moderation.modal.keywords.placeholder',
|
||||
'feature.moderation.modal.content.input',
|
||||
'feature.moderation.modal.content.output',
|
||||
].includes(key)
|
||||
? ''
|
||||
: key,
|
||||
i18n: { language: 'en-US' },
|
||||
} as unknown as ReturnType<typeof i18n.useTranslation>)
|
||||
|
||||
await renderModal(
|
||||
<ModerationSettingModal
|
||||
data={defaultData}
|
||||
onCancel={vi.fn()}
|
||||
onSave={onSave}
|
||||
/>,
|
||||
)
|
||||
|
||||
const textarea = screen.getAllByRole('textbox')[0]
|
||||
expect(textarea).toHaveAttribute('placeholder', '')
|
||||
useTranslationSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
@ -30,6 +30,7 @@ const Moderation = ({
|
||||
const [isHovering, setIsHovering] = useState(false)
|
||||
|
||||
const handleOpenModerationSettingModal = () => {
|
||||
/* v8 ignore next -- guarded path is not reachable in tests with a real disabled button because click is prevented at DOM level. @preserve */
|
||||
if (disabled)
|
||||
return
|
||||
|
||||
|
||||
@ -185,6 +185,7 @@ const ModerationSettingModal: FC<ModerationSettingModalProps> = ({
|
||||
}
|
||||
|
||||
const handleSave = () => {
|
||||
/* v8 ignore next -- UI-invariant guard: same condition is used in Save button disabled logic, so when true handleSave has no user-triggerable invocation path. @preserve */
|
||||
if (localeData.type === 'openai_moderation' && !isOpenAIProviderConfigured)
|
||||
return
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { Features } from '../../../types'
|
||||
import type { OnFeaturesChange } from '@/app/components/base/features/types'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
@ -12,6 +13,23 @@ vi.mock('@/i18n-config/language', () => ({
|
||||
],
|
||||
}))
|
||||
|
||||
vi.mock('../voice-settings', () => ({
|
||||
default: ({
|
||||
open,
|
||||
onOpen,
|
||||
children,
|
||||
}: {
|
||||
open: boolean
|
||||
onOpen: (open: boolean) => void
|
||||
children: ReactNode
|
||||
}) => (
|
||||
<div data-testid="voice-settings" data-open={open ? 'true' : 'false'}>
|
||||
<button data-testid="open-voice-settings" onClick={() => onOpen(true)}>open-voice-settings</button>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const defaultFeatures: Features = {
|
||||
moreLikeThis: { enabled: false },
|
||||
opening: { enabled: false },
|
||||
@ -68,6 +86,12 @@ describe('TextToSpeech', () => {
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should toggle without onChange callback', () => {
|
||||
renderWithProvider()
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
expect(screen.getByRole('switch')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show language and voice info when enabled and not hovering', () => {
|
||||
renderWithProvider({}, {
|
||||
text2speech: { enabled: true, language: 'en-US', voice: 'alloy' },
|
||||
@ -97,6 +121,19 @@ describe('TextToSpeech', () => {
|
||||
expect(screen.getByText(/voice\.voiceSettings\.title/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide voice settings button after mouse leave', () => {
|
||||
renderWithProvider({}, {
|
||||
text2speech: { enabled: true },
|
||||
})
|
||||
|
||||
const card = screen.getByText(/feature\.textToSpeech\.title/).closest('[class]')!
|
||||
fireEvent.mouseEnter(card)
|
||||
expect(screen.getByText(/voice\.voiceSettings\.title/)).toBeInTheDocument()
|
||||
|
||||
fireEvent.mouseLeave(card)
|
||||
expect(screen.queryByText(/voice\.voiceSettings\.title/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show autoPlay enabled text when autoPlay is enabled', () => {
|
||||
renderWithProvider({}, {
|
||||
text2speech: { enabled: true, language: 'en-US', autoPlay: TtsAutoPlay.enabled },
|
||||
@ -112,4 +149,16 @@ describe('TextToSpeech', () => {
|
||||
|
||||
expect(screen.getByText(/voice\.voiceSettings\.autoPlayDisabled/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass open false to voice settings when disabled and modal is opened', () => {
|
||||
renderWithProvider({ disabled: true }, {
|
||||
text2speech: { enabled: true },
|
||||
})
|
||||
|
||||
const card = screen.getByText(/feature\.textToSpeech\.title/).closest('[class]')!
|
||||
fireEvent.mouseEnter(card)
|
||||
fireEvent.click(screen.getByTestId('open-voice-settings'))
|
||||
|
||||
expect(screen.getByTestId('voice-settings')).toHaveAttribute('data-open', 'false')
|
||||
})
|
||||
})
|
||||
|
||||
@ -3,6 +3,38 @@ import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { FeaturesProvider } from '../../../context'
|
||||
import VoiceSettings from '../voice-settings'
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({
|
||||
children,
|
||||
placement,
|
||||
offset,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
placement?: string
|
||||
offset?: { mainAxis?: number }
|
||||
}) => (
|
||||
<div
|
||||
data-testid="voice-settings-portal"
|
||||
data-placement={placement}
|
||||
data-main-axis={offset?.mainAxis}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
PortalToFollowElemTrigger: ({
|
||||
children,
|
||||
onClick,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
onClick?: () => void
|
||||
}) => (
|
||||
<div data-testid="voice-settings-trigger" onClick={onClick}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
|
||||
}))
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
usePathname: () => '/app/test-app-id/configuration',
|
||||
useParams: () => ({ appId: 'test-app-id' }),
|
||||
@ -102,4 +134,19 @@ describe('VoiceSettings', () => {
|
||||
|
||||
expect(onOpen).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
it('should use top placement and mainAxis 4 when placementLeft is false', () => {
|
||||
renderWithProvider(
|
||||
<VoiceSettings open={false} onOpen={vi.fn()} placementLeft={false}>
|
||||
<button>Settings</button>
|
||||
</VoiceSettings>,
|
||||
)
|
||||
|
||||
const portal = screen.getAllByTestId('voice-settings-portal')
|
||||
.find(item => item.hasAttribute('data-main-axis'))
|
||||
|
||||
expect(portal).toBeDefined()
|
||||
expect(portal).toHaveAttribute('data-placement', 'top')
|
||||
expect(portal).toHaveAttribute('data-main-axis', '4')
|
||||
})
|
||||
})
|
||||
|
||||
@ -25,6 +25,11 @@ describe('createFileStore', () => {
|
||||
expect(store.getState().files).toEqual([])
|
||||
})
|
||||
|
||||
it('should create a store with empty array when value is null', () => {
|
||||
const store = createFileStore(null as unknown as FileEntity[])
|
||||
expect(store.getState().files).toEqual([])
|
||||
})
|
||||
|
||||
it('should create a store with initial files', () => {
|
||||
const files = [createMockFile()]
|
||||
const store = createFileStore(files)
|
||||
@ -96,6 +101,11 @@ describe('useFileStore', () => {
|
||||
|
||||
expect(result.current).toBe(store)
|
||||
})
|
||||
|
||||
it('should return null when no provider exists', () => {
|
||||
const { result } = renderHook(() => useFileStore())
|
||||
expect(result.current).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('FileContextProvider', () => {
|
||||
|
||||
@ -126,13 +126,11 @@ describe('FileFromLinkOrLocal', () => {
|
||||
expect(input).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should not submit when url is empty', () => {
|
||||
it('should have disabled OK button when url is empty', () => {
|
||||
renderAndOpen({ showFromLink: true })
|
||||
|
||||
const okButton = screen.getByText(/operation\.ok/)
|
||||
fireEvent.click(okButton)
|
||||
|
||||
expect(screen.queryByText(/fileUploader\.pasteFileLinkInvalid/)).not.toBeInTheDocument()
|
||||
const okButton = screen.getByRole('button', { name: /operation\.ok/ })
|
||||
expect(okButton).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should call handleLoadFileFromLink when valid URL is submitted', () => {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user