Merge branch 'feat/queue-based-graph-engine' into chore/merge-graph-engine

This commit is contained in:
-LAN-
2025-09-08 14:25:10 +08:00
824 changed files with 7235 additions and 2941 deletions

View File

@ -15,7 +15,7 @@ from services.account_service import AccountService, RegisterService
# Loading the .env file if it exists
def _load_env() -> None:
def _load_env():
current_file_path = pathlib.Path(__file__).absolute()
# Items later in the list have higher precedence.
files_to_load = [".env", "vdb.env"]

View File

@ -84,17 +84,17 @@ class TestStorageKeyLoader(unittest.TestCase):
if tenant_id is None:
tenant_id = self.tenant_id
tool_file = ToolFile()
tool_file = ToolFile(
user_id=self.user_id,
tenant_id=tenant_id,
conversation_id=self.conversation_id,
file_key=file_key,
mimetype="text/plain",
original_url="http://example.com/file.txt",
name="test_tool_file.txt",
size=2048,
)
tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file)
self.session.flush()
self.test_tool_files.append(tool_file)

View File

@ -17,7 +17,7 @@ def mock_plugin_daemon(
:return: unpatch function
"""
def unpatch() -> None:
def unpatch():
monkeypatch.undo()
monkeypatch.setattr(PluginModelClient, "invoke_llm", MockModelClass.invoke_llm)

View File

@ -150,7 +150,7 @@ class MockTcvectordbClass:
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[dict]:
):
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
def collection_delete(
@ -163,7 +163,7 @@ class MockTcvectordbClass:
):
return {"code": 0, "msg": "operation success"}
def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None) -> dict:
def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None):
return {"code": 0, "msg": "operation success"}

View File

@ -26,7 +26,7 @@ def get_example_document(doc_id: str) -> Document:
@pytest.fixture
def setup_mock_redis() -> None:
def setup_mock_redis():
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
@ -48,7 +48,7 @@ class AbstractVectorTest:
self.example_doc_id = str(uuid.uuid4())
self.example_embedding = [1.001 * i for i in range(128)]
def create_vector(self) -> None:
def create_vector(self):
self.vector.create(
texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[self.example_embedding],

View File

@ -12,7 +12,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
class MockedCodeExecutor:
@classmethod
def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict:
def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict):
# invoke directly
match language:
case CodeLanguage.PYTHON3:

View File

@ -78,7 +78,7 @@ def init_code_node(code_config: dict):
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code(setup_code_executor_mock):
code = """
def main(args1: int, args2: int) -> dict:
def main(args1: int, args2: int):
return {
"result": args1 + args2,
}
@ -124,7 +124,7 @@ def test_execute_code(setup_code_executor_mock):
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code_output_validator(setup_code_executor_mock):
code = """
def main(args1: int, args2: int) -> dict:
def main(args1: int, args2: int):
return {
"result": args1 + args2,
}
@ -167,7 +167,7 @@ def test_execute_code_output_validator(setup_code_executor_mock):
def test_execute_code_output_validator_depth():
code = """
def main(args1: int, args2: int) -> dict:
def main(args1: int, args2: int):
return {
"result": {
"result": args1 + args2,
@ -285,7 +285,7 @@ def test_execute_code_output_validator_depth():
def test_execute_code_output_object_list():
code = """
def main(args1: int, args2: int) -> dict:
def main(args1: int, args2: int):
return {
"result": {
"result": args1 + args2,
@ -360,7 +360,7 @@ def test_execute_code_output_object_list():
def test_execute_code_scientific_notation():
code = """
def main() -> dict:
def main():
return {
"result": -8.0E-5
}

View File

@ -49,7 +49,7 @@ class DifyTestContainers:
self._containers_started = False
logger.info("DifyTestContainers initialized - ready to manage test containers")
def start_containers_with_env(self) -> None:
def start_containers_with_env(self):
"""
Start all required containers for integration testing.
@ -230,7 +230,7 @@ class DifyTestContainers:
self._containers_started = True
logger.info("All test containers started successfully")
def stop_containers(self) -> None:
def stop_containers(self):
"""
Stop and clean up all test containers.

View File

@ -84,16 +84,17 @@ class TestStorageKeyLoader(unittest.TestCase):
if tenant_id is None:
tenant_id = self.tenant_id
tool_file = ToolFile()
tool_file = ToolFile(
user_id=self.user_id,
tenant_id=tenant_id,
conversation_id=self.conversation_id,
file_key=file_key,
mimetype="text/plain",
original_url="http://example.com/file.txt",
name="test_tool_file.txt",
size=2048,
)
tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file)
self.session.flush()

View File

@ -1,10 +1,11 @@
import json
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch
import pytest
from faker import Faker
from core.plugin.impl.exc import PluginDaemonClientSideError
from models.account import Account
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
from services.agent_service import AgentService
@ -21,7 +22,7 @@ class TestAgentService:
patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client,
patch("services.agent_service.ToolManager") as mock_tool_manager,
patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager,
patch("services.agent_service.current_user") as mock_current_user,
patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user,
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,

View File

@ -1,9 +1,10 @@
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from werkzeug.exceptions import NotFound
from models.account import Account
from models.model import MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.app_service import AppService
@ -24,7 +25,9 @@ class TestAnnotationService:
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
patch("services.annotation_service.current_user") as mock_current_user,
patch(
"services.annotation_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
):
# Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False

View File

@ -1,9 +1,10 @@
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from constants.model_template import default_app_templates
from models.account import Account
from models.model import App, Site
from services.account_service import AccountService, TenantService
from services.app_service import AppService
@ -161,8 +162,13 @@ class TestAppService:
app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account)
# Get app using the service
retrieved_app = app_service.get_app(created_app)
# Get app using the service - needs current_user mock
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
retrieved_app = app_service.get_app(created_app)
# Verify retrieved app matches created app
assert retrieved_app.id == created_app.id
@ -406,7 +412,11 @@ class TestAppService:
"use_icon_as_answer_icon": True,
}
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app(app, update_args)
# Verify updated fields
@ -456,7 +466,11 @@ class TestAppService:
# Update app name
new_name = "New App Name"
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_name(app, new_name)
assert updated_app.name == new_name
@ -504,7 +518,11 @@ class TestAppService:
# Update app icon
new_icon = "🌟"
new_icon_background = "#FFD93D"
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
assert updated_app.icon == new_icon
@ -551,13 +569,17 @@ class TestAppService:
original_site_status = app.enable_site
# Update site status to disabled
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_site_status(app, False)
assert updated_app.enable_site is False
assert updated_app.updated_by == account.id
# Update site status back to enabled
with patch("flask_login.utils._get_user", return_value=account):
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_site_status(updated_app, True)
assert updated_app.enable_site is True
assert updated_app.updated_by == account.id
@ -602,13 +624,17 @@ class TestAppService:
original_api_status = app.enable_api
# Update API status to disabled
with patch("flask_login.utils._get_user", return_value=account):
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_api_status(app, False)
assert updated_app.enable_api is False
assert updated_app.updated_by == account.id
# Update API status back to enabled
with patch("flask_login.utils._get_user", return_value=account):
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_api_status(updated_app, True)
assert updated_app.enable_api is True
assert updated_app.updated_by == account.id

View File

@ -1,6 +1,6 @@
import hashlib
from io import BytesIO
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
@ -417,11 +417,12 @@ class TestFileService:
text = "This is a test text content"
text_name = "test_text.txt"
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
# Mock current_user using create_autospec
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None
@ -443,11 +444,12 @@ class TestFileService:
text = "test content"
long_name = "a" * 250 # Longer than 200 characters
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
# Mock current_user using create_autospec
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=long_name)
# Verify name was truncated
@ -846,11 +848,12 @@ class TestFileService:
text = ""
text_name = "empty.txt"
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
# Mock current_user using create_autospec
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None

View File

@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
@ -17,7 +17,9 @@ class TestMetadataService:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.metadata_service.current_user") as mock_current_user,
patch(
"services.metadata_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
patch("services.metadata_service.redis_client") as mock_redis_client,
patch("services.dataset_service.DocumentService") as mock_document_service,
):

View File

@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
@ -17,7 +17,7 @@ class TestTagService:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.tag_service.current_user") as mock_current_user,
patch("services.tag_service.current_user", create_autospec(Account, instance=True)) as mock_current_user,
):
# Setup default mock returns
mock_current_user.current_tenant_id = "test-tenant-id"

View File

@ -57,10 +57,12 @@ class TestWebAppAuthService:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
import uuid
# Create account
# Create account with unique email to avoid collisions
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
account = Account(
email=fake.email(),
email=unique_email,
name=fake.name(),
interface_language="en-US",
status="active",
@ -109,8 +111,11 @@ class TestWebAppAuthService:
password = fake.password(length=12)
# Create account with password
import uuid
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
account = Account(
email=fake.email(),
email=unique_email,
name=fake.name(),
interface_language="en-US",
status="active",
@ -322,9 +327,12 @@ class TestWebAppAuthService:
"""
# Arrange: Create account without password
fake = Faker()
import uuid
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
account = Account(
email=fake.email(),
email=unique_email,
name=fake.name(),
interface_language="en-US",
status="active",
@ -431,9 +439,12 @@ class TestWebAppAuthService:
"""
# Arrange: Create banned account
fake = Faker()
import uuid
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
account = Account(
email=fake.email(),
email=unique_email,
name=fake.name(),
interface_language="en-US",
status=AccountStatus.BANNED.value,

View File

@ -1,5 +1,5 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch
import pytest
from faker import Faker
@ -231,9 +231,10 @@ class TestWebsiteService:
fake = Faker()
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlApiRequest(
provider="firecrawl",
@ -285,9 +286,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlApiRequest(
provider="watercrawl",
@ -336,9 +338,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request for single page crawling
api_request = WebsiteCrawlApiRequest(
provider="jinareader",
@ -389,9 +392,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request with invalid provider
api_request = WebsiteCrawlApiRequest(
provider="invalid_provider",
@ -419,9 +423,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123")
@ -463,9 +468,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123")
@ -502,9 +508,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123")
@ -544,9 +551,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request with invalid provider
api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123")
@ -569,9 +577,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Mock missing credentials
mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None
@ -597,9 +606,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Mock missing API key in config
mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = {
"config": {"base_url": "https://api.example.com"}
@ -995,9 +1005,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request for sub-page crawling
api_request = WebsiteCrawlApiRequest(
provider="jinareader",
@ -1054,9 +1065,10 @@ class TestWebsiteService:
mock_external_service_dependencies["requests"].get.return_value = mock_failed_response
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlApiRequest(
provider="jinareader",
@ -1096,9 +1108,10 @@ class TestWebsiteService:
mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance
# Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = account.current_tenant.id
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123")

View File

@ -0,0 +1,786 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
from tasks.add_document_to_index_task import add_document_to_index_task
class TestAddDocumentToIndexTask:
"""Integration tests for add_document_to_index_task using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.add_document_to_index_task.IndexProcessorFactory") as mock_index_processor_factory,
):
# Setup mock index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
yield {
"index_processor_factory": mock_index_processor_factory,
"index_processor": mock_processor,
}
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test dataset and document for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (dataset, document) - Created dataset and document instances
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Create dataset
dataset = Dataset(
id=fake.uuid4(),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
# Create document
document = Document(
id=fake.uuid4(),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="completed",
enabled=True,
doc_form=IndexType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
# Refresh dataset to ensure doc_form property works correctly
db.session.refresh(dataset)
return dataset, document
def _create_test_segments(self, db_session_with_containers, document, dataset):
"""
Helper method to create test document segments.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
document: Document instance
dataset: Dataset instance
Returns:
list: List of created DocumentSegment instances
"""
fake = Faker()
segments = []
for i in range(3):
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
content=fake.text(max_nb_chars=200),
word_count=len(fake.text(max_nb_chars=200).split()),
tokens=len(fake.text(max_nb_chars=200).split()) * 2,
index_node_id=f"node_{i}",
index_node_hash=f"hash_{i}",
enabled=False,
status="completed",
created_by=document.created_by,
)
db.session.add(segment)
segments.append(segment)
db.session.commit()
return segments
def test_add_document_to_index_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful document indexing with paragraph index type.
This test verifies:
- Proper document retrieval from database
- Correct segment processing and document creation
- Index processor integration
- Database state updates
- Segment status changes
- Redis cache key deletion
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Set up Redis cache key to simulate indexing in progress
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300) # 5 minutes expiry
# Verify cache key exists
assert redis_client.exists(indexing_cache_key) == 1
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify the expected outcomes
# Verify index processor was called correctly
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify database state changes
db.session.refresh(document)
for segment in segments:
db.session.refresh(segment)
assert segment.enabled is True
assert segment.disabled_at is None
assert segment.disabled_by is None
# Verify Redis cache key was deleted
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_different_index_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test document indexing with different index types.
This test verifies:
- Proper handling of different index types
- Index processor factory integration
- Document processing with various configurations
- Redis cache key deletion
"""
# Arrange: Create test data with different index type
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Update document to use different index type
document.doc_form = IndexType.QA_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset)
# Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify different index type handling
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 3
# Verify database state changes
db.session.refresh(document)
for segment in segments:
db.session.refresh(segment)
assert segment.enabled is True
assert segment.disabled_at is None
assert segment.disabled_by is None
# Verify Redis cache key was deleted
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of non-existent document.
This test verifies:
- Proper error handling for missing documents
- Early return without processing
- Database session cleanup
- No unnecessary index processor calls
- Redis cache key not affected (since it was never created)
"""
# Arrange: Use non-existent document ID
fake = Faker()
non_existent_id = fake.uuid4()
# Act: Execute the task with non-existent document
add_document_to_index_task(non_existent_id)
# Assert: Verify no processing occurred
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
mock_external_service_dependencies["index_processor"].load.assert_not_called()
# Note: redis_client.delete is not called when document is not found
# because indexing_cache_key is not defined in that case
def test_add_document_to_index_invalid_indexing_status(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of document with invalid indexing status.
This test verifies:
- Early return when indexing_status is not "completed"
- No index processing for documents not ready for indexing
- Proper database session cleanup
- No unnecessary external service calls
- Redis cache key not affected
"""
# Arrange: Create test data with invalid indexing status
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Set invalid indexing status
document.indexing_status = "processing"
db.session.commit()
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify no processing occurred
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_add_document_to_index_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling when document's dataset doesn't exist.
This test verifies:
- Proper error handling when dataset is missing
- Document status is set to error
- Document is disabled
- Error information is recorded
- Redis cache is cleared despite error
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Delete the dataset to simulate dataset not found scenario
db.session.delete(dataset)
db.session.commit()
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify error handling
db.session.refresh(document)
assert document.enabled is False
assert document.indexing_status == "error"
assert document.error is not None
assert "doesn't exist" in document.error
assert document.disabled_at is not None
# Verify no index processing occurred
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
mock_external_service_dependencies["index_processor"].load.assert_not_called()
# Verify redis cache was cleared despite error
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_parent_child_structure(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test document indexing with parent-child structure.
This test verifies:
- Proper handling of PARENT_CHILD_INDEX type
- Child document creation from segments
- Correct document structure for parent-child indexing
- Index processor receives properly structured documents
- Redis cache key deletion
"""
# Arrange: Create test data with parent-child index type
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Update document to use parent-child index type
document.doc_form = IndexType.PARENT_CHILD_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset)
# Create segments with mock child chunks
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Mock the get_child_chunks method for each segment
with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks:
# Setup mock to return child chunks for each segment
mock_child_chunks = []
for i in range(2): # Each segment has 2 child chunks
mock_child = MagicMock()
mock_child.content = f"child_content_{i}"
mock_child.index_node_id = f"child_node_{i}"
mock_child.index_node_hash = f"child_hash_{i}"
mock_child_chunks.append(mock_child)
mock_get_child_chunks.return_value = mock_child_chunks
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify parent-child index processing
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexType.PARENT_CHILD_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 3 # 3 segments
# Verify each document has children
for doc in documents:
assert hasattr(doc, "children")
assert len(doc.children) == 2 # Each document has 2 children
# Verify database state changes
db.session.refresh(document)
for segment in segments:
db.session.refresh(segment)
assert segment.enabled is True
assert segment.disabled_at is None
assert segment.disabled_by is None
# Verify redis cache was cleared
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_no_segments_to_process(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test document indexing when no segments need processing.
This test verifies:
- Proper handling when all segments are already enabled
- Index processing still occurs but with empty documents list
- Auto disable log deletion still occurs
- Redis cache is cleared
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Create segments that are already enabled
fake = Faker()
segments = []
for i in range(3):
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
content=fake.text(max_nb_chars=200),
word_count=len(fake.text(max_nb_chars=200).split()),
tokens=len(fake.text(max_nb_chars=200).split()) * 2,
index_node_id=f"node_{i}",
index_node_hash=f"hash_{i}",
enabled=True, # Already enabled
status="completed",
created_by=document.created_by,
)
db.session.add(segment)
segments.append(segment)
db.session.commit()
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify index processing occurred but with empty documents list
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with empty documents list
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 0 # No segments to process
# Verify redis cache was cleared
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_auto_disable_log_deletion(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that auto disable logs are properly deleted during indexing.
This test verifies:
- Auto disable log entries are deleted for the document
- Database state is properly managed
- Index processing continues normally
- Redis cache key deletion
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Create some auto disable log entries
fake = Faker()
auto_disable_logs = []
for i in range(2):
log_entry = DatasetAutoDisableLog(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(log_entry)
auto_disable_logs.append(log_entry)
db.session.commit()
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Verify logs exist before processing
existing_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all()
)
assert len(existing_logs) == 2
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify auto disable logs were deleted
remaining_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all()
)
assert len(remaining_logs) == 0
# Verify index processing occurred normally
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify segments were enabled
for segment in segments:
db.session.refresh(segment)
assert segment.enabled is True
# Verify redis cache was cleared
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_general_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test general exception handling during indexing process.
This test verifies:
- Exceptions are properly caught and handled
- Document status is set to error
- Document is disabled
- Error information is recorded
- Redis cache is still cleared
- Database session is properly closed
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Mock the index processor to raise an exception
mock_external_service_dependencies["index_processor"].load.side_effect = Exception("Index processing failed")
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify error handling
db.session.refresh(document)
assert document.enabled is False
assert document.indexing_status == "error"
assert document.error is not None
assert "Index processing failed" in document.error
assert document.disabled_at is not None
# Verify segments were not enabled due to error
for segment in segments:
db.session.refresh(segment)
assert segment.enabled is False # Should remain disabled due to error
# Verify redis cache was still cleared despite error
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_segment_filtering_edge_cases(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test segment filtering with various edge cases.
This test verifies:
- Only segments with enabled=False and status="completed" are processed
- Segments are ordered by position correctly
- Mixed segment states are handled properly
- Redis cache key deletion
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
# Create segments with mixed states
fake = Faker()
segments = []
# Segment 1: Should be processed (enabled=False, status="completed")
segment1 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=0,
content=fake.text(max_nb_chars=200),
word_count=len(fake.text(max_nb_chars=200).split()),
tokens=len(fake.text(max_nb_chars=200).split()) * 2,
index_node_id="node_0",
index_node_hash="hash_0",
enabled=False,
status="completed",
created_by=document.created_by,
)
db.session.add(segment1)
segments.append(segment1)
# Segment 2: Should NOT be processed (enabled=True, status="completed")
segment2 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content=fake.text(max_nb_chars=200),
word_count=len(fake.text(max_nb_chars=200).split()),
tokens=len(fake.text(max_nb_chars=200).split()) * 2,
index_node_id="node_1",
index_node_hash="hash_1",
enabled=True, # Already enabled
status="completed",
created_by=document.created_by,
)
db.session.add(segment2)
segments.append(segment2)
# Segment 3: Should NOT be processed (enabled=False, status="processing")
segment3 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=2,
content=fake.text(max_nb_chars=200),
word_count=len(fake.text(max_nb_chars=200).split()),
tokens=len(fake.text(max_nb_chars=200).split()) * 2,
index_node_id="node_2",
index_node_hash="hash_2",
enabled=False,
status="processing", # Not completed
created_by=document.created_by,
)
db.session.add(segment3)
segments.append(segment3)
# Segment 4: Should be processed (enabled=False, status="completed")
segment4 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=3,
content=fake.text(max_nb_chars=200),
word_count=len(fake.text(max_nb_chars=200).split()),
tokens=len(fake.text(max_nb_chars=200).split()) * 2,
index_node_id="node_3",
index_node_hash="hash_3",
enabled=False,
status="completed",
created_by=document.created_by,
)
db.session.add(segment4)
segments.append(segment4)
db.session.commit()
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify only eligible segments were processed
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 2 # Only 2 segments should be processed
# Verify correct segments were processed (by position order)
assert documents[0].metadata["doc_id"] == "node_0" # position 0
assert documents[1].metadata["doc_id"] == "node_3" # position 3
# Verify database state changes
db.session.refresh(document)
db.session.refresh(segment1)
db.session.refresh(segment2)
db.session.refresh(segment3)
db.session.refresh(segment4)
# All segments should be enabled because the task updates ALL segments for the document
assert segment1.enabled is True
assert segment2.enabled is True # Was already enabled, now updated to True
assert segment3.enabled is True # Was not processed but still updated to True
assert segment4.enabled is True
# Verify redis cache was cleared
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_comprehensive_error_scenarios(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test comprehensive error scenarios and recovery.
This test verifies:
- Multiple types of exceptions are handled properly
- Error state is consistently managed
- Resource cleanup occurs in all error cases
- Database session management is robust
- Redis cache key deletion in all scenarios
"""
# Arrange: Create test data
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
segments = self._create_test_segments(db_session_with_containers, document, dataset)
# Test different exception types
test_exceptions = [
("Database connection error", Exception("Database connection failed")),
("Index processor error", RuntimeError("Index processor initialization failed")),
("Memory error", MemoryError("Out of memory")),
("Value error", ValueError("Invalid index type")),
]
for error_name, exception in test_exceptions:
# Reset mocks for each test
mock_external_service_dependencies["index_processor"].load.side_effect = exception
# Reset document state
document.enabled = True
document.indexing_status = "completed"
document.error = None
document.disabled_at = None
db.session.commit()
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.set(indexing_cache_key, "processing", ex=300)
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify consistent error handling
db.session.refresh(document)
assert document.enabled is False, f"Document should be disabled for {error_name}"
assert document.indexing_status == "error", f"Document status should be error for {error_name}"
assert document.error is not None, f"Error should be recorded for {error_name}"
assert str(exception) in document.error, f"Error message should contain exception for {error_name}"
assert document.disabled_at is not None, f"Disabled timestamp should be set for {error_name}"
# Verify segments remain disabled due to error
for segment in segments:
db.session.refresh(segment)
assert segment.enabled is False, f"Segments should remain disabled for {error_name}"
# Verify redis cache was still cleared despite error
assert redis_client.exists(indexing_cache_key) == 0, f"Redis cache should be cleared for {error_name}"

View File

@ -0,0 +1,720 @@
"""
Integration tests for batch_clean_document_task using testcontainers.
This module tests the batch document cleaning functionality with real database
and storage containers to ensure proper cleanup of documents, segments, and files.
"""
import json
import uuid
from unittest.mock import Mock, patch
import pytest
from faker import Faker
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
from tasks.batch_clean_document_task import batch_clean_document_task
class TestBatchCleanDocumentTask:
"""Integration tests for batch_clean_document_task using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("extensions.ext_storage.storage") as mock_storage,
patch("core.rag.index_processor.index_processor_factory.IndexProcessorFactory") as mock_index_factory,
patch("core.tools.utils.web_reader_tool.get_image_upload_file_ids") as mock_get_image_ids,
):
# Setup default mock returns
mock_storage.delete.return_value = None
# Mock index processor
mock_index_processor = Mock()
mock_index_processor.clean.return_value = None
mock_index_factory.return_value.init_index_processor.return_value = mock_index_processor
# Mock image file ID extraction
mock_get_image_ids.return_value = []
yield {
"storage": mock_storage,
"index_factory": mock_index_factory,
"index_processor": mock_index_processor,
"get_image_ids": mock_get_image_ids,
}
def _create_test_account(self, db_session_with_containers):
"""
Helper method to create a test account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
Returns:
Account: Created account instance
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account
def _create_test_dataset(self, db_session_with_containers, account):
"""
Helper method to create a test dataset for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
account: Account instance
Returns:
Dataset: Created dataset instance
"""
fake = Faker()
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
name=fake.word(),
description=fake.sentence(),
data_source_type="upload_file",
created_by=account.id,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
)
db.session.add(dataset)
db.session.commit()
return dataset
def _create_test_document(self, db_session_with_containers, dataset, account):
"""
Helper method to create a test document for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
dataset: Dataset instance
account: Account instance
Returns:
Document: Created document instance
"""
fake = Faker()
document = Document(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=dataset.id,
position=0,
name=fake.word(),
data_source_type="upload_file",
data_source_info=json.dumps({"upload_file_id": str(uuid.uuid4())}),
batch="test_batch",
created_from="test",
created_by=account.id,
indexing_status="completed",
doc_form="text_model",
)
db.session.add(document)
db.session.commit()
return document
def _create_test_document_segment(self, db_session_with_containers, document, account):
"""
Helper method to create a test document segment for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
document: Document instance
account: Account instance
Returns:
DocumentSegment: Created document segment instance
"""
fake = Faker()
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,
position=0,
content=fake.text(),
word_count=100,
tokens=50,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
)
db.session.add(segment)
db.session.commit()
return segment
def _create_test_upload_file(self, db_session_with_containers, account):
"""
Helper method to create a test upload file for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
account: Account instance
Returns:
UploadFile: Created upload file instance
"""
fake = Faker()
from datetime import datetime
from models.enums import CreatorUserRole
upload_file = UploadFile(
tenant_id=account.current_tenant.id,
storage_type="local",
key=f"test_files/{fake.file_name()}",
name=fake.file_name(),
size=1024,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.utcnow(),
used=False,
)
db.session.add(upload_file)
db.session.commit()
return upload_file
def test_batch_clean_document_task_successful_cleanup(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful cleanup of documents with segments and files.
This test verifies that the task properly cleans up:
- Document segments from the index
- Associated image files from storage
- Upload files from storage and database
"""
# Create test data
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
document = self._create_test_document(db_session_with_containers, dataset, account)
segment = self._create_test_document_segment(db_session_with_containers, document, account)
upload_file = self._create_test_upload_file(db_session_with_containers, account)
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
# Store original IDs for verification
document_id = document.id
segment_id = segment.id
file_id = upload_file.id
# Execute the task
batch_clean_document_task(
document_ids=[document_id], dataset_id=dataset.id, doc_form=dataset.doc_form, file_ids=[file_id]
)
# Verify that the task completed successfully
# The task should have processed the segment and cleaned up the database
# Verify database cleanup
db.session.commit() # Ensure all changes are committed
# Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_with_image_files(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test cleanup of documents containing image references.
This test verifies that the task properly handles documents with
image content and cleans up associated segments.
"""
# Create test data
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
document = self._create_test_document(db_session_with_containers, dataset, account)
# Create segment with simple content (no image references)
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,
position=0,
content="Simple text content without images",
word_count=100,
tokens=50,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
)
db.session.add(segment)
db.session.commit()
# Store original IDs for verification
segment_id = segment.id
document_id = document.id
# Execute the task
batch_clean_document_task(
document_ids=[document_id], dataset_id=dataset.id, doc_form=dataset.doc_form, file_ids=[]
)
# Verify database cleanup
db.session.commit()
# Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Verify that the task completed successfully by checking the log output
# The task should have processed the segment and cleaned up the database
def test_batch_clean_document_task_no_segments(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test cleanup when document has no segments.
This test verifies that the task handles documents without segments
gracefully and still cleans up associated files.
"""
# Create test data without segments
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
document = self._create_test_document(db_session_with_containers, dataset, account)
upload_file = self._create_test_upload_file(db_session_with_containers, account)
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
# Store original IDs for verification
document_id = document.id
file_id = upload_file.id
# Execute the task
batch_clean_document_task(
document_ids=[document_id], dataset_id=dataset.id, doc_form=dataset.doc_form, file_ids=[file_id]
)
# Verify that the task completed successfully
# Since there are no segments, the task should handle this gracefully
# Verify database cleanup
db.session.commit()
# Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
# Verify database cleanup
db.session.commit()
# Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test cleanup when dataset is not found.
This test verifies that the task properly handles the case where
the specified dataset does not exist in the database.
"""
# Create test data
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
document = self._create_test_document(db_session_with_containers, dataset, account)
# Store original IDs for verification
document_id = document.id
dataset_id = dataset.id
# Delete the dataset to simulate not found scenario
db.session.delete(dataset)
db.session.commit()
# Execute the task with non-existent dataset
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
# Verify that no index processing occurred
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
# Verify that no storage operations occurred
mock_external_service_dependencies["storage"].delete.assert_not_called()
# Verify that no database cleanup occurred
db.session.commit()
# Document should still exist since cleanup failed
existing_document = db.session.query(Document).filter_by(id=document_id).first()
assert existing_document is not None
def test_batch_clean_document_task_storage_cleanup_failure(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test cleanup when storage operations fail.
This test verifies that the task continues processing even when
storage cleanup operations fail, ensuring database cleanup still occurs.
"""
# Create test data
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
document = self._create_test_document(db_session_with_containers, dataset, account)
segment = self._create_test_document_segment(db_session_with_containers, document, account)
upload_file = self._create_test_upload_file(db_session_with_containers, account)
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
# Store original IDs for verification
document_id = document.id
segment_id = segment.id
file_id = upload_file.id
# Mock storage.delete to raise an exception
mock_external_service_dependencies["storage"].delete.side_effect = Exception("Storage error")
# Execute the task
batch_clean_document_task(
document_ids=[document_id], dataset_id=dataset.id, doc_form=dataset.doc_form, file_ids=[file_id]
)
# Verify that the task completed successfully despite storage failure
# The task should continue processing even when storage operations fail
# Verify database cleanup still occurred despite storage failure
db.session.commit()
# Check that segment is deleted from database
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that upload file is deleted from database
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_multiple_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test cleanup of multiple documents in a single batch operation.
This test verifies that the task can handle multiple documents
efficiently and cleans up all associated resources.
"""
# Create test data for multiple documents
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
documents = []
segments = []
upload_files = []
# Create 3 documents with segments and files
for i in range(3):
document = self._create_test_document(db_session_with_containers, dataset, account)
segment = self._create_test_document_segment(db_session_with_containers, document, account)
upload_file = self._create_test_upload_file(db_session_with_containers, account)
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
documents.append(document)
segments.append(segment)
upload_files.append(upload_file)
db.session.commit()
# Store original IDs for verification
document_ids = [doc.id for doc in documents]
segment_ids = [seg.id for seg in segments]
file_ids = [file.id for file in upload_files]
# Execute the task with multiple documents
batch_clean_document_task(
document_ids=document_ids, dataset_id=dataset.id, doc_form=dataset.doc_form, file_ids=file_ids
)
# Verify that the task completed successfully for all documents
# The task should process all documents and clean up all associated resources
# Verify database cleanup for all resources
db.session.commit()
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_different_doc_forms(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test cleanup with different document form types.
This test verifies that the task properly handles different
document form types and creates the appropriate index processor.
"""
# Create test data
account = self._create_test_account(db_session_with_containers)
# Test different doc_form types
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
for doc_form in doc_forms:
dataset = self._create_test_dataset(db_session_with_containers, account)
db.session.commit()
document = self._create_test_document(db_session_with_containers, dataset, account)
# Update document doc_form
document.doc_form = doc_form
db.session.commit()
segment = self._create_test_document_segment(db_session_with_containers, document, account)
# Store the ID before the object is deleted
segment_id = segment.id
try:
# Execute the task
batch_clean_document_task(
document_ids=[document.id], dataset_id=dataset.id, doc_form=doc_form, file_ids=[]
)
# Verify that the task completed successfully for this doc_form
# The task should handle different document forms correctly
# Verify database cleanup
db.session.commit()
# Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
except Exception as e:
# If the task fails due to external service issues (e.g., plugin daemon),
# we should still verify that the database state is consistent
# This is a common scenario in test environments where external services may not be available
db.session.commit()
# Check if the segment still exists (task may have failed before deletion)
existing_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
if existing_segment is not None:
# If segment still exists, the task failed before deletion
# This is acceptable in test environments with external service issues
pass
else:
# If segment was deleted, the task succeeded
pass
def test_batch_clean_document_task_large_batch_performance(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test cleanup performance with a large batch of documents.
This test verifies that the task can handle large batches efficiently
and maintains performance characteristics.
"""
import time
# Create test data for large batch
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
documents = []
segments = []
upload_files = []
# Create 10 documents with segments and files (larger batch)
batch_size = 10
for i in range(batch_size):
document = self._create_test_document(db_session_with_containers, dataset, account)
segment = self._create_test_document_segment(db_session_with_containers, document, account)
upload_file = self._create_test_upload_file(db_session_with_containers, account)
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
documents.append(document)
segments.append(segment)
upload_files.append(upload_file)
db.session.commit()
# Store original IDs for verification
document_ids = [doc.id for doc in documents]
segment_ids = [seg.id for seg in segments]
file_ids = [file.id for file in upload_files]
# Measure execution time
start_time = time.perf_counter()
# Execute the task with large batch
batch_clean_document_task(
document_ids=document_ids, dataset_id=dataset.id, doc_form=dataset.doc_form, file_ids=file_ids
)
end_time = time.perf_counter()
execution_time = end_time - start_time
# Verify performance characteristics (should complete within reasonable time)
assert execution_time < 5.0 # Should complete within 5 seconds
# Verify that the task completed successfully for the large batch
# The task should handle large batches efficiently
# Verify database cleanup for all resources
db.session.commit()
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_integration_with_real_database(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test full integration with real database operations.
This test verifies that the task integrates properly with the
actual database and maintains data consistency throughout the process.
"""
# Create test data
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
# Create document with complex structure
document = self._create_test_document(db_session_with_containers, dataset, account)
# Create multiple segments for the document
segments = []
for i in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,
position=i,
content=f"Segment content {i} with some text",
word_count=50 + i * 10,
tokens=25 + i * 5,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
)
segments.append(segment)
# Create upload file
upload_file = self._create_test_upload_file(db_session_with_containers, account)
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
# Add all to database
for segment in segments:
db.session.add(segment)
db.session.commit()
# Verify initial state
assert db.session.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
assert db.session.query(UploadFile).filter_by(id=upload_file.id).first() is not None
# Store original IDs for verification
document_id = document.id
segment_ids = [seg.id for seg in segments]
file_id = upload_file.id
# Execute the task
batch_clean_document_task(
document_ids=[document_id], dataset_id=dataset.id, doc_form=dataset.doc_form, file_ids=[file_id]
)
# Verify that the task completed successfully
# The task should process all segments and clean up all associated resources
# Verify database cleanup
db.session.commit()
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
# Verify final database state
assert db.session.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
assert db.session.query(UploadFile).filter_by(id=file_id).first() is None

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
storage_key="storage_key_123",
)
def create_file_dict(self, file_id: str = "test_file_dict") -> dict:
def create_file_dict(self, file_id: str = "test_file_dict"):
"""Create a file dictionary with correct dify_model_identity"""
return {
"dify_model_identity": FILE_MODEL_IDENTITY,

View File

@ -83,7 +83,7 @@ def test_client_session_initialize():
# Create message handler
def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
):
if isinstance(message, Exception):
raise message

View File

@ -26,14 +26,13 @@ def _gen_id():
class TestFileSaverImpl:
def test_save_binary_string(self, monkeypatch):
def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch):
user_id = _gen_id()
tenant_id = _gen_id()
file_type = FileType.IMAGE
mime_type = "image/png"
mock_signed_url = "https://example.com/image.png"
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
@ -43,6 +42,7 @@ class TestFileSaverImpl:
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mock_tool_file.id = _gen_id()
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine)
@ -80,7 +80,7 @@ class TestFileSaverImpl:
)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
def test_save_remote_url_request_failed(self, monkeypatch):
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
_TEST_URL = "https://example.com/image.png"
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
@ -99,7 +99,7 @@ class TestFileSaverImpl:
mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch):
def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
_TEST_URL = "https://example.com/image.png"
mime_type = "image/png"
user_id = _gen_id()
@ -115,7 +115,6 @@ class TestFileSaverImpl:
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
@ -125,6 +124,7 @@ class TestFileSaverImpl:
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mock_tool_file.id = _gen_id()
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)

View File

@ -66,6 +66,7 @@ def llm_node_data() -> LLMNodeData:
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
reasoning_format="tagged",
)
@ -676,3 +677,66 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
class TestReasoningFormat:
"""Test cases for reasoning_format functionality"""
def test_split_reasoning_separated_mode(self):
"""Test separated mode: tags are removed and content is extracted"""
text_with_think = """
<think>I need to explain what Dify is. It's an open source AI platform.
</think>Dify is an open source AI platform.
"""
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, "separated")
assert clean_text == "Dify is an open source AI platform."
assert reasoning_content == "I need to explain what Dify is. It's an open source AI platform."
def test_split_reasoning_tagged_mode(self):
"""Test tagged mode: original text is preserved"""
text_with_think = """
<think>I need to explain what Dify is. It's an open source AI platform.
</think>Dify is an open source AI platform.
"""
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, "tagged")
# Original text unchanged
assert clean_text == text_with_think
# Empty reasoning content in tagged mode
assert reasoning_content == ""
def test_split_reasoning_no_think_blocks(self):
"""Test behavior when no <think> tags are present"""
text_without_think = "This is a simple answer without any thinking blocks."
clean_text, reasoning_content = LLMNode._split_reasoning(text_without_think, "separated")
assert clean_text == text_without_think
assert reasoning_content == ""
def test_reasoning_format_default_value(self):
"""Test that reasoning_format defaults to 'tagged' for backward compatibility"""
node_data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
context=ContextConfig(enabled=False),
)
assert node_data.reasoning_format == "tagged"
text_with_think = """
<think>I need to explain what Dify is. It's an open source AI platform.
</think>Dify is an open source AI platform.
"""
clean_text, reasoning_content = LLMNode._split_reasoning(text_with_think, node_data.reasoning_format)
assert clean_text == text_with_think
assert reasoning_content == ""

View File

@ -274,7 +274,7 @@ def test_array_file_contains_file_name():
assert result.outputs["result"] is True
def _get_test_conditions() -> list:
def _get_test_conditions():
conditions = [
# Test boolean "is" operator
{"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"},

View File

@ -362,7 +362,7 @@ class TestVariablePoolSerialization:
self._assert_pools_equal(reconstructed_dict, reconstructed_json)
# TODO: assert the data for file object...
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool):
"""Assert that two VariablePools contain equivalent data"""
# Compare system variables

View File

@ -41,6 +41,7 @@ class TestWorkflowEntryRedisChannel:
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
variable_pool=mock_variable_pool,
graph_runtime_state=mock_graph_runtime_state,
command_channel=redis_channel, # Provide Redis channel
)
@ -81,6 +82,7 @@ class TestWorkflowEntryRedisChannel:
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
variable_pool=mock_variable_pool,
graph_runtime_state=mock_graph_runtime_state,
command_channel=None, # No channel provided
)
@ -128,6 +130,7 @@ class TestWorkflowEntryRedisChannel:
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
variable_pool=mock_variable_pool,
graph_runtime_state=mock_graph_runtime_state,
command_channel=redis_channel,
)

View File

@ -27,7 +27,7 @@ from services.feature_service import BrandingModel
class MockEmailRenderer:
"""Mock implementation of EmailRenderer protocol"""
def __init__(self) -> None:
def __init__(self):
self.rendered_templates: list[tuple[str, dict[str, Any]]] = []
def render_template(self, template_path: str, **context: Any) -> str:
@ -39,7 +39,7 @@ class MockEmailRenderer:
class MockBrandingService:
"""Mock implementation of BrandingService protocol"""
def __init__(self, enabled: bool = False, application_title: str = "Dify") -> None:
def __init__(self, enabled: bool = False, application_title: str = "Dify"):
self.enabled = enabled
self.application_title = application_title
@ -54,10 +54,10 @@ class MockBrandingService:
class MockEmailSender:
"""Mock implementation of EmailSender protocol"""
def __init__(self) -> None:
def __init__(self):
self.sent_emails: list[dict[str, str]] = []
def send_email(self, to: str, subject: str, html_content: str) -> None:
def send_email(self, to: str, subject: str, html_content: str):
"""Mock send_email that records sent emails"""
self.sent_emails.append(
{
@ -134,7 +134,7 @@ class TestEmailI18nService:
email_service: EmailI18nService,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
) -> None:
):
"""Test sending email with English language"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
@ -162,7 +162,7 @@ class TestEmailI18nService:
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
):
"""Test sending email with Chinese language"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
@ -181,7 +181,7 @@ class TestEmailI18nService:
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
) -> None:
):
"""Test sending email with branding enabled"""
# Create branding service with branding enabled
branding_service = MockBrandingService(enabled=True, application_title="MyApp")
@ -215,7 +215,7 @@ class TestEmailI18nService:
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
):
"""Test language fallback to English when requested language not available"""
# Request invite member in Chinese (not configured)
email_service.send_email(
@ -233,7 +233,7 @@ class TestEmailI18nService:
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
):
"""Test unknown language code falls back to English"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
@ -252,7 +252,7 @@ class TestEmailI18nService:
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
) -> None:
):
"""Test sending change email for old email verification"""
# Add change email templates to config
email_config.templates[EmailType.CHANGE_EMAIL_OLD] = {
@ -290,7 +290,7 @@ class TestEmailI18nService:
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
) -> None:
):
"""Test sending change email for new email verification"""
# Add change email templates to config
email_config.templates[EmailType.CHANGE_EMAIL_NEW] = {
@ -325,7 +325,7 @@ class TestEmailI18nService:
def test_send_change_email_invalid_phase(
self,
email_service: EmailI18nService,
) -> None:
):
"""Test sending change email with invalid phase raises error"""
with pytest.raises(ValueError, match="Invalid phase: invalid_phase"):
email_service.send_change_email(
@ -339,7 +339,7 @@ class TestEmailI18nService:
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
):
"""Test sending raw email to single recipient"""
email_service.send_raw_email(
to="test@example.com",
@ -357,7 +357,7 @@ class TestEmailI18nService:
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
):
"""Test sending raw email to multiple recipients"""
recipients = ["user1@example.com", "user2@example.com", "user3@example.com"]
@ -378,7 +378,7 @@ class TestEmailI18nService:
def test_get_template_missing_email_type(
self,
email_config: EmailI18nConfig,
) -> None:
):
"""Test getting template for missing email type raises error"""
with pytest.raises(ValueError, match="No templates configured for email type"):
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
@ -386,7 +386,7 @@ class TestEmailI18nService:
def test_get_template_missing_language_and_english(
self,
email_config: EmailI18nConfig,
) -> None:
):
"""Test error when neither requested language nor English fallback exists"""
# Add template without English fallback
email_config.templates[EmailType.EMAIL_CODE_LOGIN] = {
@ -407,7 +407,7 @@ class TestEmailI18nService:
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
) -> None:
):
"""Test subject templating with custom variables"""
# Add template with variable in subject
email_config.templates[EmailType.OWNER_TRANSFER_NEW_NOTIFY] = {
@ -437,7 +437,7 @@ class TestEmailI18nService:
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "You are now the owner of My Workspace"
def test_email_language_from_language_code(self) -> None:
def test_email_language_from_language_code(self):
"""Test EmailLanguage.from_language_code method"""
assert EmailLanguage.from_language_code("zh-Hans") == EmailLanguage.ZH_HANS
assert EmailLanguage.from_language_code("en-US") == EmailLanguage.EN_US
@ -448,7 +448,7 @@ class TestEmailI18nService:
class TestEmailI18nIntegration:
"""Integration tests for email i18n components"""
def test_create_default_email_config(self) -> None:
def test_create_default_email_config(self):
"""Test creating default email configuration"""
config = create_default_email_config()
@ -476,7 +476,7 @@ class TestEmailI18nIntegration:
assert EmailLanguage.ZH_HANS in config.templates[EmailType.RESET_PASSWORD]
assert EmailLanguage.ZH_HANS in config.templates[EmailType.INVITE_MEMBER]
def test_get_email_i18n_service(self) -> None:
def test_get_email_i18n_service(self):
"""Test getting global email i18n service instance"""
service1 = get_email_i18n_service()
service2 = get_email_i18n_service()
@ -484,7 +484,7 @@ class TestEmailI18nIntegration:
# Should return the same instance
assert service1 is service2
def test_flask_email_renderer(self) -> None:
def test_flask_email_renderer(self):
"""Test FlaskEmailRenderer implementation"""
renderer = FlaskEmailRenderer()
@ -494,7 +494,7 @@ class TestEmailI18nIntegration:
with pytest.raises(TemplateNotFound):
renderer.render_template("test.html", foo="bar")
def test_flask_mail_sender_not_initialized(self) -> None:
def test_flask_mail_sender_not_initialized(self):
"""Test FlaskMailSender when mail is not initialized"""
sender = FlaskMailSender()
@ -514,7 +514,7 @@ class TestEmailI18nIntegration:
# Restore original mail
libs.email_i18n.mail = original_mail
def test_flask_mail_sender_initialized(self) -> None:
def test_flask_mail_sender_initialized(self):
"""Test FlaskMailSender when mail is initialized"""
sender = FlaskMailSender()

View File

@ -4,7 +4,7 @@ from Crypto.PublicKey import RSA
from libs import gmpy2_pkcs10aep_cipher
def test_gmpy2_pkcs10aep_cipher() -> None:
def test_gmpy2_pkcs10aep_cipher():
rsa_key_pair = pyrsa.newkeys(2048)
public_key = rsa_key_pair[0].save_pkcs1()
private_key = rsa_key_pair[1].save_pkcs1()

View File

@ -1,7 +1,7 @@
from models.account import TenantAccountRole
def test_account_is_privileged_role() -> None:
def test_account_is_privileged_role():
assert TenantAccountRole.ADMIN == "admin"
assert TenantAccountRole.OWNER == "owner"
assert TenantAccountRole.EDITOR == "editor"

View File

@ -154,7 +154,7 @@ class TestEnumText:
TestCase(
name="session insert with invalid type",
action=lambda s: _session_insert_with_value(s, 1),
exc_type=TypeError,
exc_type=ValueError,
),
TestCase(
name="insert with invalid value",
@ -164,7 +164,7 @@ class TestEnumText:
TestCase(
name="insert with invalid type",
action=lambda s: _insert_with_user(s, 1),
exc_type=TypeError,
exc_type=ValueError,
),
]
for idx, c in enumerate(cases, 1):

View File

@ -2,11 +2,12 @@ import datetime
from typing import Any, Optional
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, patch
from unittest.mock import Mock, create_autospec, patch
import pytest
from core.model_runtime.entities.model_entities import ModelType
from models.account import Account
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
@ -78,7 +79,7 @@ class DatasetUpdateTestDataFactory:
@staticmethod
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
"""Create a mock current user."""
current_user = Mock()
current_user = create_autospec(Account, instance=True)
current_user.current_tenant_id = tenant_id
return current_user
@ -135,7 +136,9 @@ class TestDatasetServiceUpdateDataset:
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding,
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
patch("services.dataset_service.current_user") as mock_current_user,
patch(
"services.dataset_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
):
mock_current_user.current_tenant_id = "tenant-123"
yield {

View File

@ -1,9 +1,10 @@
from unittest.mock import Mock, patch
from unittest.mock import Mock, create_autospec, patch
import pytest
from flask_restx import reqparse
from werkzeug.exceptions import BadRequest
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
@ -35,19 +36,21 @@ class TestMetadataBugCompleteValidation:
mock_metadata_args.name = None
mock_metadata_args.type = "string"
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# Should crash with TypeError
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
# Test update method as well
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)

View File

@ -1,8 +1,9 @@
from unittest.mock import Mock, patch
from unittest.mock import Mock, create_autospec, patch
import pytest
from flask_restx import reqparse
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
@ -24,20 +25,22 @@ class TestMetadataNullableBug:
mock_metadata_args.name = None # This will cause len() to crash
mock_metadata_args.type = "string"
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_metadata_service_update_with_none_name_crashes(self):
"""Test that MetadataService.update_metadata_name crashes when name is None."""
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
@ -81,10 +84,11 @@ class TestMetadataNullableBug:
mock_metadata_args.name = None # From args["name"]
mock_metadata_args.type = None # From args["type"]
with patch("services.metadata_service.current_user") as mock_user:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# Step 4: Service layer crashes on len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)