mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 22:48:07 +08:00
Merge branch 'main' into chore/ssrf-config
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
self.session.rollback()
|
||||
|
||||
def _create_upload_file(
|
||||
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
|
||||
) -> UploadFile:
|
||||
"""Helper method to create an UploadFile record for testing."""
|
||||
if file_id is None:
|
||||
@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
return upload_file
|
||||
|
||||
def _create_tool_file(
|
||||
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
|
||||
) -> ToolFile:
|
||||
"""Helper method to create a ToolFile record for testing."""
|
||||
if file_id is None:
|
||||
@ -101,9 +100,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
|
||||
return tool_file
|
||||
|
||||
def _create_file(
|
||||
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
|
||||
) -> File:
|
||||
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
|
||||
"""Helper method to create a File object for testing."""
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
@ -5,8 +5,6 @@ from decimal import Decimal
|
||||
from json import dumps
|
||||
|
||||
# import monkeypatch
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
|
||||
@ -113,8 +111,8 @@ class MockModelClass(PluginModelClient):
|
||||
|
||||
@staticmethod
|
||||
def generate_function_call(
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
) -> Optional[AssistantPromptMessage.ToolCall]:
|
||||
tools: list[PromptMessageTool] | None,
|
||||
) -> AssistantPromptMessage.ToolCall | None:
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
function: PromptMessageTool = tools[0]
|
||||
@ -157,7 +155,7 @@ class MockModelClass(PluginModelClient):
|
||||
def mocked_chat_create_sync(
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> LLMResult:
|
||||
tool_call = MockModelClass.generate_function_call(tools=tools)
|
||||
|
||||
@ -186,7 +184,7 @@ class MockModelClass(PluginModelClient):
|
||||
def mocked_chat_create_stream(
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
tool_call = MockModelClass.generate_function_call(tools=tools)
|
||||
|
||||
@ -241,9 +239,9 @@ class MockModelClass(PluginModelClient):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
with open(temp_file.name, "rb") as f:
|
||||
downloaded_content = f.read()
|
||||
downloaded_content = Path(temp_file.name).read_bytes()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -22,7 +21,7 @@ class MockBaiduVectorDBClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
config=None,
|
||||
adapter: Optional[HTTPAdapter] = None,
|
||||
adapter: HTTPAdapter | None = None,
|
||||
):
|
||||
self.conn = MagicMock()
|
||||
self._config = MagicMock()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
@ -23,16 +23,16 @@ class MockTcvectordbClass:
|
||||
key="",
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=10,
|
||||
adapter: Optional[HTTPAdapter] = None,
|
||||
adapter: HTTPAdapter | None = None,
|
||||
pool_size: int = 2,
|
||||
proxies: Optional[dict] = None,
|
||||
password: Optional[str] = None,
|
||||
proxies: dict | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._conn = None
|
||||
self._read_consistency = read_consistency
|
||||
|
||||
def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
|
||||
def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase:
|
||||
return RPCDatabase(
|
||||
name="dify",
|
||||
read_consistency=self._read_consistency,
|
||||
@ -42,7 +42,7 @@ class MockTcvectordbClass:
|
||||
return True
|
||||
|
||||
def describe_collection(
|
||||
self, database_name: str, collection_name: str, timeout: Optional[float] = None
|
||||
self, database_name: str, collection_name: str, timeout: float | None = None
|
||||
) -> RPCCollection:
|
||||
index = Index(
|
||||
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
@ -71,13 +71,13 @@ class MockTcvectordbClass:
|
||||
collection_name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: Optional[str] = None,
|
||||
index: Optional[Index] = None,
|
||||
embedding: Optional[Embedding] = None,
|
||||
timeout: Optional[float] = None,
|
||||
ttl_config: Optional[dict] = None,
|
||||
filter_index_config: Optional[FilterIndexConfig] = None,
|
||||
indexes: Optional[list[IndexField]] = None,
|
||||
description: str | None = None,
|
||||
index: Index | None = None,
|
||||
embedding: Embedding | None = None,
|
||||
timeout: float | None = None,
|
||||
ttl_config: dict | None = None,
|
||||
filter_index_config: FilterIndexConfig | None = None,
|
||||
indexes: list[IndexField] | None = None,
|
||||
) -> RPCCollection:
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
@ -102,7 +102,7 @@ class MockTcvectordbClass:
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
documents: list[Union[Document, dict]],
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
build_index: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@ -113,12 +113,12 @@ class MockTcvectordbClass:
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
vectors: list[list[float]],
|
||||
filter: Optional[Filter] = None,
|
||||
filter: Filter | None = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
@ -126,14 +126,14 @@ class MockTcvectordbClass:
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
|
||||
match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
|
||||
filter: Optional[Union[Filter, str]] = None,
|
||||
rerank: Optional[Rerank] = None,
|
||||
retrieve_vector: Optional[bool] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
ann: Union[list[AnnSearch], AnnSearch] | None = None,
|
||||
match: Union[list[KeywordSearch], KeywordSearch] | None = None,
|
||||
filter: Union[Filter, str] | None = None,
|
||||
rerank: Rerank | None = None,
|
||||
retrieve_vector: bool | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
limit: int | None = None,
|
||||
timeout: float | None = None,
|
||||
return_pd_object=False,
|
||||
**kwargs,
|
||||
) -> list[list[dict]]:
|
||||
@ -143,13 +143,13 @@ class MockTcvectordbClass:
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: Optional[list] = None,
|
||||
document_ids: list | None = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
filter: Optional[Filter] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
filter: Filter | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
|
||||
|
||||
@ -157,13 +157,13 @@ class MockTcvectordbClass:
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: Optional[list[str]] = None,
|
||||
filter: Optional[Filter] = None,
|
||||
timeout: Optional[float] = None,
|
||||
document_ids: list[str] | None = None,
|
||||
filter: Filter | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None):
|
||||
def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
@ -34,7 +33,7 @@ class MockIndex:
|
||||
include_vectors: bool = False,
|
||||
include_metadata: bool = False,
|
||||
filter: str = "",
|
||||
data: Optional[str] = None,
|
||||
data: str | None = None,
|
||||
namespace: str = "",
|
||||
include_data: bool = False,
|
||||
):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -29,7 +28,7 @@ def get_mocked_fetch_memory(memory_text: str):
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
message_limit: int | None = None,
|
||||
):
|
||||
return memory_text
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -42,10 +41,10 @@ class DifyTestContainers:
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize container management with default configurations."""
|
||||
self.postgres: Optional[PostgresContainer] = None
|
||||
self.redis: Optional[RedisContainer] = None
|
||||
self.dify_sandbox: Optional[DockerContainer] = None
|
||||
self.dify_plugin_daemon: Optional[DockerContainer] = None
|
||||
self.postgres: PostgresContainer | None = None
|
||||
self.redis: RedisContainer | None = None
|
||||
self.dify_sandbox: DockerContainer | None = None
|
||||
self.dify_plugin_daemon: DockerContainer | None = None
|
||||
self._containers_started = False
|
||||
logger.info("DifyTestContainers initialized - ready to manage test containers")
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
self.session.rollback()
|
||||
|
||||
def _create_upload_file(
|
||||
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
|
||||
) -> UploadFile:
|
||||
"""Helper method to create an UploadFile record for testing."""
|
||||
if file_id is None:
|
||||
@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
return upload_file
|
||||
|
||||
def _create_tool_file(
|
||||
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
|
||||
) -> ToolFile:
|
||||
"""Helper method to create a ToolFile record for testing."""
|
||||
if file_id is None:
|
||||
@ -102,9 +101,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
|
||||
return tool_file
|
||||
|
||||
def _create_file(
|
||||
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
|
||||
) -> File:
|
||||
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
|
||||
"""Helper method to create a File object for testing."""
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -248,9 +250,15 @@ class TestWebAppAuthService:
|
||||
- Proper error handling for non-existent accounts
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: Use non-existent email
|
||||
fake = Faker()
|
||||
non_existent_email = fake.email()
|
||||
# Arrange: Generate a guaranteed non-existent email
|
||||
# Use UUID and timestamp to ensure uniqueness
|
||||
unique_id = str(uuid.uuid4()).replace("-", "")
|
||||
timestamp = str(int(time.time() * 1000000)) # microseconds
|
||||
non_existent_email = f"nonexistent_{unique_id}_{timestamp}@test-domain-that-never-exists.invalid"
|
||||
|
||||
# Double-check this email doesn't exist in the database
|
||||
existing_account = db_session_with_containers.query(Account).filter_by(email=non_existent_email).first()
|
||||
assert existing_account is None, f"Test email {non_existent_email} already exists in database"
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
|
||||
@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances.
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
||||
def mock_download(key, file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(csv_content)
|
||||
Path(file_path).write_text(csv_content, encoding="utf-8")
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
db.session.commit()
|
||||
|
||||
# Test each unavailable document
|
||||
for i, document in enumerate(test_cases):
|
||||
for document in test_cases:
|
||||
job_id = str(uuid.uuid4())
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
||||
def mock_download(key, file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(empty_csv_content)
|
||||
Path(file_path).write_text(empty_csv_content, encoding="utf-8")
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
||||
def mock_download(key, file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(csv_content)
|
||||
Path(file_path).write_text(csv_content, encoding="utf-8")
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
|
||||
@ -362,7 +362,7 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Create segments for each document
|
||||
segments = []
|
||||
for i, document in enumerate(documents):
|
||||
for document in documents:
|
||||
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
|
||||
segments.append(segment)
|
||||
|
||||
|
||||
@ -0,0 +1,729 @@
|
||||
"""
|
||||
TestContainers-based integration tests for disable_segments_from_index_task.
|
||||
|
||||
This module provides comprehensive integration testing for the disable_segments_from_index_task
|
||||
using TestContainers to ensure realistic database interactions and proper isolation.
|
||||
The task is responsible for removing document segments from the search index when they are disabled.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from faker import Faker
|
||||
|
||||
from models import Account, Dataset, DocumentSegment
|
||||
from models import Document as DatasetDocument
|
||||
from models.dataset import DatasetProcessRule
|
||||
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
|
||||
|
||||
|
||||
class TestDisableSegmentsFromIndexTask:
|
||||
"""
|
||||
Comprehensive integration tests for disable_segments_from_index_task using testcontainers.
|
||||
|
||||
This test class covers all major functionality of the disable_segments_from_index_task:
|
||||
- Successful segment disabling with proper index cleanup
|
||||
- Error handling for various edge cases
|
||||
- Database state validation after task execution
|
||||
- Redis cache cleanup verification
|
||||
- Index processor integration testing
|
||||
|
||||
All tests use the testcontainers infrastructure to ensure proper database isolation
|
||||
and realistic testing environment with actual database interactions.
|
||||
"""
|
||||
|
||||
def _create_test_account(self, db_session_with_containers, fake=None):
|
||||
"""
|
||||
Helper method to create a test account with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
Account: Created test account instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
account = Account()
|
||||
account.id = fake.uuid4()
|
||||
account.email = fake.email()
|
||||
account.name = fake.name()
|
||||
account.avatar_url = fake.url()
|
||||
account.tenant_id = fake.uuid4()
|
||||
account.status = "active"
|
||||
account.type = "normal"
|
||||
account.role = "owner"
|
||||
account.interface_language = "en-US"
|
||||
account.created_at = fake.date_time_this_year()
|
||||
account.updated_at = account.created_at
|
||||
|
||||
# Create a tenant for the account
|
||||
from models.account import Tenant
|
||||
|
||||
tenant = Tenant()
|
||||
tenant.id = account.tenant_id
|
||||
tenant.name = f"Test Tenant {fake.company()}"
|
||||
tenant.plan = "basic"
|
||||
tenant.status = "active"
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers, account, fake=None):
|
||||
"""
|
||||
Helper method to create a test dataset with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
account: The account creating the dataset
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
Dataset: Created test dataset instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
dataset = Dataset()
|
||||
dataset.id = fake.uuid4()
|
||||
dataset.tenant_id = account.tenant_id
|
||||
dataset.name = f"Test Dataset {fake.word()}"
|
||||
dataset.description = fake.text(max_nb_chars=200)
|
||||
dataset.provider = "vendor"
|
||||
dataset.permission = "only_me"
|
||||
dataset.data_source_type = "upload_file"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.built_in_field_enabled = False
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_test_document(self, db_session_with_containers, dataset, account, fake=None):
|
||||
"""
|
||||
Helper method to create a test document with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
dataset: The dataset containing the document
|
||||
account: The account creating the document
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
DatasetDocument: Created test document instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
document = DatasetDocument()
|
||||
document.id = fake.uuid4()
|
||||
document.tenant_id = dataset.tenant_id
|
||||
document.dataset_id = dataset.id
|
||||
document.position = 1
|
||||
document.data_source_type = "upload_file"
|
||||
document.data_source_info = '{"upload_file_id": "test_file_id"}'
|
||||
document.batch = fake.uuid4()
|
||||
document.name = f"Test Document {fake.word()}.txt"
|
||||
document.created_from = "upload_file"
|
||||
document.created_by = account.id
|
||||
document.created_api_request_id = fake.uuid4()
|
||||
document.processing_started_at = fake.date_time_this_year()
|
||||
document.file_id = fake.uuid4()
|
||||
document.word_count = fake.random_int(min=100, max=1000)
|
||||
document.parsing_completed_at = fake.date_time_this_year()
|
||||
document.cleaning_completed_at = fake.date_time_this_year()
|
||||
document.splitting_completed_at = fake.date_time_this_year()
|
||||
document.tokens = fake.random_int(min=50, max=500)
|
||||
document.indexing_started_at = fake.date_time_this_year()
|
||||
document.indexing_completed_at = fake.date_time_this_year()
|
||||
document.indexing_status = "completed"
|
||||
document.enabled = True
|
||||
document.archived = False
|
||||
document.doc_form = "text_model" # Use text_model form for testing
|
||||
document.doc_language = "en"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
return document
|
||||
|
||||
def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None):
|
||||
"""
|
||||
Helper method to create test document segments with realistic data.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
document: The document containing the segments
|
||||
dataset: The dataset containing the document
|
||||
account: The account creating the segments
|
||||
count: Number of segments to create
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
List[DocumentSegment]: Created test segment instances
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
segments = []
|
||||
|
||||
for i in range(count):
|
||||
segment = DocumentSegment()
|
||||
segment.id = fake.uuid4()
|
||||
segment.tenant_id = dataset.tenant_id
|
||||
segment.dataset_id = dataset.id
|
||||
segment.document_id = document.id
|
||||
segment.position = i + 1
|
||||
segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}"
|
||||
segment.answer = f"Test answer {i + 1}" if i % 2 == 0 else None
|
||||
segment.word_count = fake.random_int(min=10, max=100)
|
||||
segment.tokens = fake.random_int(min=5, max=50)
|
||||
segment.keywords = [fake.word() for _ in range(3)]
|
||||
segment.index_node_id = f"node_{segment.id}"
|
||||
segment.index_node_hash = fake.sha256()
|
||||
segment.hit_count = 0
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
segment.status = "completed"
|
||||
segment.created_by = account.id
|
||||
segment.updated_by = account.id
|
||||
segment.indexing_at = fake.date_time_this_year()
|
||||
segment.completed_at = fake.date_time_this_year()
|
||||
segment.error = None
|
||||
segment.stopped_at = None
|
||||
|
||||
segments.append(segment)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
for segment in segments:
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
|
||||
return segments
|
||||
|
||||
def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None):
|
||||
"""
|
||||
Helper method to create a dataset process rule.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
dataset: The dataset for the process rule
|
||||
fake: Faker instance for generating test data
|
||||
|
||||
Returns:
|
||||
DatasetProcessRule: Created process rule instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
process_rule = DatasetProcessRule()
|
||||
process_rule.id = fake.uuid4()
|
||||
process_rule.tenant_id = dataset.tenant_id
|
||||
process_rule.dataset_id = dataset.id
|
||||
process_rule.mode = "automatic"
|
||||
process_rule.rules = (
|
||||
"{"
|
||||
'"mode": "automatic", '
|
||||
'"rules": {'
|
||||
'"pre_processing_rules": [], "segmentation": '
|
||||
'{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}'
|
||||
"}"
|
||||
)
|
||||
process_rule.created_by = dataset.created_by
|
||||
process_rule.updated_by = dataset.updated_by
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(process_rule)
|
||||
db.session.commit()
|
||||
|
||||
return process_rule
|
||||
|
||||
def test_disable_segments_success(self, db_session_with_containers):
|
||||
"""
|
||||
Test successful disabling of segments from index.
|
||||
|
||||
This test verifies that the task can correctly disable segments from the index
|
||||
when all conditions are met, including proper index cleanup and database state updates.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Mock the index processor to avoid external dependencies
|
||||
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
|
||||
# Verify index processor was called correctly
|
||||
mock_factory.assert_called_once_with(document.doc_form)
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify the call arguments (checking by attributes rather than object identity)
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # First argument should be the dataset
|
||||
assert sorted(call_args[0][1]) == sorted(
|
||||
[segment.index_node_id for segment in segments]
|
||||
) # Compare sorted lists to handle any order while preserving duplicates
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is False
|
||||
|
||||
# Verify Redis cache cleanup was called for each segment
|
||||
assert mock_redis.delete.call_count == len(segments)
|
||||
for segment in segments:
|
||||
expected_key = f"segment_{segment.id}_indexing"
|
||||
mock_redis.delete.assert_any_call(expected_key)
|
||||
|
||||
def test_disable_segments_dataset_not_found(self, db_session_with_containers):
|
||||
"""
|
||||
Test handling when dataset is not found.
|
||||
|
||||
This test ensures that the task correctly handles cases where the specified
|
||||
dataset doesn't exist, logging appropriate messages and returning early.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
non_existent_dataset_id = fake.uuid4()
|
||||
non_existent_document_id = fake.uuid4()
|
||||
segment_ids = [fake.uuid4()]
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, non_existent_dataset_id, non_existent_document_id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Redis should not be called when dataset is not found
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_document_not_found(self, db_session_with_containers):
|
||||
"""
|
||||
Test handling when document is not found.
|
||||
|
||||
This test ensures that the task correctly handles cases where the specified
|
||||
document doesn't exist, logging appropriate messages and returning early.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
non_existent_document_id = fake.uuid4()
|
||||
segment_ids = [fake.uuid4()]
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, non_existent_document_id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Redis should not be called when document is not found
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_document_invalid_status(self, db_session_with_containers):
|
||||
"""
|
||||
Test handling when document has invalid status for disabling.
|
||||
|
||||
This test ensures that the task correctly handles cases where the document
|
||||
is not enabled, archived, or not completed, preventing invalid operations.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
|
||||
|
||||
# Test case 1: Document not enabled
|
||||
document.enabled = False
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Redis should not be called when document status is invalid
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
# Test case 2: Document archived
|
||||
document.enabled = True
|
||||
document.archived = True
|
||||
db.session.commit()
|
||||
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
# Test case 3: Document indexing not completed
|
||||
document.enabled = True
|
||||
document.archived = False
|
||||
document.indexing_status = "indexing"
|
||||
db.session.commit()
|
||||
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_no_segments_found(self, db_session_with_containers):
|
||||
"""
|
||||
Test handling when no segments are found for the given IDs.
|
||||
|
||||
This test ensures that the task correctly handles cases where the specified
|
||||
segment IDs don't exist or don't match the dataset/document criteria.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
# Use non-existent segment IDs
|
||||
non_existent_segment_ids = [fake.uuid4() for _ in range(3)]
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(non_existent_segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Redis should not be called when no segments are found
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_index_processor_error(self, db_session_with_containers):
|
||||
"""
|
||||
Test handling when index processor encounters an error.
|
||||
|
||||
This test verifies that the task correctly handles index processor errors
|
||||
by rolling back segment states and ensuring proper cleanup.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Mock the index processor to raise an exception
|
||||
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean.side_effect = Exception("Index processor error")
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
|
||||
# Verify segments were rolled back to enabled state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(segments[0])
|
||||
db.session.refresh(segments[1])
|
||||
|
||||
# Check that segments are re-enabled after error
|
||||
updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
|
||||
|
||||
for segment in updated_segments:
|
||||
assert segment.enabled is True
|
||||
assert segment.disabled_at is None
|
||||
assert segment.disabled_by is None
|
||||
|
||||
# Verify Redis cache cleanup was still called
|
||||
assert mock_redis.delete.call_count == len(segments)
|
||||
|
||||
def test_disable_segments_with_different_doc_forms(self, db_session_with_containers):
|
||||
"""
|
||||
Test disabling segments with different document forms.
|
||||
|
||||
This test verifies that the task correctly handles different document forms
|
||||
(paragraph, qa, parent_child) and initializes the appropriate index processor.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Test different document forms
|
||||
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
# Update document form
|
||||
document.doc_form = doc_form
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Mock the index processor factory
|
||||
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
mock_factory.assert_called_with(doc_form)
|
||||
|
||||
def test_disable_segments_performance_timing(self, db_session_with_containers):
|
||||
"""
|
||||
Test that the task properly measures and logs performance timing.
|
||||
|
||||
This test verifies that the task correctly measures execution time
|
||||
and logs performance metrics for monitoring purposes.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Mock the index processor
|
||||
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Mock time.perf_counter to control timing
|
||||
with patch("tasks.disable_segments_from_index_task.time.perf_counter") as mock_perf_counter:
|
||||
mock_perf_counter.side_effect = [1000.0, 1000.5] # 0.5 seconds execution time
|
||||
|
||||
# Mock logger to capture log messages
|
||||
with patch("tasks.disable_segments_from_index_task.logger") as mock_logger:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
|
||||
# Verify performance logging
|
||||
mock_logger.info.assert_called()
|
||||
log_calls = [call[0][0] for call in mock_logger.info.call_args_list]
|
||||
performance_log = next((call for call in log_calls if "latency" in call), None)
|
||||
assert performance_log is not None
|
||||
assert "0.5" in performance_log # Should log the execution time
|
||||
|
||||
def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers):
|
||||
"""
|
||||
Test that Redis cache is properly cleaned up for all segments.
|
||||
|
||||
This test verifies that the task correctly removes indexing cache entries
|
||||
from Redis for all processed segments, preventing stale cache issues.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 5, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Mock the index processor
|
||||
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Mock Redis client to track delete calls
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
|
||||
# Verify Redis delete was called for each segment
|
||||
assert mock_redis.delete.call_count == len(segments)
|
||||
|
||||
# Verify correct cache keys were used
|
||||
expected_keys = [f"segment_{segment.id}_indexing" for segment in segments]
|
||||
actual_calls = [call[0][0] for call in mock_redis.delete.call_args_list]
|
||||
|
||||
for expected_key in expected_keys:
|
||||
assert expected_key in actual_calls
|
||||
|
||||
def test_disable_segments_database_session_cleanup(self, db_session_with_containers):
|
||||
"""
|
||||
Test that database session is properly closed after task execution.
|
||||
|
||||
This test verifies that the task correctly manages database sessions
|
||||
and ensures proper cleanup to prevent connection leaks.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Mock the index processor
|
||||
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Mock db.session.close to verify it's called
|
||||
with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Verify session was closed
|
||||
mock_close.assert_called()
|
||||
|
||||
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
|
||||
"""
|
||||
Test handling when empty segment IDs list is provided.
|
||||
|
||||
This test ensures that the task correctly handles edge cases where
|
||||
an empty list of segment IDs is provided.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
empty_segment_ids = []
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(empty_segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Redis should not be called when no segments are provided
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers):
|
||||
"""
|
||||
Test handling when some segment IDs are valid and others are invalid.
|
||||
|
||||
This test verifies that the task correctly processes only the valid
|
||||
segment IDs and ignores invalid ones.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account, fake)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account, fake)
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake)
|
||||
self._create_dataset_process_rule(db_session_with_containers, dataset, fake)
|
||||
|
||||
# Mix valid and invalid segment IDs
|
||||
valid_segment_ids = [segment.id for segment in segments]
|
||||
invalid_segment_ids = [fake.uuid4() for _ in range(2)]
|
||||
mixed_segment_ids = valid_segment_ids + invalid_segment_ids
|
||||
|
||||
# Mock the index processor
|
||||
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
# Mock Redis client
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Act
|
||||
result = disable_segments_from_index_task(mixed_segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
|
||||
# Verify index processor was called with only valid segment node IDs
|
||||
expected_node_ids = [segment.index_node_id for segment in segments]
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # First argument should be the dataset
|
||||
assert sorted(call_args[0][1]) == sorted(
|
||||
expected_node_ids
|
||||
) # Compare sorted lists to handle any order while preserving duplicates
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is False
|
||||
|
||||
# Verify Redis cleanup was called only for valid segments
|
||||
assert mock_redis.delete.call_count == len(segments)
|
||||
@ -0,0 +1,554 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
|
||||
|
||||
class TestDocumentIndexingTask:
|
||||
"""Integration tests for document_indexing_task using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner,
|
||||
patch("tasks.document_indexing_task.FeatureService") as mock_feature_service,
|
||||
):
|
||||
# Setup mock indexing runner
|
||||
mock_runner_instance = MagicMock()
|
||||
mock_indexing_runner.return_value = mock_runner_instance
|
||||
|
||||
# Setup mock feature service
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
yield {
|
||||
"indexing_runner": mock_indexing_runner,
|
||||
"indexing_runner_instance": mock_runner_instance,
|
||||
"feature_service": mock_feature_service,
|
||||
"features": mock_features,
|
||||
}
|
||||
|
||||
def _create_test_dataset_and_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset and documents for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
document_count: Number of documents to create
|
||||
|
||||
Returns:
|
||||
tuple: (dataset, documents) - Created dataset and document instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
for i in range(document_count):
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
def _create_test_dataset_with_billing_features(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset with billing features configured.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
billing_enabled: Whether billing is enabled
|
||||
|
||||
Returns:
|
||||
tuple: (dataset, documents) - Created dataset and document instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
for i in range(3):
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Configure billing features
|
||||
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||
if billing_enabled:
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
|
||||
mock_external_service_dependencies["features"].vector_space.limit = 100
|
||||
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful document indexing with multiple documents.
|
||||
|
||||
This test verifies:
|
||||
- Proper dataset retrieval from database
|
||||
- Correct document processing and status updates
|
||||
- IndexingRunner integration
|
||||
- Database state updates
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=3
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 3
|
||||
|
||||
def test_document_indexing_task_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of non-existent dataset.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing datasets
|
||||
- Early return without processing
|
||||
- Database session cleanup
|
||||
- No unnecessary indexing runner calls
|
||||
"""
|
||||
# Arrange: Use non-existent dataset ID
|
||||
fake = Faker()
|
||||
non_existent_dataset_id = fake.uuid4()
|
||||
document_ids = [fake.uuid4() for _ in range(3)]
|
||||
|
||||
# Act: Execute the task with non-existent dataset
|
||||
document_indexing_task(non_existent_dataset_id, document_ids)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||
|
||||
def test_document_indexing_task_document_not_found_in_dataset(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling when some documents don't exist in the dataset.
|
||||
|
||||
This test verifies:
|
||||
- Only existing documents are processed
|
||||
- Non-existent documents are ignored
|
||||
- Indexing runner receives only valid documents
|
||||
- Database state updates correctly
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
|
||||
# Mix existing and non-existent document IDs
|
||||
fake = Faker()
|
||||
existing_document_ids = [doc.id for doc in documents]
|
||||
non_existent_document_ids = [fake.uuid4() for _ in range(2)]
|
||||
all_document_ids = existing_document_ids + non_existent_document_ids
|
||||
|
||||
# Act: Execute the task with mixed document IDs
|
||||
document_indexing_task(dataset.id, all_document_ids)
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only existing documents were updated
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with only existing documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 2 # Only existing documents
|
||||
|
||||
def test_document_indexing_task_indexing_runner_exception(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of IndexingRunner exceptions.
|
||||
|
||||
This test verifies:
|
||||
- Exceptions from IndexingRunner are properly caught
|
||||
- Task completes without raising exceptions
|
||||
- Database session is properly closed
|
||||
- Error logging occurs
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Mock IndexingRunner to raise an exception
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception(
|
||||
"Indexing runner failed"
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_mixed_document_states(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test processing documents with mixed initial states.
|
||||
|
||||
This test verifies:
|
||||
- Documents with different initial states are handled correctly
|
||||
- Only valid documents are processed
|
||||
- Database state updates are consistent
|
||||
- IndexingRunner receives correct documents
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, base_documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
|
||||
# Create additional documents with different states
|
||||
fake = Faker()
|
||||
extra_documents = []
|
||||
|
||||
# Document with different indexing status
|
||||
doc1 = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=2,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=dataset.created_by,
|
||||
indexing_status="completed", # Already completed
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(doc1)
|
||||
extra_documents.append(doc1)
|
||||
|
||||
# Document with disabled status
|
||||
doc2 = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=3,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=dataset.created_by,
|
||||
indexing_status="waiting",
|
||||
enabled=False, # Disabled
|
||||
)
|
||||
db.session.add(doc2)
|
||||
extra_documents.append(doc2)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
all_documents = base_documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with mixed document states
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify all documents were updated to parsing status
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with all documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 4
|
||||
|
||||
def test_document_indexing_task_billing_sandbox_plan_batch_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test billing validation for sandbox plan batch upload limit.
|
||||
|
||||
This test verifies:
|
||||
- Sandbox plan batch upload limit enforcement
|
||||
- Error handling for batch upload limit exceeded
|
||||
- Document status updates to error state
|
||||
- Proper error message recording
|
||||
"""
|
||||
# Arrange: Create test data with billing enabled
|
||||
dataset, documents = self._create_test_dataset_with_billing_features(
|
||||
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||
)
|
||||
|
||||
# Configure sandbox plan with batch limit
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
|
||||
|
||||
# Create more documents than sandbox plan allows (limit is 1)
|
||||
fake = Faker()
|
||||
extra_documents = []
|
||||
for i in range(2): # Total will be 5 documents (3 existing + 2 new)
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=i + 3,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=dataset.created_by,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
extra_documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
all_documents = documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify error handling
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error is not None
|
||||
assert "batch upload" in document.error
|
||||
assert document.stopped_at is not None
|
||||
|
||||
# Verify no indexing runner was called
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
|
||||
def test_document_indexing_task_billing_disabled_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful processing when billing is disabled.
|
||||
|
||||
This test verifies:
|
||||
- Processing continues normally when billing is disabled
|
||||
- No billing validation occurs
|
||||
- Documents are processed successfully
|
||||
- IndexingRunner is called correctly
|
||||
"""
|
||||
# Arrange: Create test data with billing disabled
|
||||
dataset, documents = self._create_test_dataset_with_billing_features(
|
||||
db_session_with_containers, mock_external_service_dependencies, billing_enabled=False
|
||||
)
|
||||
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task with billing disabled
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify successful processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_document_is_paused_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of DocumentIsPausedError from IndexingRunner.
|
||||
|
||||
This test verifies:
|
||||
- DocumentIsPausedError is properly caught and handled
|
||||
- Task completes without raising exceptions
|
||||
- Appropriate logging occurs
|
||||
- Database session is properly closed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Mock IndexingRunner to raise DocumentIsPausedError
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError(
|
||||
"Document indexing is paused"
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
@ -15,7 +15,7 @@ class FakeResponse:
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.content = content
|
||||
self.text = text if text else content.decode("utf-8", errors="ignore")
|
||||
self.text = text or content.decode("utf-8", errors="ignore")
|
||||
|
||||
|
||||
# ---------------------------
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -47,7 +46,7 @@ class MockTokenBufferMemory:
|
||||
self.history_messages = history_messages or []
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
if message_limit is not None:
|
||||
return self.history_messages[-message_limit * 2 :]
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import contextvars
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -29,7 +28,7 @@ def login_app(app: Flask) -> Flask:
|
||||
login_manager.init_app(app)
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id: str) -> Optional[User]:
|
||||
def load_user(user_id: str) -> User | None:
|
||||
if user_id == "test_user":
|
||||
return User("test_user")
|
||||
return None
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Mock redis_client before importing dataset_service
|
||||
from unittest.mock import Mock, call, patch
|
||||
@ -37,7 +36,7 @@ class DocumentBatchUpdateTestDataFactory:
|
||||
enabled: bool = True,
|
||||
archived: bool = False,
|
||||
indexing_status: str = "completed",
|
||||
completed_at: Optional[datetime.datetime] = None,
|
||||
completed_at: datetime.datetime | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock document with specified attributes."""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
# Mock redis_client before importing dataset_service
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
@ -24,9 +24,9 @@ class DatasetUpdateTestDataFactory:
|
||||
description: str = "old_description",
|
||||
indexing_technique: str = "high_quality",
|
||||
retrieval_model: str = "old_model",
|
||||
embedding_model_provider: Optional[str] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
collection_binding_id: Optional[str] = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
collection_binding_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation:
|
||||
# Console API create
|
||||
console_create_file = "api/controllers/console/datasets/metadata.py"
|
||||
if os.path.exists(console_create_file):
|
||||
with open(console_create_file) as f:
|
||||
content = f.read()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||
content = Path(console_create_file).read_text()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||
|
||||
# Service API create
|
||||
service_create_file = "api/controllers/service_api/dataset/metadata.py"
|
||||
if os.path.exists(service_create_file):
|
||||
with open(service_create_file) as f:
|
||||
content = f.read()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||
assert "nullable=True" not in create_api_section
|
||||
content = Path(service_create_file).read_text()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||
assert "nullable=True" not in create_api_section
|
||||
|
||||
|
||||
class TestMetadataValidationSummary:
|
||||
|
||||
Reference in New Issue
Block a user