mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
Merge branch 'main' into feat/agent-node-v2
This commit is contained in:
@ -287,7 +287,7 @@ def test_validate_inputs_optional_file_with_empty_string():
|
||||
|
||||
|
||||
def test_validate_inputs_optional_file_list_with_empty_list():
|
||||
"""Test that optional FILE_LIST variable with empty list returns None"""
|
||||
"""Test that optional FILE_LIST variable with empty list returns empty list (not None)"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var_file_list = VariableEntity(
|
||||
@ -302,6 +302,28 @@ def test_validate_inputs_optional_file_list_with_empty_list():
|
||||
value=[],
|
||||
)
|
||||
|
||||
# Empty list should be preserved, not converted to None
|
||||
# This allows downstream components like document_extractor to handle empty lists properly
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_validate_inputs_optional_file_list_with_empty_string():
|
||||
"""Test that optional FILE_LIST variable with empty string returns None"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var_file_list = VariableEntity(
|
||||
variable="test_file_list",
|
||||
label="test_file_list",
|
||||
type=VariableEntityType.FILE_LIST,
|
||||
required=False,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file_list,
|
||||
value="",
|
||||
)
|
||||
|
||||
# Empty string should be treated as unset
|
||||
assert result is None
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,420 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueuePingEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
|
||||
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock(spec=ChatAppGenerateEntity)
|
||||
entity.task_id = "test-task-id"
|
||||
entity.app_id = "test-app-id"
|
||||
# minimal app_config used by pipeline internals
|
||||
entity.app_config = SimpleNamespace(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.CHAT,
|
||||
app_model_config_dict={},
|
||||
additional_features=None,
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
# minimal model_conf for LLMResult init
|
||||
entity.model_conf = SimpleNamespace(
|
||||
model="test-model",
|
||||
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
|
||||
credentials={},
|
||||
)
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = Mock(spec=AppQueueManager)
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_cycle_manager(self):
|
||||
"""Create a mock message cycle manager."""
|
||||
manager = Mock()
|
||||
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
|
||||
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
|
||||
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
|
||||
manager.handle_retriever_resources = Mock()
|
||||
manager.handle_annotation_reply.return_value = None
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation(self):
|
||||
"""Create a mock conversation."""
|
||||
conversation = Mock()
|
||||
conversation.id = "test-conversation-id"
|
||||
conversation.mode = "chat"
|
||||
return conversation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message."""
|
||||
message = Mock()
|
||||
message.id = "test-message-id"
|
||||
message.created_at = Mock()
|
||||
message.created_at.timestamp.return_value = 1234567890
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_state(self):
|
||||
"""Create a mock task state."""
|
||||
task_state = Mock(spec=EasyUITaskState)
|
||||
|
||||
# Create LLM result mock
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.prompt_messages = []
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = ""
|
||||
|
||||
task_state.llm_result = llm_result
|
||||
task_state.answer = ""
|
||||
|
||||
return task_state
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(
|
||||
self,
|
||||
mock_application_generate_entity,
|
||||
mock_queue_manager,
|
||||
mock_conversation,
|
||||
mock_message,
|
||||
mock_message_cycle_manager,
|
||||
mock_task_state,
|
||||
):
|
||||
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
|
||||
):
|
||||
pipeline = EasyUIBasedGenerateTaskPipeline(
|
||||
application_generate_entity=mock_application_generate_entity,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
stream=True,
|
||||
)
|
||||
pipeline._message_cycle_manager = mock_message_cycle_manager
|
||||
pipeline._task_state = mock_task_state
|
||||
return pipeline
|
||||
|
||||
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
|
||||
self, pipeline, mock_message_cycle_manager
|
||||
):
|
||||
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
|
||||
# Setup a minimal LLM chunk event
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "hi"
|
||||
chunk.prompt_messages = []
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
|
||||
|
||||
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with text content."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Hello, world!"
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello, world!"
|
||||
|
||||
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with list content."""
|
||||
# Setup
|
||||
text_content = Mock(spec=TextPromptMessageContent)
|
||||
text_content.data = "Hello"
|
||||
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = [text_content, " world!"]
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of agent message events."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Agent response"
|
||||
|
||||
agent_message_event = Mock(spec=QueueAgentMessageEvent)
|
||||
agent_message_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = agent_message_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Ensure method under assertion is a mock to track calls
|
||||
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
# Agent messages should use _agent_message_to_stream_response
|
||||
pipeline._agent_message_to_stream_response.assert_called_once_with(
|
||||
answer="Agent response", message_id="test-message-id"
|
||||
)
|
||||
|
||||
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of message end events."""
|
||||
# Setup
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = "Final response"
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = llm_result
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert mock_task_state.llm_result == llm_result
|
||||
pipeline._save_message.assert_called_once()
|
||||
pipeline._message_end_to_stream_response.assert_called_once()
|
||||
|
||||
def test_error_event(self, pipeline):
|
||||
"""Test handling of error events."""
|
||||
# Setup
|
||||
error_event = Mock(spec=QueueErrorEvent)
|
||||
error_event.error = Exception("Test error")
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = error_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.handle_error = Mock(return_value=Exception("Test error"))
|
||||
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.handle_error.assert_called_once()
|
||||
pipeline.error_to_stream_response.assert_called_once()
|
||||
|
||||
def test_ping_event(self, pipeline):
|
||||
"""Test handling of ping events."""
|
||||
# Setup
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.ping_stream_response.assert_called_once()
|
||||
|
||||
def test_file_event(self, pipeline, mock_message_cycle_manager):
|
||||
"""Test handling of file events."""
|
||||
# Setup
|
||||
file_event = Mock(spec=QueueMessageFileEvent)
|
||||
file_event.message_file_id = "file-id"
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = file_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
file_response = Mock(spec=MessageFileStreamResponse)
|
||||
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert responses[0] == file_response
|
||||
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
|
||||
|
||||
def test_publisher_is_called_with_messages(self, pipeline):
|
||||
"""Test that publisher publishes messages when provided."""
|
||||
# Setup
|
||||
publisher = Mock(spec=AppGeneratorTTSPublisher)
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
# Called once with message and once with None at the end
|
||||
assert publisher.publish.call_count == 2
|
||||
publisher.publish.assert_any_call(mock_queue_message)
|
||||
publisher.publish.assert_any_call(None)
|
||||
|
||||
def test_trace_manager_passed_to_save_message(self, pipeline):
|
||||
"""Test that trace manager is passed to _save_message."""
|
||||
# Setup
|
||||
trace_manager = Mock(spec=TraceQueueManager)
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = None
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
|
||||
|
||||
# Assert
|
||||
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
|
||||
|
||||
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling multiple events in sequence."""
|
||||
# Setup
|
||||
chunk1 = Mock()
|
||||
chunk1.delta.message.content = "Hello"
|
||||
chunk1.prompt_messages = []
|
||||
|
||||
chunk2 = Mock()
|
||||
chunk2.delta.message.content = " world!"
|
||||
chunk2.prompt_messages = []
|
||||
|
||||
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event1.chunk = chunk1
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event2.chunk = chunk2
|
||||
|
||||
mock_queue_messages = [
|
||||
Mock(event=llm_chunk_event1),
|
||||
Mock(event=ping_event),
|
||||
Mock(event=llm_chunk_event2),
|
||||
]
|
||||
pipeline.queue_manager.listen.return_value = mock_queue_messages
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 3
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
# Verify calls to message_to_stream_response
|
||||
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
@ -0,0 +1,166 @@
|
||||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
|
||||
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
|
||||
|
||||
class TestMessageCycleManagerOptimization:
|
||||
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock()
|
||||
entity.task_id = "test-task-id"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def message_cycle_manager(self, mock_application_generate_entity):
|
||||
"""Create a message cycle manager instance."""
|
||||
task_state = Mock()
|
||||
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
|
||||
|
||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and no message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute: compute event type once, then pass to message_to_stream_response
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
# Execute with event_type provided
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
# Should not query database when event_type is provided
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world",
|
||||
message_id="test-message-id",
|
||||
from_variable_selector=["var1", "var2"],
|
||||
event_type=StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.from_variable_selector == ["var1", "var2"]
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
|
||||
def test_optimization_usage_example(self, message_cycle_manager):
|
||||
"""Test the optimization pattern that should be used by callers."""
|
||||
# Step 1: Get event type once (this queries database)
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None # No files
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Should query database once
|
||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
||||
assert event_type == StreamEvent.MESSAGE
|
||||
|
||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
||||
|
||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
chunk2_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Should not query database again
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
assert chunk1_response.answer == "Chunk 1"
|
||||
assert chunk2_response.answer == "Chunk 2"
|
||||
@ -96,7 +96,7 @@ class TestNotionExtractorAuthentication:
|
||||
def test_init_with_integration_token_fallback(self, mock_get_token, mock_config, mock_document_model):
|
||||
"""Test NotionExtractor falls back to integration token when credential not found."""
|
||||
# Arrange
|
||||
mock_get_token.return_value = None
|
||||
mock_get_token.side_effect = Exception("No credential id found")
|
||||
mock_config.NOTION_INTEGRATION_TOKEN = "integration-token-fallback"
|
||||
|
||||
# Act
|
||||
@ -105,7 +105,7 @@ class TestNotionExtractorAuthentication:
|
||||
notion_obj_id="page-456",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant-789",
|
||||
credential_id="cred-123",
|
||||
credential_id=None,
|
||||
document_model=mock_document_model,
|
||||
)
|
||||
|
||||
@ -117,7 +117,7 @@ class TestNotionExtractorAuthentication:
|
||||
def test_init_missing_credentials_raises_error(self, mock_get_token, mock_config, mock_document_model):
|
||||
"""Test NotionExtractor raises error when no credentials available."""
|
||||
# Arrange
|
||||
mock_get_token.return_value = None
|
||||
mock_get_token.side_effect = Exception("No credential id found")
|
||||
mock_config.NOTION_INTEGRATION_TOKEN = None
|
||||
|
||||
# Act & Assert
|
||||
@ -127,7 +127,7 @@ class TestNotionExtractorAuthentication:
|
||||
notion_obj_id="page-456",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant-789",
|
||||
credential_id="cred-123",
|
||||
credential_id=None,
|
||||
document_model=mock_document_model,
|
||||
)
|
||||
assert "Must specify `integration_token`" in str(exc_info.value)
|
||||
|
||||
@ -1,52 +1,109 @@
|
||||
import secrets
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||
from core.helper.ssrf_proxy import (
|
||||
SSRF_DEFAULT_MAX_RETRIES,
|
||||
_get_user_provided_host_header,
|
||||
make_request,
|
||||
)
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_successful_request(mock_request):
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_successful_request(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_request.return_value = mock_response
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_exceed_max_retries(mock_request):
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_retry_exceed_max_retries(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
|
||||
mock_request.side_effect = side_effects
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_logic_success(mock_request):
|
||||
side_effects = []
|
||||
class TestGetUserProvidedHostHeader:
|
||||
"""Tests for _get_user_provided_host_header function."""
|
||||
|
||||
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
|
||||
status_code = secrets.choice(STATUS_FORCELIST)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
side_effects.append(mock_response)
|
||||
def test_returns_none_when_headers_is_none(self):
|
||||
assert _get_user_provided_host_header(None) is None
|
||||
|
||||
mock_response_200 = MagicMock()
|
||||
mock_response_200.status_code = 200
|
||||
side_effects.append(mock_response_200)
|
||||
def test_returns_none_when_headers_is_empty(self):
|
||||
assert _get_user_provided_host_header({}) is None
|
||||
|
||||
mock_request.side_effect = side_effects
|
||||
def test_returns_none_when_host_header_not_present(self):
|
||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}
|
||||
assert _get_user_provided_host_header(headers) is None
|
||||
|
||||
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
||||
def test_returns_host_header_lowercase(self):
|
||||
headers = {"host": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_uppercase(self):
|
||||
headers = {"HOST": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_mixed_case(self):
|
||||
headers = {"HoSt": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_from_multiple_headers(self):
|
||||
headers = {"Content-Type": "application/json", "Host": "api.example.com", "Authorization": "Bearer token"}
|
||||
assert _get_user_provided_host_header(headers) == "api.example.com"
|
||||
|
||||
def test_returns_first_host_header_when_duplicates(self):
|
||||
headers = {"host": "first.com", "Host": "second.com"}
|
||||
# Should return the first one encountered (iteration order is preserved in dict)
|
||||
result = _get_user_provided_host_header(headers)
|
||||
assert result in ("first.com", "second.com")
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_without_user_header(mock_get_client):
|
||||
"""Test that when no Host header is provided, the default behavior is maintained."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
|
||||
assert response.status_code == 200
|
||||
# Host should not be set if not provided by user
|
||||
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_with_user_header(mock_get_client):
|
||||
"""Test that user-provided Host header is preserved in the request."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
custom_host = "custom.example.com:8080"
|
||||
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
||||
assert mock_request.call_args_list[0][1].get("method") == "GET"
|
||||
|
||||
@ -1,129 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client():
|
||||
"""Fixture: Mock Redis client"""
|
||||
with patch("core.helper.tool_provider_cache.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class TestToolProviderListCache:
|
||||
"""Test class for ToolProviderListCache"""
|
||||
|
||||
def test_generate_cache_key(self):
|
||||
"""Test cache key generation logic"""
|
||||
# Scenario 1: Specify typ (valid literal value)
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
|
||||
|
||||
# Scenario 2: typ is None (defaults to "all")
|
||||
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
|
||||
|
||||
def test_get_cached_providers_hit(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit and successful decoding"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "api"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
|
||||
|
||||
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
|
||||
assert result == mock_providers
|
||||
|
||||
def test_get_cached_providers_decode_error(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit but decoding failed"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = b"invalid_json_data"
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_get_cached_providers_miss(self, mock_redis_client):
|
||||
"""Test get cached providers - cache miss"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_set_cached_providers(self, mock_redis_client):
|
||||
"""Test set cached providers"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
|
||||
|
||||
mock_redis_client.setex.assert_called_once_with(
|
||||
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
|
||||
)
|
||||
|
||||
def test_invalidate_cache_specific_type(self, mock_redis_client):
|
||||
"""Test invalidate cache - specific type"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "workflow"
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id, typ)
|
||||
|
||||
mock_redis_client.delete.assert_called_once_with(cache_key)
|
||||
|
||||
def test_invalidate_cache_all_types(self, mock_redis_client):
|
||||
"""Test invalidate cache - clear all tenant cache"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_keys = [
|
||||
b"tool_providers:tenant_id:tenant_123:type:all",
|
||||
b"tool_providers:tenant_id:tenant_123:type:builtin",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = mock_keys
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
def test_invalidate_cache_no_keys(self, mock_redis_client):
|
||||
"""Test invalidate cache - no cache keys for tenant"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.scan_iter.return_value = []
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
def test_redis_fallback_default_return(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - default return value (Redis error)"""
|
||||
mock_redis_client.get.side_effect = RedisError("Redis connection error")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers("tenant_123")
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_redis_fallback_no_default(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - no default return value (Redis error)"""
|
||||
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
|
||||
|
||||
try:
|
||||
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
|
||||
except RedisError:
|
||||
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
|
||||
|
||||
mock_redis_client.setex.assert_called_once()
|
||||
@ -1,6 +1,9 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
@ -136,3 +139,51 @@ class TestValidateProjectName:
|
||||
"""Test custom default name"""
|
||||
result = validate_project_name("", "Custom Default")
|
||||
assert result == "Custom Default"
|
||||
|
||||
|
||||
class TestGenerateDottedOrder:
|
||||
"""Test cases for generate_dotted_order function"""
|
||||
|
||||
def test_dotted_order_has_6_digit_microseconds(self):
|
||||
"""Test that timestamp includes full 6-digit microseconds for LangSmith API compatibility.
|
||||
|
||||
LangSmith API expects timestamps in format: YYYYMMDDTHHMMSSffffffZ (6-digit microseconds).
|
||||
Previously, the code truncated to 3 digits which caused API errors:
|
||||
'cannot parse .111 as .000000'
|
||||
"""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "test-run-id"
|
||||
result = generate_dotted_order(run_id, start_time)
|
||||
|
||||
# Extract timestamp portion (before the run_id)
|
||||
timestamp_match = re.match(r"^(\d{8}T\d{6})(\d+)Z", result)
|
||||
assert timestamp_match is not None, "Timestamp format should match YYYYMMDDTHHMMSSffffffZ"
|
||||
|
||||
microseconds = timestamp_match.group(2)
|
||||
assert len(microseconds) == 6, f"Microseconds should be 6 digits, got {len(microseconds)}: {microseconds}"
|
||||
|
||||
def test_dotted_order_format_matches_langsmith_expected(self):
|
||||
"""Test that dotted_order format matches LangSmith API expected format."""
|
||||
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456)
|
||||
run_id = "abc123"
|
||||
result = generate_dotted_order(run_id, start_time)
|
||||
|
||||
# LangSmith expects: YYYYMMDDTHHMMSSffffffZ followed by run_id
|
||||
assert result == "20250115T103045123456Zabc123"
|
||||
|
||||
def test_dotted_order_with_parent(self):
|
||||
"""Test dotted_order generation with parent order uses dot separator."""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "child-run-id"
|
||||
parent_order = "20251223T041955000000Zparent-run-id"
|
||||
result = generate_dotted_order(run_id, start_time, parent_order)
|
||||
|
||||
assert result == "20251223T041955000000Zparent-run-id.20251223T041955111000Zchild-run-id"
|
||||
|
||||
def test_dotted_order_without_parent_has_no_dot(self):
|
||||
"""Test dotted_order generation without parent has no dot separator."""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "test-run-id"
|
||||
result = generate_dotted_order(run_id, start_time, None)
|
||||
|
||||
assert "." not in result
|
||||
|
||||
@ -0,0 +1,327 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import (
|
||||
PGVector,
|
||||
PGVectorConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestPGVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=False,
|
||||
)
|
||||
self.collection_name = "test_collection"
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test PGVector initialization."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
assert pgvector._collection_name == self.collection_name
|
||||
assert pgvector.table_name == f"embedding_{self.collection_name}"
|
||||
assert pgvector.get_type() == "pgvector"
|
||||
assert pgvector.pool is not None
|
||||
assert pgvector.pg_bigm is False
|
||||
assert pgvector.index_hash is not None
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init_with_pg_bigm(self, mock_pool_class):
|
||||
"""Test PGVector initialization with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
|
||||
assert pgvector.pg_bigm is True
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_basic(self, mock_redis, mock_pool_class):
|
||||
"""Test basic collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify SQL execution calls
|
||||
assert mock_cursor.execute.called
|
||||
|
||||
# Check that CREATE TABLE was called with correct dimension
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(1536)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that CREATE INDEX was called (dimension <= 2000)
|
||||
create_index_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
|
||||
]
|
||||
assert len(create_index_calls) == 1
|
||||
|
||||
# Verify Redis cache was set
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with dimension > 2000 (no HNSW index)."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(3072) # Dimension > 2000
|
||||
|
||||
# Check that CREATE TABLE was called
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(3072)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that HNSW index was NOT created (dimension > 2000)
|
||||
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
|
||||
assert len(hnsw_index_calls) == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that pg_bigm index was created
|
||||
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
|
||||
assert len(bigm_index_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
|
||||
"""Test that vector extension is created if it doesn't exist."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
# First call: vector extension doesn't exist
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that CREATE EXTENSION was called
|
||||
create_extension_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
|
||||
]
|
||||
assert len(create_extension_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
|
||||
"""Test that collection creation is skipped when cache exists."""
|
||||
# Mock Redis operations - cache exists
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = 1 # Cache exists
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that no SQL was executed (early return due to cache)
|
||||
assert mock_cursor.execute.call_count == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
|
||||
"""Test that Redis lock is used during collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify Redis lock was acquired with correct lock name
|
||||
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
|
||||
|
||||
# Verify lock context manager was entered and exited
|
||||
mock_lock.__enter__.assert_called_once()
|
||||
mock_lock.__exit__.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_get_cursor_context_manager(self, mock_pool_class):
|
||||
"""Test that _get_cursor properly manages connection lifecycle."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
with pgvector._get_cursor() as cur:
|
||||
assert cur == mock_cursor
|
||||
|
||||
# Verify connection lifecycle methods were called
|
||||
mock_pool.getconn.assert_called_once()
|
||||
mock_cursor.close.assert_called_once()
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_pool.putconn.assert_called_once_with(mock_conn)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_config_override",
|
||||
[
|
||||
{"host": ""}, # Test empty host
|
||||
{"port": 0}, # Test invalid port
|
||||
{"user": ""}, # Test empty user
|
||||
{"password": ""}, # Test empty password
|
||||
{"database": ""}, # Test empty database
|
||||
{"min_connection": 0}, # Test invalid min_connection
|
||||
{"max_connection": 0}, # Test invalid max_connection
|
||||
{"min_connection": 10, "max_connection": 5}, # Test min > max
|
||||
],
|
||||
)
|
||||
def test_config_validation_parametrized(invalid_config_override):
|
||||
"""Test configuration validation for various invalid inputs using parametrize."""
|
||||
config = {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
"database": "test_db",
|
||||
"min_connection": 1,
|
||||
"max_connection": 5,
|
||||
}
|
||||
config.update(invalid_config_override)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PGVectorConfig(**config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,5 +1,7 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
@ -25,3 +27,35 @@ def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||
|
||||
assert job_id is not None
|
||||
assert isinstance(job_id, str)
|
||||
|
||||
|
||||
def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture):
|
||||
api_key = "fc-"
|
||||
base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"]
|
||||
for base in base_urls:
|
||||
app = FirecrawlApp(api_key=api_key, base_url=base)
|
||||
mock_post = mocker.patch("httpx.post")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"id": "job123"}
|
||||
mock_post.return_value = mock_resp
|
||||
app.crawl_url("https://example.com", params=None)
|
||||
called_url = mock_post.call_args[0][0]
|
||||
assert called_url == "https://custom.firecrawl.dev/v2/crawl"
|
||||
|
||||
|
||||
def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture):
|
||||
api_key = "fc-"
|
||||
app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/")
|
||||
mock_post = mocker.patch("httpx.post")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
mock_resp.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
app.scrape_url("https://example.com")
|
||||
|
||||
# Should not raise a JSONDecodeError; current behavior reports status code only
|
||||
assert str(excinfo.value) == "Failed to scrape URL. Status code: 404"
|
||||
|
||||
@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch):
|
||||
# DB interactions should be recorded
|
||||
assert len(db_stub.session.added) == 2
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
|
||||
def test_extract_images_from_docx_uses_internal_files_url():
|
||||
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
|
||||
# Test the URL generation logic directly
|
||||
from configs import dify_config
|
||||
|
||||
# Mock the configuration values
|
||||
original_files_url = getattr(dify_config, "FILES_URL", None)
|
||||
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
|
||||
|
||||
try:
|
||||
# Set both URLs - INTERNAL should take precedence
|
||||
dify_config.FILES_URL = "http://external.example.com"
|
||||
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
|
||||
|
||||
# Test the URL generation logic (same as in word_extractor.py)
|
||||
upload_file_id = "test_file_id"
|
||||
|
||||
# This is the pattern we fixed in the word extractor
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
|
||||
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
|
||||
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
if original_files_url is not None:
|
||||
dify_config.FILES_URL = original_files_url
|
||||
if original_internal_files_url is not None:
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
||||
@ -421,7 +421,18 @@ class TestRetrievalService:
|
||||
# In real code, this waits for all futures to complete
|
||||
# In tests, futures complete immediately, so wait is a no-op
|
||||
with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"):
|
||||
yield mock_executor
|
||||
# Mock concurrent.futures.as_completed for early error propagation
|
||||
# In real code, this yields futures as they complete
|
||||
# In tests, we yield all futures immediately since they're already done
|
||||
def mock_as_completed(futures_list, timeout=None):
|
||||
"""Mock as_completed that yields futures immediately."""
|
||||
yield from futures_list
|
||||
|
||||
with patch(
|
||||
"core.rag.datasource.retrieval_service.concurrent.futures.as_completed",
|
||||
side_effect=mock_as_completed,
|
||||
):
|
||||
yield mock_executor
|
||||
|
||||
# ==================== Vector Search Tests ====================
|
||||
|
||||
|
||||
@ -0,0 +1,873 @@
|
||||
"""
|
||||
Unit tests for DatasetRetrieval.process_metadata_filter_func.
|
||||
|
||||
This module provides comprehensive test coverage for the process_metadata_filter_func
|
||||
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
|
||||
filter expressions based on metadata filtering conditions.
|
||||
|
||||
Conditions Tested:
|
||||
==================
|
||||
1. **String Conditions**: contains, not contains, start with, end with
|
||||
2. **Equality Conditions**: is / =, is not / ≠
|
||||
3. **Null Conditions**: empty, not empty
|
||||
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
|
||||
5. **List Conditions**: in
|
||||
6. **Edge Cases**: None values, different data types (str, int, float)
|
||||
|
||||
Test Architecture:
|
||||
==================
|
||||
- Direct instantiation of DatasetRetrieval
|
||||
- Mocking of DatasetDocument model attributes
|
||||
- Verification of SQLAlchemy filter expressions
|
||||
- Follows Arrange-Act-Assert (AAA) pattern
|
||||
|
||||
Running Tests:
|
||||
==============
|
||||
# Run all tests in this module
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
|
||||
|
||||
# Run a specific test
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
|
||||
TestProcessMetadataFilterFunc::test_contains_condition -v
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
|
||||
class TestProcessMetadataFilterFunc:
|
||||
"""
|
||||
Comprehensive test suite for process_metadata_filter_func method.
|
||||
|
||||
This test class validates all metadata filtering conditions supported by
|
||||
the DatasetRetrieval class, including string operations, numeric comparisons,
|
||||
null checks, and list operations.
|
||||
|
||||
Method Signature:
|
||||
==================
|
||||
def process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
) -> list:
|
||||
|
||||
The method builds SQLAlchemy filter expressions by:
|
||||
1. Validating value is not None (except for empty/not empty conditions)
|
||||
2. Using DatasetDocument.doc_metadata JSON field operations
|
||||
3. Adding appropriate SQLAlchemy expressions to the filters list
|
||||
4. Returning the updated filters list
|
||||
|
||||
Mocking Strategy:
|
||||
==================
|
||||
- Mock DatasetDocument.doc_metadata to avoid database dependencies
|
||||
- Verify filter expressions are created correctly
|
||||
- Test with various data types (str, int, float, list)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval(self):
|
||||
"""
|
||||
Create a DatasetRetrieval instance for testing.
|
||||
|
||||
Returns:
|
||||
DatasetRetrieval: Instance to test process_metadata_filter_func
|
||||
"""
|
||||
return DatasetRetrieval()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_metadata(self):
|
||||
"""
|
||||
Mock the DatasetDocument.doc_metadata JSON field.
|
||||
|
||||
The method uses DatasetDocument.doc_metadata[metadata_name] to access
|
||||
JSON fields. We mock this to avoid database dependencies.
|
||||
|
||||
Returns:
|
||||
Mock: Mocked doc_metadata attribute
|
||||
"""
|
||||
mock_metadata_field = MagicMock()
|
||||
|
||||
# Create mock for string access
|
||||
mock_string_access = MagicMock()
|
||||
mock_string_access.like = MagicMock()
|
||||
mock_string_access.notlike = MagicMock()
|
||||
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.in_ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for float access (for numeric comparisons)
|
||||
mock_float_access = MagicMock()
|
||||
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for null checks
|
||||
mock_null_access = MagicMock()
|
||||
mock_null_access.is_ = MagicMock(return_value=MagicMock())
|
||||
mock_null_access.isnot = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Setup __getitem__ to return appropriate mock based on usage
|
||||
def getitem_side_effect(name):
|
||||
if name in ["author", "title", "category"]:
|
||||
return mock_string_access
|
||||
elif name in ["year", "price", "rating"]:
|
||||
return mock_float_access
|
||||
else:
|
||||
return mock_string_access
|
||||
|
||||
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
|
||||
mock_metadata_field.as_string.return_value = mock_string_access
|
||||
mock_metadata_field.as_float.return_value = mock_float_access
|
||||
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
|
||||
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
|
||||
|
||||
return mock_metadata_field
|
||||
|
||||
# ==================== String Condition Tests ====================
|
||||
|
||||
def test_contains_condition_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'contains' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "John"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_contains_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with NOT LIKE expression
|
||||
- Pattern matching uses %value% syntax with negation
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not contains"
|
||||
metadata_name = "title"
|
||||
value = "banned"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_start_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'start with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "start with"
|
||||
metadata_name = "category"
|
||||
value = "tech"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_end_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'end with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "end with"
|
||||
metadata_name = "filename"
|
||||
value = ".pdf"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Equality Condition Tests ====================
|
||||
|
||||
def test_is_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' (=) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with equality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = "Jane Doe"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_equals_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test '=' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is' condition
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "="
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_int_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with integer value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_float_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with float value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "price"
|
||||
value = 19.99
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' (≠) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with inequality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "author"
|
||||
value = "Unknown"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test '≠' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is not' condition
|
||||
- Inequality expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≠"
|
||||
metadata_name = "category"
|
||||
value = "archived"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_numeric_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' condition with numeric value.
|
||||
|
||||
Verifies:
|
||||
- Numeric inequality comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Null Condition Tests ====================
|
||||
|
||||
def test_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'empty' condition (null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "empty"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not empty' condition (not null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NOT NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not empty"
|
||||
metadata_name = "description"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Numeric Comparison Tests ====================
|
||||
|
||||
def test_before_condition(self, retrieval):
|
||||
"""
|
||||
Test 'before' (<) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "before"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '<' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'before' condition
|
||||
- Less than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "price"
|
||||
value = 100.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_after_condition(self, retrieval):
|
||||
"""
|
||||
Test 'after' (>) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "after"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '>' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'after' condition
|
||||
- Greater than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≤' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≤"
|
||||
metadata_name = "price"
|
||||
value = 50.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '<=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≤' condition
|
||||
- Less than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<="
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≥' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≥"
|
||||
metadata_name = "rating"
|
||||
value = 3.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '>=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≥' condition
|
||||
- Greater than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== List/In Condition Tests ====================
|
||||
|
||||
def test_in_condition_with_comma_separated_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with comma-separated string value.
|
||||
|
||||
Verifies:
|
||||
- String is split into list
|
||||
- Whitespace is trimmed from each value
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "tech, science, AI "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_list_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with list value.
|
||||
|
||||
Verifies:
|
||||
- List is processed correctly
|
||||
- None values are filtered out
|
||||
- IN expression is created with valid values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "tags"
|
||||
value = ["python", "javascript", None, "golang"]
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_tuple_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with tuple value.
|
||||
|
||||
Verifies:
|
||||
- Tuple is processed like a list
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ("tech", "science", "ai")
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_empty_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with empty string value.
|
||||
|
||||
Verifies:
|
||||
- Empty string results in literal(False) filter
|
||||
- No valid values to match
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
# Verify it's a literal(False) expression
|
||||
# This is a bit tricky to test without access to the actual expression
|
||||
|
||||
def test_in_condition_with_only_whitespace(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with whitespace-only string value.
|
||||
|
||||
Verifies:
|
||||
- Whitespace-only string results in literal(False) filter
|
||||
- All values are stripped and filtered out
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = " , , "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_single_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with single non-comma string.
|
||||
|
||||
Verifies:
|
||||
- Single string is treated as single-item list
|
||||
- IN expression is created with one value
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Edge Case Tests ====================
|
||||
|
||||
def test_none_value_with_non_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with conditions that require value.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values (except empty/not empty)
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0 # No filter added
|
||||
|
||||
def test_none_value_with_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with 'is' (=) condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_none_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "year"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_existing_filters_preserved(self, retrieval):
|
||||
"""
|
||||
Test that existing filters are preserved.
|
||||
|
||||
Verifies:
|
||||
- Existing filters in the list are not removed
|
||||
- New filters are appended to the list
|
||||
"""
|
||||
existing_filter = MagicMock()
|
||||
filters = [existing_filter]
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 2
|
||||
assert filters[0] == existing_filter
|
||||
|
||||
def test_multiple_filters_accumulated(self, retrieval):
|
||||
"""
|
||||
Test multiple calls to accumulate filters.
|
||||
|
||||
Verifies:
|
||||
- Each call adds a new filter to the list
|
||||
- All filters are preserved across calls
|
||||
"""
|
||||
filters = []
|
||||
|
||||
# First filter
|
||||
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
|
||||
assert len(filters) == 1
|
||||
|
||||
# Second filter
|
||||
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
|
||||
assert len(filters) == 2
|
||||
|
||||
# Third filter
|
||||
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
|
||||
assert len(filters) == 3
|
||||
|
||||
def test_unknown_condition(self, retrieval):
|
||||
"""
|
||||
Test unknown/unsupported condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for unknown conditions
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "unknown_condition"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_empty_string_value_with_contains(self, retrieval):
|
||||
"""
|
||||
Test empty string value with 'contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filter is added even with empty string
|
||||
- LIKE expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_special_characters_in_value(self, retrieval):
|
||||
"""
|
||||
Test special characters in value string.
|
||||
|
||||
Verifies:
|
||||
- Special characters are handled in value
|
||||
- LIKE expression is created correctly
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "title"
|
||||
value = "C++ & Python's features"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_zero_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test zero value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Zero is treated as valid value
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "price"
|
||||
value = 0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_negative_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test negative value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Negative numbers are handled correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "temperature"
|
||||
value = -10.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_float_value_with_integer_comparison(self, retrieval):
|
||||
"""
|
||||
Test float value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Float values work correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter:
|
||||
# Verify no empty chunks
|
||||
assert all(len(chunk) > 0 for chunk in result)
|
||||
|
||||
def test_double_slash_n(self):
|
||||
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
|
||||
separator = "\\n\\n---\\n\\n"
|
||||
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
|
||||
chunks = splitter.split_text(data)
|
||||
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Metadata Preservation
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import (
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@ -348,3 +351,127 @@ def test_init_params():
|
||||
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_bearer():
|
||||
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": ""},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_basic():
|
||||
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "basic", "api_key": ""},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_custom():
|
||||
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_whitespace_only_api_key_raises_error():
|
||||
"""Test that whitespace-only API key raises AuthorizationConfigError."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": " "},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_valid_api_key_works():
|
||||
"""Test that valid API key works correctly for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": "valid-api-key-123"},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Should not raise an error
|
||||
headers = executor._assembling_headers()
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == "Bearer valid-api-key-123"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables):
|
||||
|
||||
|
||||
def test_json_object_valid_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@ -65,7 +68,7 @@ def test_json_object_valid_schema():
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
result = node._run()
|
||||
@ -74,12 +77,23 @@ def test_json_object_valid_schema():
|
||||
|
||||
|
||||
def test_json_object_invalid_json_string():
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
json_schema=schema,
|
||||
)
|
||||
]
|
||||
|
||||
@ -88,38 +102,21 @@ def test_json_object_invalid_json_string():
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
|
||||
def test_json_object_valid_json_but_not_object(value):
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": value}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_does_not_match_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema():
|
||||
]
|
||||
|
||||
# age is a string, which violates the schema (expects number)
|
||||
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema():
|
||||
|
||||
|
||||
def test_json_object_missing_required_schema_field():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field():
|
||||
]
|
||||
|
||||
# Missing required field "name"
|
||||
user_inputs = {"profile": {"age": 20}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided():
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=False,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided():
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
# Current implementation raises a validation error even when the variable is optional
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
with pytest.raises(ValueError, match="profile is required in input form"):
|
||||
node._run()
|
||||
|
||||
Reference in New Issue
Block a user