Merge branch 'main' into chore/ssrf-config

This commit is contained in:
-LAN-
2025-09-17 13:02:04 +08:00
committed by GitHub
512 changed files with 5611 additions and 3949 deletions

View File

@ -1,6 +1,5 @@
import unittest
from datetime import UTC, datetime
from typing import Optional
from unittest.mock import patch
from uuid import uuid4
@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase):
self.session.rollback()
def _create_upload_file(
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase):
return upload_file
def _create_tool_file(
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
@ -101,9 +100,7 @@ class TestStorageKeyLoader(unittest.TestCase):
return tool_file
def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id

View File

@ -5,8 +5,6 @@ from decimal import Decimal
from json import dumps
# import monkeypatch
from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
@ -113,8 +111,8 @@ class MockModelClass(PluginModelClient):
@staticmethod
def generate_function_call(
tools: Optional[list[PromptMessageTool]],
) -> Optional[AssistantPromptMessage.ToolCall]:
tools: list[PromptMessageTool] | None,
) -> AssistantPromptMessage.ToolCall | None:
if not tools or len(tools) == 0:
return None
function: PromptMessageTool = tools[0]
@ -157,7 +155,7 @@ class MockModelClass(PluginModelClient):
def mocked_chat_create_sync(
model: str,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
tools: list[PromptMessageTool] | None = None,
) -> LLMResult:
tool_call = MockModelClass.generate_function_call(tools=tools)
@ -186,7 +184,7 @@ class MockModelClass(PluginModelClient):
def mocked_chat_create_stream(
model: str,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
tools: list[PromptMessageTool] | None = None,
) -> Generator[LLMResultChunk, None, None]:
tool_call = MockModelClass.generate_function_call(tools=tools)
@ -241,9 +239,9 @@ class MockModelClass(PluginModelClient):
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
model_parameters: dict | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
):
return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)

View File

@ -3,6 +3,7 @@
import os
import tempfile
import unittest
from pathlib import Path
import pytest
@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
# Test download
with tempfile.NamedTemporaryFile() as temp_file:
storage.download(test_filename, temp_file.name)
with open(temp_file.name, "rb") as f:
downloaded_content = f.read()
downloaded_content = Path(temp_file.name).read_bytes()
assert downloaded_content == test_content
# Test scan

View File

@ -1,6 +1,5 @@
import os
from collections import UserDict
from typing import Optional
from unittest.mock import MagicMock
import pytest
@ -22,7 +21,7 @@ class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
adapter: Optional[HTTPAdapter] = None,
adapter: HTTPAdapter | None = None,
):
self.conn = MagicMock()
self._config = MagicMock()

View File

@ -1,5 +1,5 @@
import os
from typing import Optional, Union
from typing import Union
import pytest
from _pytest.monkeypatch import MonkeyPatch
@ -23,16 +23,16 @@ class MockTcvectordbClass:
key="",
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=10,
adapter: Optional[HTTPAdapter] = None,
adapter: HTTPAdapter | None = None,
pool_size: int = 2,
proxies: Optional[dict] = None,
password: Optional[str] = None,
proxies: dict | None = None,
password: str | None = None,
**kwargs,
):
self._conn = None
self._read_consistency = read_consistency
def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase:
return RPCDatabase(
name="dify",
read_consistency=self._read_consistency,
@ -42,7 +42,7 @@ class MockTcvectordbClass:
return True
def describe_collection(
self, database_name: str, collection_name: str, timeout: Optional[float] = None
self, database_name: str, collection_name: str, timeout: float | None = None
) -> RPCCollection:
index = Index(
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
@ -71,13 +71,13 @@ class MockTcvectordbClass:
collection_name: str,
shard: int,
replicas: int,
description: Optional[str] = None,
index: Optional[Index] = None,
embedding: Optional[Embedding] = None,
timeout: Optional[float] = None,
ttl_config: Optional[dict] = None,
filter_index_config: Optional[FilterIndexConfig] = None,
indexes: Optional[list[IndexField]] = None,
description: str | None = None,
index: Index | None = None,
embedding: Embedding | None = None,
timeout: float | None = None,
ttl_config: dict | None = None,
filter_index_config: FilterIndexConfig | None = None,
indexes: list[IndexField] | None = None,
) -> RPCCollection:
return RPCCollection(
RPCDatabase(
@ -102,7 +102,7 @@ class MockTcvectordbClass:
database_name: str,
collection_name: str,
documents: list[Union[Document, dict]],
timeout: Optional[float] = None,
timeout: float | None = None,
build_index: bool = True,
**kwargs,
):
@ -113,12 +113,12 @@ class MockTcvectordbClass:
database_name: str,
collection_name: str,
vectors: list[list[float]],
filter: Optional[Filter] = None,
filter: Filter | None = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
output_fields: list[str] | None = None,
timeout: float | None = None,
) -> list[list[dict]]:
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
@ -126,14 +126,14 @@ class MockTcvectordbClass:
self,
database_name: str,
collection_name: str,
ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
filter: Optional[Union[Filter, str]] = None,
rerank: Optional[Rerank] = None,
retrieve_vector: Optional[bool] = None,
output_fields: Optional[list[str]] = None,
limit: Optional[int] = None,
timeout: Optional[float] = None,
ann: Union[list[AnnSearch], AnnSearch] | None = None,
match: Union[list[KeywordSearch], KeywordSearch] | None = None,
filter: Union[Filter, str] | None = None,
rerank: Rerank | None = None,
retrieve_vector: bool | None = None,
output_fields: list[str] | None = None,
limit: int | None = None,
timeout: float | None = None,
return_pd_object=False,
**kwargs,
) -> list[list[dict]]:
@ -143,13 +143,13 @@ class MockTcvectordbClass:
self,
database_name: str,
collection_name: str,
document_ids: Optional[list] = None,
document_ids: list | None = None,
retrieve_vector: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
limit: int | None = None,
offset: int | None = None,
filter: Filter | None = None,
output_fields: list[str] | None = None,
timeout: float | None = None,
):
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
@ -157,13 +157,13 @@ class MockTcvectordbClass:
self,
database_name: str,
collection_name: str,
document_ids: Optional[list[str]] = None,
filter: Optional[Filter] = None,
timeout: Optional[float] = None,
document_ids: list[str] | None = None,
filter: Filter | None = None,
timeout: float | None = None,
):
return {"code": 0, "msg": "operation success"}
def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None):
def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None):
return {"code": 0, "msg": "operation success"}

View File

@ -1,6 +1,5 @@
import os
from collections import UserDict
from typing import Optional
import pytest
from _pytest.monkeypatch import MonkeyPatch
@ -34,7 +33,7 @@ class MockIndex:
include_vectors: bool = False,
include_metadata: bool = False,
filter: str = "",
data: Optional[str] = None,
data: str | None = None,
namespace: str = "",
include_data: bool = False,
):

View File

@ -1,7 +1,6 @@
import os
import time
import uuid
from typing import Optional
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
@ -29,7 +28,7 @@ def get_mocked_fetch_memory(memory_text: str):
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: Optional[int] = None,
message_limit: int | None = None,
):
return memory_text

View File

@ -11,7 +11,6 @@ import logging
import os
from collections.abc import Generator
from pathlib import Path
from typing import Optional
import pytest
from flask import Flask
@ -42,10 +41,10 @@ class DifyTestContainers:
def __init__(self):
"""Initialize container management with default configurations."""
self.postgres: Optional[PostgresContainer] = None
self.redis: Optional[RedisContainer] = None
self.dify_sandbox: Optional[DockerContainer] = None
self.dify_plugin_daemon: Optional[DockerContainer] = None
self.postgres: PostgresContainer | None = None
self.redis: RedisContainer | None = None
self.dify_sandbox: DockerContainer | None = None
self.dify_plugin_daemon: DockerContainer | None = None
self._containers_started = False
logger.info("DifyTestContainers initialized - ready to manage test containers")

View File

@ -1,6 +1,5 @@
import unittest
from datetime import UTC, datetime
from typing import Optional
from unittest.mock import patch
from uuid import uuid4
@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase):
self.session.rollback()
def _create_upload_file(
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase):
return upload_file
def _create_tool_file(
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
@ -102,9 +101,7 @@ class TestStorageKeyLoader(unittest.TestCase):
return tool_file
def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id

View File

@ -1,3 +1,5 @@
import time
import uuid
from unittest.mock import patch
import pytest
@ -248,9 +250,15 @@ class TestWebAppAuthService:
- Proper error handling for non-existent accounts
- Correct exception type and message
"""
# Arrange: Use non-existent email
fake = Faker()
non_existent_email = fake.email()
# Arrange: Generate a guaranteed non-existent email
# Use UUID and timestamp to ensure uniqueness
unique_id = str(uuid.uuid4()).replace("-", "")
timestamp = str(int(time.time() * 1000000)) # microseconds
non_existent_email = f"nonexistent_{unique_id}_{timestamp}@test-domain-that-never-exists.invalid"
# Double-check this email doesn't exist in the database
existing_account = db_session_with_containers.query(Account).filter_by(email=non_existent_email).first()
assert existing_account is None, f"Test email {non_existent_email} already exists in database"
# Act & Assert: Verify proper error handling
with pytest.raises(AccountNotFoundError):

View File

@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances.
import uuid
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(csv_content)
Path(file_path).write_text(csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download
@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask:
db.session.commit()
# Test each unavailable document
for i, document in enumerate(test_cases):
for document in test_cases:
job_id = str(uuid.uuid4())
batch_create_segment_to_index_task(
job_id=job_id,
@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(empty_csv_content)
Path(file_path).write_text(empty_csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download
@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(csv_content)
Path(file_path).write_text(csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download

View File

@ -362,7 +362,7 @@ class TestCleanDatasetTask:
# Create segments for each document
segments = []
for i, document in enumerate(documents):
for document in documents:
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
segments.append(segment)

View File

@ -0,0 +1,729 @@
"""
TestContainers-based integration tests for disable_segments_from_index_task.
This module provides comprehensive integration testing for the disable_segments_from_index_task
using TestContainers to ensure realistic database interactions and proper isolation.
The task is responsible for removing document segments from the search index when they are disabled.
"""
from unittest.mock import MagicMock, patch
from faker import Faker
from models import Account, Dataset, DocumentSegment
from models import Document as DatasetDocument
from models.dataset import DatasetProcessRule
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
class TestDisableSegmentsFromIndexTask:
"""
Comprehensive integration tests for disable_segments_from_index_task using testcontainers.
This test class covers all major functionality of the disable_segments_from_index_task:
- Successful segment disabling with proper index cleanup
- Error handling for various edge cases
- Database state validation after task execution
- Redis cache cleanup verification
- Index processor integration testing
All tests use the testcontainers infrastructure to ensure proper database isolation
and realistic testing environment with actual database interactions.
"""
def _create_test_account(self, db_session_with_containers, fake=None):
"""
Helper method to create a test account with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
fake: Faker instance for generating test data
Returns:
Account: Created test account instance
"""
fake = fake or Faker()
account = Account()
account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
account.tenant_id = fake.uuid4()
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US"
account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at
# Create a tenant for the account
from models.account import Tenant
tenant = Tenant()
tenant.id = account.tenant_id
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
from extensions.ext_database import db
db.session.add(tenant)
db.session.add(account)
db.session.commit()
# Set the current tenant for the account
account.current_tenant = tenant
return account
def _create_test_dataset(self, db_session_with_containers, account, fake=None):
"""
Helper method to create a test dataset with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
account: The account creating the dataset
fake: Faker instance for generating test data
Returns:
Dataset: Created test dataset instance
"""
fake = fake or Faker()
dataset = Dataset()
dataset.id = fake.uuid4()
dataset.tenant_id = account.tenant_id
dataset.name = f"Test Dataset {fake.word()}"
dataset.description = fake.text(max_nb_chars=200)
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.data_source_type = "upload_file"
dataset.indexing_technique = "high_quality"
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.embedding_model = "text-embedding-ada-002"
dataset.embedding_model_provider = "openai"
dataset.built_in_field_enabled = False
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
return dataset
def _create_test_document(self, db_session_with_containers, dataset, account, fake=None):
"""
Helper method to create a test document with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
dataset: The dataset containing the document
account: The account creating the document
fake: Faker instance for generating test data
Returns:
DatasetDocument: Created test document instance
"""
fake = fake or Faker()
document = DatasetDocument()
document.id = fake.uuid4()
document.tenant_id = dataset.tenant_id
document.dataset_id = dataset.id
document.position = 1
document.data_source_type = "upload_file"
document.data_source_info = '{"upload_file_id": "test_file_id"}'
document.batch = fake.uuid4()
document.name = f"Test Document {fake.word()}.txt"
document.created_from = "upload_file"
document.created_by = account.id
document.created_api_request_id = fake.uuid4()
document.processing_started_at = fake.date_time_this_year()
document.file_id = fake.uuid4()
document.word_count = fake.random_int(min=100, max=1000)
document.parsing_completed_at = fake.date_time_this_year()
document.cleaning_completed_at = fake.date_time_this_year()
document.splitting_completed_at = fake.date_time_this_year()
document.tokens = fake.random_int(min=50, max=500)
document.indexing_started_at = fake.date_time_this_year()
document.indexing_completed_at = fake.date_time_this_year()
document.indexing_status = "completed"
document.enabled = True
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_language = "en"
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
return document
def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None):
"""
Helper method to create test document segments with realistic data.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
document: The document containing the segments
dataset: The dataset containing the document
account: The account creating the segments
count: Number of segments to create
fake: Faker instance for generating test data
Returns:
List[DocumentSegment]: Created test segment instances
"""
fake = fake or Faker()
segments = []
for i in range(count):
segment = DocumentSegment()
segment.id = fake.uuid4()
segment.tenant_id = dataset.tenant_id
segment.dataset_id = dataset.id
segment.document_id = document.id
segment.position = i + 1
segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}"
segment.answer = f"Test answer {i + 1}" if i % 2 == 0 else None
segment.word_count = fake.random_int(min=10, max=100)
segment.tokens = fake.random_int(min=5, max=50)
segment.keywords = [fake.word() for _ in range(3)]
segment.index_node_id = f"node_{segment.id}"
segment.index_node_hash = fake.sha256()
segment.hit_count = 0
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
segment.status = "completed"
segment.created_by = account.id
segment.updated_by = account.id
segment.indexing_at = fake.date_time_this_year()
segment.completed_at = fake.date_time_this_year()
segment.error = None
segment.stopped_at = None
segments.append(segment)
from extensions.ext_database import db
for segment in segments:
db.session.add(segment)
db.session.commit()
return segments
def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None):
"""
Helper method to create a dataset process rule.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
dataset: The dataset for the process rule
fake: Faker instance for generating test data
Returns:
DatasetProcessRule: Created process rule instance
"""
fake = fake or Faker()
process_rule = DatasetProcessRule()
process_rule.id = fake.uuid4()
process_rule.tenant_id = dataset.tenant_id
process_rule.dataset_id = dataset.id
process_rule.mode = "automatic"
process_rule.rules = (
"{"
'"mode": "automatic", '
'"rules": {'
'"pre_processing_rules": [], "segmentation": '
'{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}'
"}"
)
process_rule.created_by = dataset.created_by
process_rule.updated_by = dataset.updated_by
from extensions.ext_database import db
db.session.add(process_rule)
db.session.commit()
return process_rule
def test_disable_segments_success(self, db_session_with_containers):
"""
Test successful disabling of segments from index.
This test verifies that the task can correctly disable segments from the index
when all conditions are met, including proper index cleanup and database state updates.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
segment_ids = [segment.id for segment in segments]
# Mock the index processor to avoid external dependencies
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify index processor was called correctly
mock_factory.assert_called_once_with(document.doc_form)
mock_processor.clean.assert_called_once()
# Verify the call arguments (checking by attributes rather than object identity)
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # First argument should be the dataset
assert sorted(call_args[0][1]) == sorted(
[segment.index_node_id for segment in segments]
) # Compare sorted lists to handle any order while preserving duplicates
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is False
# Verify Redis cache cleanup was called for each segment
assert mock_redis.delete.call_count == len(segments)
for segment in segments:
expected_key = f"segment_{segment.id}_indexing"
mock_redis.delete.assert_any_call(expected_key)
def test_disable_segments_dataset_not_found(self, db_session_with_containers):
"""
Test handling when dataset is not found.
This test ensures that the task correctly handles cases where the specified
dataset doesn't exist, logging appropriate messages and returning early.
"""
# Arrange
fake = Faker()
non_existent_dataset_id = fake.uuid4()
non_existent_document_id = fake.uuid4()
segment_ids = [fake.uuid4()]
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
result = disable_segments_from_index_task(segment_ids, non_existent_dataset_id, non_existent_document_id)
# Assert
assert result is None # Task should complete without returning a value
# Redis should not be called when dataset is not found
mock_redis.delete.assert_not_called()
def test_disable_segments_document_not_found(self, db_session_with_containers):
"""
Test handling when document is not found.
This test ensures that the task correctly handles cases where the specified
document doesn't exist, logging appropriate messages and returning early.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
non_existent_document_id = fake.uuid4()
segment_ids = [fake.uuid4()]
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, non_existent_document_id)
# Assert
assert result is None # Task should complete without returning a value
# Redis should not be called when document is not found
mock_redis.delete.assert_not_called()
def test_disable_segments_document_invalid_status(self, db_session_with_containers):
"""
Test handling when document has invalid status for disabling.
This test ensures that the task correctly handles cases where the document
is not enabled, archived, or not completed, preventing invalid operations.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
# Test case 1: Document not enabled
document.enabled = False
from extensions.ext_database import db
db.session.commit()
segment_ids = [segment.id for segment in segments]
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Redis should not be called when document status is invalid
mock_redis.delete.assert_not_called()
# Test case 2: Document archived
document.enabled = True
document.archived = True
db.session.commit()
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
mock_redis.delete.assert_not_called()
# Test case 3: Document indexing not completed
document.enabled = True
document.archived = False
document.indexing_status = "indexing"
db.session.commit()
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
mock_redis.delete.assert_not_called()
def test_disable_segments_no_segments_found(self, db_session_with_containers):
"""
Test handling when no segments are found for the given IDs.
This test ensures that the task correctly handles cases where the specified
segment IDs don't exist or don't match the dataset/document criteria.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
# Use non-existent segment IDs
non_existent_segment_ids = [fake.uuid4() for _ in range(3)]
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
result = disable_segments_from_index_task(non_existent_segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Redis should not be called when no segments are found
mock_redis.delete.assert_not_called()
def test_disable_segments_index_processor_error(self, db_session_with_containers):
"""
Test handling when index processor encounters an error.
This test verifies that the task correctly handles index processor errors
by rolling back segment states and ensuring proper cleanup.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
segment_ids = [segment.id for segment in segments]
# Mock the index processor to raise an exception
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_processor.clean.side_effect = Exception("Index processor error")
mock_factory.return_value.init_index_processor.return_value = mock_processor
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify segments were rolled back to enabled state
from extensions.ext_database import db
db.session.refresh(segments[0])
db.session.refresh(segments[1])
# Check that segments are re-enabled after error
updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
for segment in updated_segments:
assert segment.enabled is True
assert segment.disabled_at is None
assert segment.disabled_by is None
# Verify Redis cache cleanup was still called
assert mock_redis.delete.call_count == len(segments)
def test_disable_segments_with_different_doc_forms(self, db_session_with_containers):
"""
Test disabling segments with different document forms.
This test verifies that the task correctly handles different document forms
(paragraph, qa, parent_child) and initializes the appropriate index processor.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
segment_ids = [segment.id for segment in segments]
# Test different document forms
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
for doc_form in doc_forms:
# Update document form
document.doc_form = doc_form
from extensions.ext_database import db
db.session.commit()
# Mock the index processor factory
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
mock_factory.assert_called_with(doc_form)
def test_disable_segments_performance_timing(self, db_session_with_containers):
"""
Test that the task properly measures and logs performance timing.
This test verifies that the task correctly measures execution time
and logs performance metrics for monitoring purposes.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
segment_ids = [segment.id for segment in segments]
# Mock the index processor
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Mock time.perf_counter to control timing
with patch("tasks.disable_segments_from_index_task.time.perf_counter") as mock_perf_counter:
mock_perf_counter.side_effect = [1000.0, 1000.5] # 0.5 seconds execution time
# Mock logger to capture log messages
with patch("tasks.disable_segments_from_index_task.logger") as mock_logger:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify performance logging
mock_logger.info.assert_called()
log_calls = [call[0][0] for call in mock_logger.info.call_args_list]
performance_log = next((call for call in log_calls if "latency" in call), None)
assert performance_log is not None
assert "0.5" in performance_log # Should log the execution time
def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers):
"""
Test that Redis cache is properly cleaned up for all segments.
This test verifies that the task correctly removes indexing cache entries
from Redis for all processed segments, preventing stale cache issues.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 5, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
segment_ids = [segment.id for segment in segments]
# Mock the index processor
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
# Mock Redis client to track delete calls
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify Redis delete was called for each segment
assert mock_redis.delete.call_count == len(segments)
# Verify correct cache keys were used
expected_keys = [f"segment_{segment.id}_indexing" for segment in segments]
actual_calls = [call[0][0] for call in mock_redis.delete.call_args_list]
for expected_key in expected_keys:
assert expected_key in actual_calls
def test_disable_segments_database_session_cleanup(self, db_session_with_containers):
"""
Test that database session is properly closed after task execution.
This test verifies that the task correctly manages database sessions
and ensures proper cleanup to prevent connection leaks.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
segment_ids = [segment.id for segment in segments]
# Mock the index processor
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Mock db.session.close to verify it's called
with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify session was closed
mock_close.assert_called()
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
"""
Test handling when empty segment IDs list is provided.
This test ensures that the task correctly handles edge cases where
an empty list of segment IDs is provided.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
empty_segment_ids = []
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
result = disable_segments_from_index_task(empty_segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Redis should not be called when no segments are provided
mock_redis.delete.assert_not_called()
def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers):
"""
Test handling when some segment IDs are valid and others are invalid.
This test verifies that the task correctly processes only the valid
segment IDs and ignores invalid ones.
"""
# Arrange
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
# Mix valid and invalid segment IDs
valid_segment_ids = [segment.id for segment in segments]
invalid_segment_ids = [fake.uuid4() for _ in range(2)]
mixed_segment_ids = valid_segment_ids + invalid_segment_ids
# Mock the index processor
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
# Mock Redis client
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Act
result = disable_segments_from_index_task(mixed_segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify index processor was called with only valid segment node IDs
expected_node_ids = [segment.index_node_id for segment in segments]
mock_processor.clean.assert_called_once()
# Verify the call arguments
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # First argument should be the dataset
assert sorted(call_args[0][1]) == sorted(
expected_node_ids
) # Compare sorted lists to handle any order while preserving duplicates
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is False
# Verify Redis cleanup was called only for valid segments
assert mock_redis.delete.call_count == len(segments)

View File

@ -0,0 +1,554 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import document_indexing_task
class TestDocumentIndexingTask:
"""Integration tests for document_indexing_task using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner,
patch("tasks.document_indexing_task.FeatureService") as mock_feature_service,
):
# Setup mock indexing runner
mock_runner_instance = MagicMock()
mock_indexing_runner.return_value = mock_runner_instance
# Setup mock feature service
mock_features = MagicMock()
mock_features.billing.enabled = False
mock_feature_service.get_features.return_value = mock_features
yield {
"indexing_runner": mock_indexing_runner,
"indexing_runner_instance": mock_runner_instance,
"feature_service": mock_feature_service,
"features": mock_features,
}
def _create_test_dataset_and_documents(
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
):
"""
Helper method to create a test dataset and documents for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
document_count: Number of documents to create
Returns:
tuple: (dataset, documents) - Created dataset and document instances
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Create dataset
dataset = Dataset(
id=fake.uuid4(),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
# Create documents
documents = []
for i in range(document_count):
document = Document(
id=fake.uuid4(),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
return dataset, documents
def _create_test_dataset_with_billing_features(
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
):
"""
Helper method to create a test dataset with billing features configured.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
billing_enabled: Whether billing is enabled
Returns:
tuple: (dataset, documents) - Created dataset and document instances
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Create dataset
dataset = Dataset(
id=fake.uuid4(),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
# Create documents
documents = []
for i in range(3):
document = Document(
id=fake.uuid4(),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
if billing_enabled:
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
mock_external_service_dependencies["features"].vector_space.limit = 100
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
return dataset, documents
def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful document indexing with multiple documents.
This test verifies:
- Proper dataset retrieval from database
- Correct document processing and status updates
- IndexingRunner integration
- Database state updates
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=3
)
document_ids = [doc.id for doc in documents]
# Act: Execute the task
document_indexing_task(dataset.id, document_ids)
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 3
def test_document_indexing_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of non-existent dataset.
This test verifies:
- Proper error handling for missing datasets
- Early return without processing
- Database session cleanup
- No unnecessary indexing runner calls
"""
# Arrange: Use non-existent dataset ID
fake = Faker()
non_existent_dataset_id = fake.uuid4()
document_ids = [fake.uuid4() for _ in range(3)]
# Act: Execute the task with non-existent dataset
document_indexing_task(non_existent_dataset_id, document_ids)
# Assert: Verify no processing occurred
mock_external_service_dependencies["indexing_runner"].assert_not_called()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
def test_document_indexing_task_document_not_found_in_dataset(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling when some documents don't exist in the dataset.
This test verifies:
- Only existing documents are processed
- Non-existent documents are ignored
- Indexing runner receives only valid documents
- Database state updates correctly
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
# Mix existing and non-existent document IDs
fake = Faker()
existing_document_ids = [doc.id for doc in documents]
non_existent_document_ids = [fake.uuid4() for _ in range(2)]
all_document_ids = existing_document_ids + non_existent_document_ids
# Act: Execute the task with mixed document IDs
document_indexing_task(dataset.id, all_document_ids)
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only existing documents were updated
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with only existing documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 2 # Only existing documents
def test_document_indexing_task_indexing_runner_exception(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of IndexingRunner exceptions.
This test verifies:
- Exceptions from IndexingRunner are properly caught
- Task completes without raising exceptions
- Database session is properly closed
- Error logging occurs
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Mock IndexingRunner to raise an exception
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception(
"Indexing runner failed"
)
# Act: Execute the task
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_mixed_document_states(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test processing documents with mixed initial states.
This test verifies:
- Documents with different initial states are handled correctly
- Only valid documents are processed
- Database state updates are consistent
- IndexingRunner receives correct documents
"""
# Arrange: Create test data
dataset, base_documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
# Create additional documents with different states
fake = Faker()
extra_documents = []
# Document with different indexing status
doc1 = Document(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=dataset.created_by,
indexing_status="completed", # Already completed
enabled=True,
)
db.session.add(doc1)
extra_documents.append(doc1)
# Document with disabled status
doc2 = Document(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=dataset.created_by,
indexing_status="waiting",
enabled=False, # Disabled
)
db.session.add(doc2)
extra_documents.append(doc2)
db.session.commit()
all_documents = base_documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with mixed document states
document_indexing_task(dataset.id, document_ids)
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify all documents were updated to parsing status
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with all documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 4
def test_document_indexing_task_billing_sandbox_plan_batch_limit(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test billing validation for sandbox plan batch upload limit.
This test verifies:
- Sandbox plan batch upload limit enforcement
- Error handling for batch upload limit exceeded
- Document status updates to error state
- Proper error message recording
"""
# Arrange: Create test data with billing enabled
dataset, documents = self._create_test_dataset_with_billing_features(
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
)
# Configure sandbox plan with batch limit
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
# Create more documents than sandbox plan allows (limit is 1)
fake = Faker()
extra_documents = []
for i in range(2): # Total will be 5 documents (3 existing + 2 new)
document = Document(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=i + 3,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=dataset.created_by,
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
extra_documents.append(document)
db.session.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
document_indexing_task(dataset.id, document_ids)
# Assert: Verify error handling
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "error"
assert document.error is not None
assert "batch upload" in document.error
assert document.stopped_at is not None
# Verify no indexing runner was called
mock_external_service_dependencies["indexing_runner"].assert_not_called()
def test_document_indexing_task_billing_disabled_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful processing when billing is disabled.
This test verifies:
- Processing continues normally when billing is disabled
- No billing validation occurs
- Documents are processed successfully
- IndexingRunner is called correctly
"""
# Arrange: Create test data with billing disabled
dataset, documents = self._create_test_dataset_with_billing_features(
db_session_with_containers, mock_external_service_dependencies, billing_enabled=False
)
document_ids = [doc.id for doc in documents]
# Act: Execute the task with billing disabled
document_indexing_task(dataset.id, document_ids)
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_document_is_paused_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of DocumentIsPausedError from IndexingRunner.
This test verifies:
- DocumentIsPausedError is properly caught and handled
- Task completes without raising exceptions
- Appropriate logging occurs
- Database session is properly closed
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Mock IndexingRunner to raise DocumentIsPausedError
from core.indexing_runner import DocumentIsPausedError
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError(
"Document indexing is paused"
)
# Act: Execute the task
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None

View File

@ -15,7 +15,7 @@ class FakeResponse:
self.status_code = status_code
self.headers = headers or {}
self.content = content
self.text = text if text else content.decode("utf-8", errors="ignore")
self.text = text or content.decode("utf-8", errors="ignore")
# ---------------------------

View File

@ -1,7 +1,6 @@
import base64
import uuid
from collections.abc import Sequence
from typing import Optional
from unittest import mock
import pytest
@ -47,7 +46,7 @@ class MockTokenBufferMemory:
self.history_messages = history_messages or []
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
if message_limit is not None:
return self.history_messages[-message_limit * 2 :]

View File

@ -1,6 +1,5 @@
import contextvars
import threading
from typing import Optional
import pytest
from flask import Flask
@ -29,7 +28,7 @@ def login_app(app: Flask) -> Flask:
login_manager.init_app(app)
@login_manager.user_loader
def load_user(user_id: str) -> Optional[User]:
def load_user(user_id: str) -> User | None:
if user_id == "test_user":
return User("test_user")
return None

View File

@ -1,5 +1,4 @@
import datetime
from typing import Optional
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, call, patch
@ -37,7 +36,7 @@ class DocumentBatchUpdateTestDataFactory:
enabled: bool = True,
archived: bool = False,
indexing_status: str = "completed",
completed_at: Optional[datetime.datetime] = None,
completed_at: datetime.datetime | None = None,
**kwargs,
) -> Mock:
"""Create a mock document with specified attributes."""

View File

@ -1,5 +1,5 @@
import datetime
from typing import Any, Optional
from typing import Any
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, create_autospec, patch
@ -24,9 +24,9 @@ class DatasetUpdateTestDataFactory:
description: str = "old_description",
indexing_technique: str = "high_quality",
retrieval_model: str = "old_model",
embedding_model_provider: Optional[str] = None,
embedding_model: Optional[str] = None,
collection_binding_id: Optional[str] = None,
embedding_model_provider: str | None = None,
embedding_model: str | None = None,
collection_binding_id: str | None = None,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""

View File

@ -1,3 +1,4 @@
from pathlib import Path
from unittest.mock import Mock, create_autospec, patch
import pytest
@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation:
# Console API create
console_create_file = "api/controllers/console/datasets/metadata.py"
if os.path.exists(console_create_file):
with open(console_create_file) as f:
content = f.read()
# Should contain nullable=False, not nullable=True
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
content = Path(console_create_file).read_text()
# Should contain nullable=False, not nullable=True
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
# Service API create
service_create_file = "api/controllers/service_api/dataset/metadata.py"
if os.path.exists(service_create_file):
with open(service_create_file) as f:
content = f.read()
# Should contain nullable=False, not nullable=True
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
assert "nullable=True" not in create_api_section
content = Path(service_create_file).read_text()
# Should contain nullable=False, not nullable=True
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
assert "nullable=True" not in create_api_section
class TestMetadataValidationSummary: