Merge branch 'main' into feat/agent-node-v2

This commit is contained in:
Novice
2025-12-15 15:26:48 +08:00
694 changed files with 37577 additions and 16560 deletions

View File

@ -26,16 +26,29 @@ redis_mock.hgetall = MagicMock(return_value={})
redis_mock.hdel = MagicMock()
redis_mock.incr = MagicMock(return_value=1)
# Ensure OpenDAL fs writes to tmp to avoid polluting workspace
os.environ.setdefault("OPENDAL_SCHEME", "fs")
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
os.environ.setdefault("STORAGE_TYPE", "opendal")
# Add the API directory to Python path to ensure proper imports
import sys
sys.path.insert(0, PROJECT_DIR)
# apply the mock to the Redis client in the Flask app
from extensions import ext_redis
redis_patcher = patch.object(ext_redis, "redis_client", redis_mock)
redis_patcher.start()
def _patch_redis_clients_on_loaded_modules():
"""Ensure any module-level redis_client references point to the shared redis_mock."""
import sys
for module in list(sys.modules.values()):
if module is None:
continue
if hasattr(module, "redis_client"):
module.redis_client = redis_mock
@pytest.fixture
@ -49,6 +62,15 @@ def _provide_app_context(app: Flask):
yield
@pytest.fixture(autouse=True)
def _patch_redis_clients():
"""Patch redis_client to MagicMock only for unit test executions."""
with patch.object(ext_redis, "redis_client", redis_mock):
_patch_redis_clients_on_loaded_modules()
yield
@pytest.fixture(autouse=True)
def reset_redis_mock():
"""reset the Redis mock before each test"""
@ -63,3 +85,20 @@ def reset_redis_mock():
redis_mock.hgetall.return_value = {}
redis_mock.hdel.return_value = None
redis_mock.incr.return_value = 1
# Keep any imported modules pointing at the mock between tests
_patch_redis_clients_on_loaded_modules()
@pytest.fixture(autouse=True)
def reset_secret_key():
"""Ensure SECRET_KEY-dependent logic sees an empty config value by default."""
from configs import dify_config
original = dify_config.SECRET_KEY
dify_config.SECRET_KEY = ""
try:
yield
finally:
dify_config.SECRET_KEY = original

View File

@ -0,0 +1,344 @@
"""
Unit tests for annotation import security features.
Tests rate limiting, concurrency control, file validation, and other
security features added to prevent DoS attacks on the annotation import endpoint.
"""
import io
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.datastructures import FileStorage
from configs import dify_config
class TestAnnotationImportRateLimiting:
"""Test rate limiting for annotation import operations."""
@pytest.fixture
def mock_redis(self):
"""Mock Redis client for testing."""
with patch("controllers.console.wraps.redis_client") as mock:
yield mock
@pytest.fixture
def mock_current_account(self):
"""Mock current account with tenant."""
with patch("controllers.console.wraps.current_account_with_tenant") as mock:
mock.return_value = (MagicMock(id="user_id"), "test_tenant_id")
yield mock
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
"""Test that per-minute rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-minute limit
mock_redis.zcard.side_effect = [
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
10, # Hour check
]
@annotation_import_rate_limit
def dummy_view():
return "success"
# Should abort with 429
with pytest.raises(Exception) as exc_info:
dummy_view()
# Verify it's a rate limit error
assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value)
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
"""Test that per-hour rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-hour limit
mock_redis.zcard.side_effect = [
3, # Minute check (under limit)
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR + 1, # Hour check (over limit)
]
@annotation_import_rate_limit
def dummy_view():
return "success"
# Should abort with 429
with pytest.raises(Exception) as exc_info:
dummy_view()
assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value)
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
"""Test that requests within limits are allowed."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate being under both limits
mock_redis.zcard.return_value = 2
@annotation_import_rate_limit
def dummy_view():
return "success"
# Should succeed
result = dummy_view()
assert result == "success"
# Verify Redis operations were called
assert mock_redis.zadd.called
assert mock_redis.zremrangebyscore.called
class TestAnnotationImportConcurrencyControl:
"""Test concurrency control for annotation import operations."""
@pytest.fixture
def mock_redis(self):
"""Mock Redis client for testing."""
with patch("controllers.console.wraps.redis_client") as mock:
yield mock
@pytest.fixture
def mock_current_account(self):
"""Mock current account with tenant."""
with patch("controllers.console.wraps.current_account_with_tenant") as mock:
mock.return_value = (MagicMock(id="user_id"), "test_tenant_id")
yield mock
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
"""Test that concurrent task limit is enforced."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate max concurrent tasks already running
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
@annotation_import_concurrency_limit
def dummy_view():
return "success"
# Should abort with 429
with pytest.raises(Exception) as exc_info:
dummy_view()
assert "429" in str(exc_info.value) or "concurrent" in str(exc_info.value).lower()
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
"""Test that requests within concurrency limits are allowed."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate being under concurrent task limit
mock_redis.zcard.return_value = 1
@annotation_import_concurrency_limit
def dummy_view():
return "success"
# Should succeed
result = dummy_view()
assert result == "success"
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
"""Test that old/stale job entries are removed."""
from controllers.console.wraps import annotation_import_concurrency_limit
mock_redis.zcard.return_value = 0
@annotation_import_concurrency_limit
def dummy_view():
return "success"
dummy_view()
# Verify cleanup was called
assert mock_redis.zremrangebyscore.called
class TestAnnotationImportFileValidation:
"""Test file validation in annotation import."""
def test_file_size_limit_enforced(self):
"""Test that files exceeding size limit are rejected."""
# Create a file larger than the limit
max_size = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
large_content = b"x" * (max_size + 1024) # Exceed by 1KB
file = FileStorage(stream=io.BytesIO(large_content), filename="test.csv", content_type="text/csv")
# Should be rejected in controller
# This would be tested in integration tests with actual endpoint
def test_empty_file_rejected(self):
"""Test that empty files are rejected."""
file = FileStorage(stream=io.BytesIO(b""), filename="test.csv", content_type="text/csv")
# Should be rejected
# This would be tested in integration tests
def test_non_csv_file_rejected(self):
"""Test that non-CSV files are rejected."""
file = FileStorage(stream=io.BytesIO(b"test"), filename="test.txt", content_type="text/plain")
# Should be rejected based on extension
# This would be tested in integration tests
class TestAnnotationImportServiceValidation:
"""Test service layer validation for annotation import."""
@pytest.fixture
def mock_app(self):
"""Mock application object."""
app = MagicMock()
app.id = "app_id"
return app
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.annotation_service.db.session") as mock:
yield mock
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too many records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with too many records
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
csv_content = "question,answer\n"
for i in range(max_records + 100):
csv_content += f"Question {i},Answer {i}\n"
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
with patch("services.annotation_service.FeatureService") as mock_features:
mock_features.get_features.return_value.billing.enabled = False
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
# Should return error about too many records
assert "error_msg" in result
assert "too many" in result["error_msg"].lower() or "maximum" in result["error_msg"].lower()
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too few valid records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with only header (no data rows)
csv_content = "question,answer\n"
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
# Should return error about insufficient records
assert "error_msg" in result
assert "at least" in result["error_msg"].lower() or "minimum" in result["error_msg"].lower()
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
"""Test that invalid CSV format is handled gracefully."""
from services.annotation_service import AppAnnotationService
# Create invalid CSV content
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
# Should return error message
assert "error_msg" in result
def test_valid_import_succeeds(self, mock_app, mock_db_session):
"""Test that valid import request succeeds."""
from services.annotation_service import AppAnnotationService
# Create valid CSV
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
with patch("services.annotation_service.FeatureService") as mock_features:
mock_features.get_features.return_value.billing.enabled = False
with patch("services.annotation_service.batch_import_annotations_task") as mock_task:
with patch("services.annotation_service.redis_client"):
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
# Should return success response
assert "job_id" in result
assert "job_status" in result
assert result["job_status"] == "waiting"
assert "record_count" in result
assert result["record_count"] == 2
class TestAnnotationImportTaskOptimization:
"""Test optimizations in batch import task."""
def test_task_has_timeout_configured(self):
"""Test that task has proper timeout configuration."""
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
# Verify task configuration
assert hasattr(batch_import_annotations_task, "time_limit")
assert hasattr(batch_import_annotations_task, "soft_time_limit")
# Check timeout values are reasonable
# Hard limit should be 6 minutes (360s)
# Soft limit should be 5 minutes (300s)
# Note: actual values depend on Celery configuration
class TestConfigurationValues:
"""Test that security configuration values are properly set."""
def test_rate_limit_configs_exist(self):
"""Test that rate limit configurations are defined."""
assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE")
assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR")
assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE > 0
assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR > 0
def test_file_size_limit_config_exists(self):
"""Test that file size limit configuration is defined."""
assert hasattr(dify_config, "ANNOTATION_IMPORT_FILE_SIZE_LIMIT")
assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT > 0
assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT <= 10 # Reasonable max (10MB)
def test_record_limit_configs_exist(self):
"""Test that record limit configurations are defined."""
assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_RECORDS")
assert hasattr(dify_config, "ANNOTATION_IMPORT_MIN_RECORDS")
assert dify_config.ANNOTATION_IMPORT_MAX_RECORDS > 0
assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS > 0
assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS < dify_config.ANNOTATION_IMPORT_MAX_RECORDS
def test_concurrency_limit_config_exists(self):
"""Test that concurrency limit configuration is defined."""
assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_CONCURRENT")
assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT > 0
assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT <= 10 # Reasonable upper bound

View File

@ -125,7 +125,7 @@ class TestPartnerTenants:
resource = PartnerTenants()
# Act & Assert
# reqparse will raise BadRequest for missing required field
# Validation should raise BadRequest for missing required field
with pytest.raises(BadRequest):
resource.put(partner_key_encoded)

View File

@ -0,0 +1,407 @@
"""Final working unit tests for admin endpoints - tests business logic directly."""
import uuid
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.console.admin import InsertExploreAppPayload
from models.model import App, RecommendedApp
class TestInsertExploreAppPayload:
"""Test InsertExploreAppPayload validation."""
def test_valid_payload(self):
"""Test creating payload with valid data."""
payload_data = {
"app_id": str(uuid.uuid4()),
"desc": "Test app description",
"copyright": "© 2024 Test Company",
"privacy_policy": "https://example.com/privacy",
"custom_disclaimer": "Custom disclaimer text",
"language": "en-US",
"category": "Productivity",
"position": 1,
}
payload = InsertExploreAppPayload.model_validate(payload_data)
assert payload.app_id == payload_data["app_id"]
assert payload.desc == payload_data["desc"]
assert payload.copyright == payload_data["copyright"]
assert payload.privacy_policy == payload_data["privacy_policy"]
assert payload.custom_disclaimer == payload_data["custom_disclaimer"]
assert payload.language == payload_data["language"]
assert payload.category == payload_data["category"]
assert payload.position == payload_data["position"]
def test_minimal_payload(self):
"""Test creating payload with only required fields."""
payload_data = {
"app_id": str(uuid.uuid4()),
"language": "en-US",
"category": "Productivity",
"position": 1,
}
payload = InsertExploreAppPayload.model_validate(payload_data)
assert payload.app_id == payload_data["app_id"]
assert payload.desc is None
assert payload.copyright is None
assert payload.privacy_policy is None
assert payload.custom_disclaimer is None
assert payload.language == payload_data["language"]
assert payload.category == payload_data["category"]
assert payload.position == payload_data["position"]
def test_invalid_language(self):
"""Test payload with invalid language code."""
payload_data = {
"app_id": str(uuid.uuid4()),
"language": "invalid-lang",
"category": "Productivity",
"position": 1,
}
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
InsertExploreAppPayload.model_validate(payload_data)
class TestAdminRequiredDecorator:
"""Test admin_required decorator."""
def setup_method(self):
"""Set up test fixtures."""
# Mock dify_config
self.dify_config_patcher = patch("controllers.console.admin.dify_config")
self.mock_dify_config = self.dify_config_patcher.start()
self.mock_dify_config.ADMIN_API_KEY = "test-admin-key"
# Mock extract_access_token
self.token_patcher = patch("controllers.console.admin.extract_access_token")
self.mock_extract_token = self.token_patcher.start()
def teardown_method(self):
"""Clean up test fixtures."""
self.dify_config_patcher.stop()
self.token_patcher.stop()
def test_admin_required_success(self):
"""Test successful admin authentication."""
from controllers.console.admin import admin_required
@admin_required
def test_view():
return {"success": True}
self.mock_extract_token.return_value = "test-admin-key"
result = test_view()
assert result["success"] is True
def test_admin_required_invalid_token(self):
"""Test admin_required with invalid token."""
from controllers.console.admin import admin_required
@admin_required
def test_view():
return {"success": True}
self.mock_extract_token.return_value = "wrong-key"
with pytest.raises(Unauthorized, match="API key is invalid"):
test_view()
def test_admin_required_no_api_key_configured(self):
"""Test admin_required when no API key is configured."""
from controllers.console.admin import admin_required
self.mock_dify_config.ADMIN_API_KEY = None
@admin_required
def test_view():
return {"success": True}
with pytest.raises(Unauthorized, match="API key is invalid"):
test_view()
def test_admin_required_missing_authorization_header(self):
"""Test admin_required with missing authorization header."""
from controllers.console.admin import admin_required
@admin_required
def test_view():
return {"success": True}
self.mock_extract_token.return_value = None
with pytest.raises(Unauthorized, match="Authorization header is missing"):
test_view()
class TestExploreAppBusinessLogicDirect:
"""Test the core business logic of explore app management directly."""
def test_data_fusion_logic(self):
"""Test the data fusion logic between payload and site data."""
# Test cases for different data scenarios
test_cases = [
{
"name": "site_data_overrides_payload",
"payload": {"desc": "Payload desc", "copyright": "Payload copyright"},
"site": {"description": "Site desc", "copyright": "Site copyright"},
"expected": {
"desc": "Site desc",
"copyright": "Site copyright",
"privacy_policy": "",
"custom_disclaimer": "",
},
},
{
"name": "payload_used_when_no_site",
"payload": {"desc": "Payload desc", "copyright": "Payload copyright"},
"site": None,
"expected": {
"desc": "Payload desc",
"copyright": "Payload copyright",
"privacy_policy": "",
"custom_disclaimer": "",
},
},
{
"name": "empty_defaults_when_no_data",
"payload": {},
"site": None,
"expected": {"desc": "", "copyright": "", "privacy_policy": "", "custom_disclaimer": ""},
},
]
for case in test_cases:
# Simulate the data fusion logic
payload_desc = case["payload"].get("desc")
payload_copyright = case["payload"].get("copyright")
payload_privacy_policy = case["payload"].get("privacy_policy")
payload_custom_disclaimer = case["payload"].get("custom_disclaimer")
if case["site"]:
site_desc = case["site"].get("description")
site_copyright = case["site"].get("copyright")
site_privacy_policy = case["site"].get("privacy_policy")
site_custom_disclaimer = case["site"].get("custom_disclaimer")
# Site data takes precedence
desc = site_desc or payload_desc or ""
copyright = site_copyright or payload_copyright or ""
privacy_policy = site_privacy_policy or payload_privacy_policy or ""
custom_disclaimer = site_custom_disclaimer or payload_custom_disclaimer or ""
else:
# Use payload data or empty defaults
desc = payload_desc or ""
copyright = payload_copyright or ""
privacy_policy = payload_privacy_policy or ""
custom_disclaimer = payload_custom_disclaimer or ""
result = {
"desc": desc,
"copyright": copyright,
"privacy_policy": privacy_policy,
"custom_disclaimer": custom_disclaimer,
}
assert result == case["expected"], f"Failed test case: {case['name']}"
def test_app_visibility_logic(self):
"""Test that apps are made public when added to explore list."""
# Create a mock app
mock_app = Mock(spec=App)
mock_app.is_public = False
# Simulate the business logic
mock_app.is_public = True
assert mock_app.is_public is True
def test_recommended_app_creation_logic(self):
"""Test the creation of RecommendedApp objects."""
app_id = str(uuid.uuid4())
payload_data = {
"app_id": app_id,
"desc": "Test app description",
"copyright": "© 2024 Test Company",
"privacy_policy": "https://example.com/privacy",
"custom_disclaimer": "Custom disclaimer",
"language": "en-US",
"category": "Productivity",
"position": 1,
}
# Simulate the creation logic
recommended_app = Mock(spec=RecommendedApp)
recommended_app.app_id = payload_data["app_id"]
recommended_app.description = payload_data["desc"]
recommended_app.copyright = payload_data["copyright"]
recommended_app.privacy_policy = payload_data["privacy_policy"]
recommended_app.custom_disclaimer = payload_data["custom_disclaimer"]
recommended_app.language = payload_data["language"]
recommended_app.category = payload_data["category"]
recommended_app.position = payload_data["position"]
# Verify the data
assert recommended_app.app_id == app_id
assert recommended_app.description == "Test app description"
assert recommended_app.copyright == "© 2024 Test Company"
assert recommended_app.privacy_policy == "https://example.com/privacy"
assert recommended_app.custom_disclaimer == "Custom disclaimer"
assert recommended_app.language == "en-US"
assert recommended_app.category == "Productivity"
assert recommended_app.position == 1
def test_recommended_app_update_logic(self):
"""Test the update logic for existing RecommendedApp objects."""
mock_recommended_app = Mock(spec=RecommendedApp)
update_data = {
"desc": "Updated description",
"copyright": "© 2024 Updated",
"language": "fr-FR",
"category": "Tools",
"position": 2,
}
# Simulate the update logic
mock_recommended_app.description = update_data["desc"]
mock_recommended_app.copyright = update_data["copyright"]
mock_recommended_app.language = update_data["language"]
mock_recommended_app.category = update_data["category"]
mock_recommended_app.position = update_data["position"]
# Verify the updates
assert mock_recommended_app.description == "Updated description"
assert mock_recommended_app.copyright == "© 2024 Updated"
assert mock_recommended_app.language == "fr-FR"
assert mock_recommended_app.category == "Tools"
assert mock_recommended_app.position == 2
def test_app_not_found_error_logic(self):
"""Test error handling when app is not found."""
app_id = str(uuid.uuid4())
# Simulate app lookup returning None
found_app = None
# Test the error condition
if not found_app:
with pytest.raises(NotFound, match=f"App '{app_id}' is not found"):
raise NotFound(f"App '{app_id}' is not found")
def test_recommended_app_not_found_error_logic(self):
"""Test error handling when recommended app is not found for deletion."""
app_id = str(uuid.uuid4())
# Simulate recommended app lookup returning None
found_recommended_app = None
# Test the error condition
if not found_recommended_app:
with pytest.raises(NotFound, match=f"App '{app_id}' is not found in the explore list"):
raise NotFound(f"App '{app_id}' is not found in the explore list")
def test_database_session_usage_patterns(self):
"""Test the expected database session usage patterns."""
# Mock session usage patterns
mock_session = Mock()
# Test session.add pattern
mock_recommended_app = Mock(spec=RecommendedApp)
mock_session.add(mock_recommended_app)
mock_session.commit()
# Verify session was used correctly
mock_session.add.assert_called_once_with(mock_recommended_app)
mock_session.commit.assert_called_once()
# Test session.delete pattern
mock_recommended_app_to_delete = Mock(spec=RecommendedApp)
mock_session.delete(mock_recommended_app_to_delete)
mock_session.commit()
# Verify delete pattern
mock_session.delete.assert_called_once_with(mock_recommended_app_to_delete)
def test_payload_validation_integration(self):
"""Test payload validation in the context of the business logic."""
# Test valid payload
valid_payload_data = {
"app_id": str(uuid.uuid4()),
"desc": "Test app description",
"language": "en-US",
"category": "Productivity",
"position": 1,
}
# This should succeed
payload = InsertExploreAppPayload.model_validate(valid_payload_data)
assert payload.app_id == valid_payload_data["app_id"]
# Test invalid payload
invalid_payload_data = {
"app_id": str(uuid.uuid4()),
"language": "invalid-lang", # This should fail validation
"category": "Productivity",
"position": 1,
}
# This should raise an exception
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
InsertExploreAppPayload.model_validate(invalid_payload_data)
class TestExploreAppDataHandling:
"""Test specific data handling scenarios."""
def test_uuid_validation(self):
"""Test UUID validation and handling."""
# Test valid UUID
valid_uuid = str(uuid.uuid4())
# This should be a valid UUID
assert uuid.UUID(valid_uuid) is not None
# Test invalid UUID
invalid_uuid = "not-a-valid-uuid"
# This should raise a ValueError
with pytest.raises(ValueError):
uuid.UUID(invalid_uuid)
def test_language_validation(self):
"""Test language validation against supported languages."""
from constants.languages import supported_language
# Test supported language
assert supported_language("en-US") == "en-US"
assert supported_language("fr-FR") == "fr-FR"
# Test unsupported language
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
supported_language("invalid-lang")
def test_response_formatting(self):
"""Test API response formatting."""
# Test success responses
create_response = {"result": "success"}
update_response = {"result": "success"}
delete_response = None # 204 No Content returns None
assert create_response["result"] == "success"
assert update_response["result"] == "success"
assert delete_response is None
# Test status codes
create_status = 201 # Created
update_status = 200 # OK
delete_status = 204 # No Content
assert create_status == 201
assert update_status == 200
assert delete_status == 204

View File

@ -0,0 +1,25 @@
import uuid
import pytest
from pydantic import ValidationError
from controllers.service_api.app.completion import ChatRequestPayload
def test_chat_request_payload_accepts_blank_conversation_id():
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": ""})
assert payload.conversation_id is None
def test_chat_request_payload_validates_uuid():
conversation_id = str(uuid.uuid4())
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": conversation_id})
assert payload.conversation_id == conversation_id
def test_chat_request_payload_rejects_invalid_uuid():
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})

View File

@ -256,24 +256,18 @@ class TestFilePreviewApi:
mock_app, # App query for tenant validation
]
with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse:
# Mock request parsing
mock_parser = Mock()
mock_parser.parse_args.return_value = {"as_attachment": False}
mock_reqparse.RequestParser.return_value = mock_parser
# Test the core logic directly without Flask decorators
# Validate file ownership
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
assert result_message_file == mock_message_file
assert result_upload_file == mock_upload_file
# Test the core logic directly without Flask decorators
# Validate file ownership
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
assert result_message_file == mock_message_file
assert result_upload_file == mock_upload_file
# Test file response building
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
assert response is not None
# Test file response building
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
assert response is not None
# Verify storage was called correctly
mock_storage.load.assert_not_called() # Since we're testing components separately
# Verify storage was called correctly
mock_storage.load.assert_not_called() # Since we're testing components separately
@patch("controllers.service_api.app.file_preview.storage")
def test_storage_error_handling(

View File

@ -0,0 +1,20 @@
import pytest
from pydantic import ValidationError
from controllers.console.explore.conversation import ConversationRenamePayload as ConsolePayload
from controllers.service_api.app.conversation import ConversationRenamePayload as ServicePayload
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
def test_payload_allows_auto_generate_without_name(payload_cls):
payload = payload_cls.model_validate({"auto_generate": True})
assert payload.auto_generate is True
assert payload.name is None
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
@pytest.mark.parametrize("value", [None, "", " "])
def test_payload_requires_name_when_not_auto_generate(payload_cls, value):
with pytest.raises(ValidationError):
payload_cls.model_validate({"name": value, "auto_generate": False})

View File

@ -0,0 +1,129 @@
import json
from unittest.mock import patch
import pytest
from redis.exceptions import RedisError
from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
@pytest.fixture
def mock_redis_client():
"""Fixture: Mock Redis client"""
with patch("core.helper.tool_provider_cache.redis_client") as mock:
yield mock
class TestToolProviderListCache:
"""Test class for ToolProviderListCache"""
def test_generate_cache_key(self):
"""Test cache key generation logic"""
# Scenario 1: Specify typ (valid literal value)
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "builtin"
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
# Scenario 2: typ is None (defaults to "all")
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
def test_get_cached_providers_hit(self, mock_redis_client):
"""Test get cached providers - cache hit and successful decoding"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "api"
mock_providers = [{"id": "tool", "name": "test_provider"}]
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
assert result == mock_providers
def test_get_cached_providers_decode_error(self, mock_redis_client):
"""Test get cached providers - cache hit but decoding failed"""
tenant_id = "tenant_123"
mock_redis_client.get.return_value = b"invalid_json_data"
result = ToolProviderListCache.get_cached_providers(tenant_id)
assert result is None
mock_redis_client.get.assert_called_once()
def test_get_cached_providers_miss(self, mock_redis_client):
"""Test get cached providers - cache miss"""
tenant_id = "tenant_123"
mock_redis_client.get.return_value = None
result = ToolProviderListCache.get_cached_providers(tenant_id)
assert result is None
mock_redis_client.get.assert_called_once()
def test_set_cached_providers(self, mock_redis_client):
"""Test set cached providers"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "builtin"
mock_providers = [{"id": "tool", "name": "test_provider"}]
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
mock_redis_client.setex.assert_called_once_with(
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
)
def test_invalidate_cache_specific_type(self, mock_redis_client):
"""Test invalidate cache - specific type"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "workflow"
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
ToolProviderListCache.invalidate_cache(tenant_id, typ)
mock_redis_client.delete.assert_called_once_with(cache_key)
def test_invalidate_cache_all_types(self, mock_redis_client):
"""Test invalidate cache - clear all tenant cache"""
tenant_id = "tenant_123"
mock_keys = [
b"tool_providers:tenant_id:tenant_123:type:all",
b"tool_providers:tenant_id:tenant_123:type:builtin",
]
mock_redis_client.scan_iter.return_value = mock_keys
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
mock_redis_client.delete.assert_called_once_with(*mock_keys)
def test_invalidate_cache_no_keys(self, mock_redis_client):
"""Test invalidate cache - no cache keys for tenant"""
tenant_id = "tenant_123"
mock_redis_client.scan_iter.return_value = []
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.delete.assert_not_called()
def test_redis_fallback_default_return(self, mock_redis_client):
"""Test redis_fallback decorator - default return value (Redis error)"""
mock_redis_client.get.side_effect = RedisError("Redis connection error")
result = ToolProviderListCache.get_cached_providers("tenant_123")
assert result is None
mock_redis_client.get.assert_called_once()
def test_redis_fallback_no_default(self, mock_redis_client):
"""Test redis_fallback decorator - no default return value (Redis error)"""
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
try:
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
except RedisError:
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
mock_redis_client.setex.assert_called_once()

View File

@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeConnectionError,
@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments:
@pytest.fixture
def sample_embedding_result(self):
"""Create a sample TextEmbeddingResult for testing.
"""Create a sample EmbeddingResult for testing.
Returns:
TextEmbeddingResult: Mock embedding result with proper structure
EmbeddingResult: Mock embedding result with proper structure
"""
# Create normalized embedding vectors (dimension 1536 for ada-002)
embedding_vector = np.random.randn(1536)
@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_vector],
usage=usage,
@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments:
latency=0.8,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments:
latency=0.6,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=new_embeddings,
usage=usage,
@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[valid_vector.tolist(), nan_vector],
usage=usage,
@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[nan_vector],
usage=usage,
@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching:
latency=0.3,
)
result_ada = TextEmbeddingResult(
result_ada = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_ada],
usage=usage,
)
result_3_small = TextEmbeddingResult(
result_3_small = EmbeddingResult(
model="text-embedding-3-small",
embeddings=[normalized_3_small],
usage=usage,
@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching:
latency=0.4,
)
result_openai = TextEmbeddingResult(
result_openai = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_openai],
usage=usage_openai,
)
result_cohere = TextEmbeddingResult(
result_cohere = EmbeddingResult(
model="embed-english-v3.0",
embeddings=[normalized_cohere],
usage=usage_cohere,
@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation:
latency=0.7,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation:
latency=0.3,
)
result_ada = TextEmbeddingResult(
result_ada = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_ada],
usage=usage_ada,
@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation:
latency=0.4,
)
result_cohere = TextEmbeddingResult(
result_cohere = EmbeddingResult(
model="embed-english-v3.0",
embeddings=[normalized_cohere],
usage=usage_cohere,
@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases:
latency=0.1,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases:
latency=1.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases:
latency=0.2,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases:
)
# Model returns embeddings for all texts
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases:
latency=0.8,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,

View File

@ -1,7 +1,10 @@
"""Primarily used for testing merged cell scenarios"""
from types import SimpleNamespace
from docx import Document
import core.rag.extractor.word_extractor as we
from core.rag.extractor.word_extractor import WordExtractor
@ -47,3 +50,85 @@ def test_parse_row():
extractor = object.__new__(WordExtractor)
for idx, row in enumerate(table.rows):
assert extractor._parse_row(row, {}, 3) == gt[idx]
def test_extract_images_from_docx(monkeypatch):
external_bytes = b"ext-bytes"
internal_bytes = b"int-bytes"
# Patch storage.save to capture writes
saves: list[tuple[str, bytes]] = []
def save(key: str, data: bytes):
saves.append((key, data))
monkeypatch.setattr(we, "storage", SimpleNamespace(save=save))
# Patch db.session to record adds/commit
class DummySession:
def __init__(self):
self.added = []
self.committed = False
def add(self, obj):
self.added.append(obj)
def commit(self):
self.committed = True
db_stub = SimpleNamespace(session=DummySession())
monkeypatch.setattr(we, "db", db_stub)
# Patch config values used for URL composition and storage type
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
# Patch UploadFile to avoid real DB models
class FakeUploadFile:
_i = 0
def __init__(self, **kwargs): # kwargs match the real signature fields
type(self)._i += 1
self.id = f"u{self._i}"
monkeypatch.setattr(we, "UploadFile", FakeUploadFile)
# Patch external image fetcher
def fake_get(url: str):
assert url == "https://example.com/image.png"
return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes)
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
# A hashable internal part object with a blob attribute
class HashablePart:
def __init__(self, blob: bytes):
self.blob = blob
def __hash__(self) -> int: # ensure it can be used as a dict key like real docx parts
return id(self)
# Build a minimal doc object with both external and internal image rels
internal_part = HashablePart(blob=internal_bytes)
rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png")
rel_int = SimpleNamespace(is_external=False, target_ref="word/media/image1.png", target_part=internal_part)
doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext, "rId2": rel_int}))
extractor = object.__new__(WordExtractor)
extractor.tenant_id = "t1"
extractor.user_id = "u1"
image_map = extractor._extract_images_from_docx(doc)
# Returned map should contain entries for external (keyed by rId) and internal (keyed by target_part)
assert set(image_map.keys()) == {"rId1", internal_part}
assert all(v.startswith("![image](") and v.endswith("/file-preview)") for v in image_map.values())
# Storage should receive both payloads
payloads = {data for _, data in saves}
assert external_bytes in payloads
assert internal_bytes in payloads
# DB interactions should be recorded
assert len(db_stub.session.added) == 2
assert db_stub.session.committed is True

View File

@ -62,7 +62,7 @@ from core.indexing_runner import (
IndexingRunner,
)
from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import ChildDocument, Document
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DatasetProcessRule
@ -112,7 +112,7 @@ def create_mock_dataset_document(
document_id: str | None = None,
dataset_id: str | None = None,
tenant_id: str | None = None,
doc_form: str = IndexType.PARAGRAPH_INDEX,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
data_source_type: str = "upload_file",
doc_language: str = "English",
) -> Mock:
@ -133,8 +133,8 @@ def create_mock_dataset_document(
Mock: A configured mock DatasetDocument object with all required attributes.
Example:
>>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX)
>>> assert doc.doc_form == IndexType.QA_INDEX
>>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX)
>>> assert doc.doc_form == IndexStructureType.QA_INDEX
"""
doc = Mock(spec=DatasetDocument)
doc.id = document_id or str(uuid.uuid4())
@ -276,7 +276,7 @@ class TestIndexingRunnerExtract:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.data_source_type = "upload_file"
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
return doc
@ -616,7 +616,7 @@ class TestIndexingRunnerLoad:
doc = Mock(spec=DatasetDocument)
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
return doc
@pytest.fixture
@ -700,7 +700,7 @@ class TestIndexingRunnerLoad:
"""Test loading with parent-child index structure."""
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
sample_dataset.indexing_technique = "high_quality"
# Add child documents
@ -775,7 +775,7 @@ class TestIndexingRunnerRun:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.doc_language = "English"
doc.data_source_type = "upload_file"
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
@ -802,6 +802,21 @@ class TestIndexingRunnerRun:
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
# Mock current_user (Account) for _transform
mock_current_user = MagicMock()
mock_current_user.set_tenant_id = MagicMock()
# Setup db.session.query to return different results based on the model
def mock_query_side_effect(model):
mock_query_result = MagicMock()
if model.__name__ == "Dataset":
mock_query_result.filter_by.return_value.first.return_value = mock_dataset
elif model.__name__ == "Account":
mock_query_result.filter_by.return_value.first.return_value = mock_current_user
return mock_query_result
mock_dependencies["db"].session.query.side_effect = mock_query_side_effect
# Mock processor
mock_processor = MagicMock()
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.created_by = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
return doc
@pytest.fixture
@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments:
"""Test loading segments for parent-child index."""
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
# Add child documents
for doc in sample_documents:
@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate:
tenant_id=tenant_id,
extract_settings=extract_settings,
tmp_processing_rule={"mode": "automatic", "rules": {}},
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)

View File

@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.rerank.weight_rerank import WeightRerankRunner
def create_mock_model_instance():
"""Create a properly configured mock ModelInstance for reranking tests."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
return mock_instance
class TestRerankModelRunner:
"""Unit tests for RerankModelRunner.
@ -37,10 +49,23 @@ class TestRerankModelRunner:
- Metadata preservation and score injection
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@pytest.fixture
def mock_model_instance(self):
"""Create a mock ModelInstance for reranking."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
return mock_instance
@pytest.fixture
@ -803,7 +828,7 @@ class TestRerankRunnerFactory:
- Parameters are forwarded to runner constructor
"""
# Arrange: Mock model instance
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
# Act: Create runner via factory
runner = RerankRunnerFactory.create_rerank_runner(
@ -865,7 +890,7 @@ class TestRerankRunnerFactory:
- String values are properly matched
"""
# Arrange: Mock model instance
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
# Act: Create runner using enum value
runner = RerankRunnerFactory.create_rerank_runner(
@ -886,6 +911,13 @@ class TestRerankIntegration:
- Real-world usage scenarios
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_model_reranking_full_workflow(self):
"""Test complete model-based reranking workflow.
@ -895,7 +927,7 @@ class TestRerankIntegration:
- Top results are returned correctly
"""
# Arrange: Create mock model and documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -951,7 +983,7 @@ class TestRerankIntegration:
- Normalization is consistent
"""
# Arrange: Create mock model with various scores
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -991,6 +1023,13 @@ class TestRerankEdgeCases:
- Concurrent reranking scenarios
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_with_empty_metadata(self):
"""Test reranking when documents have empty metadata.
@ -1000,7 +1039,7 @@ class TestRerankEdgeCases:
- Empty metadata documents are processed correctly
"""
# Arrange: Create documents with empty metadata
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1046,7 +1085,7 @@ class TestRerankEdgeCases:
- Score comparison logic works at boundary
"""
# Arrange: Create mock with various scores including negatives
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1082,7 +1121,7 @@ class TestRerankEdgeCases:
- No overflow or precision issues
"""
# Arrange: All documents with perfect scores
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1117,7 +1156,7 @@ class TestRerankEdgeCases:
- Content encoding is preserved
"""
# Arrange: Documents with special characters
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1159,7 +1198,7 @@ class TestRerankEdgeCases:
- Content is not truncated unexpectedly
"""
# Arrange: Documents with very long content
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
long_content = "This is a very long document. " * 1000 # ~30,000 characters
mock_rerank_result = RerankResult(
@ -1196,7 +1235,7 @@ class TestRerankEdgeCases:
- All documents are processed correctly
"""
# Arrange: Create 100 documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
num_docs = 100
# Create rerank results for all documents
@ -1287,7 +1326,7 @@ class TestRerankEdgeCases:
- Documents can still be ranked
"""
# Arrange: Empty query
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1325,6 +1364,13 @@ class TestRerankPerformance:
- Score calculation optimization
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_batch_processing(self):
"""Test that documents are processed in a single batch.
@ -1334,7 +1380,7 @@ class TestRerankPerformance:
- Efficient batch processing
"""
# Arrange: Multiple documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)],
@ -1435,6 +1481,13 @@ class TestRerankErrorHandling:
- Error propagation
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_model_invocation_error(self):
"""Test handling of model invocation errors.
@ -1444,7 +1497,7 @@ class TestRerankErrorHandling:
- Error context is preserved
"""
# Arrange: Mock model that raises exception
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed")
documents = [
@ -1470,7 +1523,7 @@ class TestRerankErrorHandling:
- Invalid results don't corrupt output
"""
# Arrange: Rerank result with invalid index
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[

View File

@ -425,15 +425,15 @@ class TestRetrievalService:
# ==================== Vector Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents):
def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test basic vector/semantic search functionality.
This test validates the core vector search flow:
1. Dataset is retrieved from database
2. embedding_search is called via ThreadPoolExecutor
2. _retrieve is called via ThreadPoolExecutor
3. Documents are added to shared all_documents list
4. Results are returned to caller
@ -447,28 +447,28 @@ class TestRetrievalService:
# Set up the mock dataset that will be "retrieved" from database
mock_get_dataset.return_value = mock_dataset
# Create a side effect function that simulates embedding_search behavior
# In the real implementation, embedding_search:
# 1. Gets the dataset
# 2. Creates a Vector instance
# 3. Calls search_by_vector with embeddings
# 4. Extends all_documents with results
def side_effect_embedding_search(
# Create a side effect function that simulates _retrieve behavior
# _retrieve modifies the all_documents list in place
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
"""Simulate embedding_search adding documents to the shared list."""
all_documents.extend(sample_documents)
"""Simulate _retrieve adding documents to the shared list."""
if all_documents is not None:
all_documents.extend(sample_documents)
mock_embedding_search.side_effect = side_effect_embedding_search
mock_retrieve.side_effect = side_effect_retrieve
# Define test parameters
query = "What is Python?" # Natural language query
@ -481,7 +481,7 @@ class TestRetrievalService:
# 1. Check if query is empty (early return if so)
# 2. Get the dataset using _get_dataset
# 3. Create ThreadPoolExecutor
# 4. Submit embedding_search task
# 4. Submit _retrieve task
# 5. Wait for completion
# 6. Return all_documents list
results = RetrievalService.retrieve(
@ -502,15 +502,13 @@ class TestRetrievalService:
# Verify documents maintain their scores (highest score first in sample_documents)
assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
# Verify embedding_search was called exactly once
# Verify _retrieve was called exactly once
# This confirms the search method was invoked by ThreadPoolExecutor
mock_embedding_search.assert_called_once()
mock_retrieve.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_document_filter(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test vector search with document ID filtering.
@ -522,21 +520,25 @@ class TestRetrievalService:
mock_get_dataset.return_value = mock_dataset
filtered_docs = [sample_documents[0]]
def side_effect_embedding_search(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.extend(filtered_docs)
if all_documents is not None:
all_documents.extend(filtered_docs)
mock_embedding_search.side_effect = side_effect_embedding_search
mock_retrieve.side_effect = side_effect_retrieve
document_ids_filter = [sample_documents[0].metadata["document_id"]]
# Act
@ -552,12 +554,12 @@ class TestRetrievalService:
assert len(results) == 1
assert results[0].metadata["doc_id"] == "doc1"
# Verify document_ids_filter was passed
call_kwargs = mock_embedding_search.call_args.kwargs
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["document_ids_filter"] == document_ids_filter
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test vector search when no results match the query.
@ -567,8 +569,8 @@ class TestRetrievalService:
"""
# Arrange
mock_get_dataset.return_value = mock_dataset
# embedding_search doesn't add anything to all_documents
mock_embedding_search.side_effect = lambda *args, **kwargs: None
# _retrieve doesn't add anything to all_documents
mock_retrieve.side_effect = lambda *args, **kwargs: None
# Act
results = RetrievalService.retrieve(
@ -583,9 +585,9 @@ class TestRetrievalService:
# ==================== Keyword Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents):
def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test basic keyword search functionality.
@ -597,12 +599,25 @@ class TestRetrievalService:
# Arrange
mock_get_dataset.return_value = mock_dataset
def side_effect_keyword_search(
flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
def side_effect_retrieve(
flask_app,
retrieval_method,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.extend(sample_documents)
if all_documents is not None:
all_documents.extend(sample_documents)
mock_keyword_search.side_effect = side_effect_keyword_search
mock_retrieve.side_effect = side_effect_retrieve
query = "Python programming"
top_k = 3
@ -618,7 +633,7 @@ class TestRetrievalService:
# Assert
assert len(results) == 3
assert all(isinstance(doc, Document) for doc in results)
mock_keyword_search.assert_called_once()
mock_retrieve.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
@ -1147,11 +1162,9 @@ class TestRetrievalService:
# ==================== Metadata Filtering Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_metadata_filter(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test vector search with metadata-based document filtering.
@ -1166,21 +1179,25 @@ class TestRetrievalService:
filtered_doc = sample_documents[0]
filtered_doc.metadata["category"] = "programming"
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.append(filtered_doc)
if all_documents is not None:
all_documents.append(filtered_doc)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
# Act
results = RetrievalService.retrieve(
@ -1243,9 +1260,9 @@ class TestRetrievalService:
# Assert
assert results == []
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test that exceptions during retrieval are properly handled.
@ -1256,22 +1273,26 @@ class TestRetrievalService:
# Arrange
mock_get_dataset.return_value = mock_dataset
# Make embedding_search add an exception to the exceptions list
# Make _retrieve add an exception to the exceptions list
def side_effect_with_exception(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
exceptions.append("Search failed")
if exceptions is not None:
exceptions.append("Search failed")
mock_embedding_search.side_effect = side_effect_with_exception
mock_retrieve.side_effect = side_effect_with_exception
# Act & Assert
with pytest.raises(ValueError) as exc_info:
@ -1286,9 +1307,9 @@ class TestRetrievalService:
# ==================== Score Threshold Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test vector search with score threshold filtering.
@ -1306,21 +1327,25 @@ class TestRetrievalService:
provider="dify",
)
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.append(high_score_doc)
if all_documents is not None:
all_documents.append(high_score_doc)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
score_threshold = 0.8
@ -1339,9 +1364,9 @@ class TestRetrievalService:
# ==================== Top-K Limiting Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test that retrieval respects top_k parameter.
@ -1362,22 +1387,26 @@ class TestRetrievalService:
for i in range(10)
]
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
# Return only top_k documents
all_documents.extend(many_docs[:top_k])
if all_documents is not None:
all_documents.extend(many_docs[:top_k])
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
top_k = 3
@ -1390,9 +1419,9 @@ class TestRetrievalService:
)
# Assert
# Verify top_k was passed to embedding_search
assert mock_embedding_search.called
call_kwargs = mock_embedding_search.call_args.kwargs
# Verify _retrieve was called
assert mock_retrieve.called
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["top_k"] == top_k
# Verify we got the right number of results
assert len(results) == top_k
@ -1421,11 +1450,9 @@ class TestRetrievalService:
# ==================== Reranking Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_semantic_search_with_reranking(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test semantic search with reranking model.
@ -1439,22 +1466,26 @@ class TestRetrievalService:
# Simulate reranking changing order
reranked_docs = list(reversed(sample_documents))
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
# embedding_search handles reranking internally
all_documents.extend(reranked_docs)
# _retrieve handles reranking internally
if all_documents is not None:
all_documents.extend(reranked_docs)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
reranking_model = {
"reranking_provider_name": "cohere",
@ -1473,7 +1504,7 @@ class TestRetrievalService:
# Assert
# For semantic search with reranking, reranking_model should be passed
assert len(results) == 3
call_kwargs = mock_embedding_search.call_args.kwargs
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["reranking_model"] == reranking_model

View File

@ -0,0 +1,86 @@
import pytest
import core.tools.utils.message_transformer as mt
from core.tools.entities.tool_entities import ToolInvokeMessage
class _FakeToolFile:
def __init__(self, mimetype: str):
self.id = "fake-tool-file-id"
self.mimetype = mimetype
class _FakeToolFileManager:
"""Fake ToolFileManager to capture the mimetype passed in."""
last_call: dict | None = None
def __init__(self, *args, **kwargs):
pass
def create_file_by_raw(
self,
*,
user_id: str,
tenant_id: str,
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: str | None = None,
):
type(self).last_call = {
"user_id": user_id,
"tenant_id": tenant_id,
"conversation_id": conversation_id,
"file_binary": file_binary,
"mimetype": mimetype,
"filename": filename,
}
return _FakeToolFile(mimetype)
@pytest.fixture(autouse=True)
def _patch_tool_file_manager(monkeypatch):
# Patch the manager used inside the transformer module
monkeypatch.setattr(mt, "ToolFileManager", _FakeToolFileManager)
# also ensure predictable URL generation (no need to patch; uses id and extension only)
yield
_FakeToolFileManager.last_call = None
def _gen(messages):
yield from messages
def test_transform_tool_invoke_messages_mimetype_key_present_but_none():
# Arrange: a BLOB message whose meta contains a mime_type key set to None
blob = b"hello"
msg = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=blob),
meta={"mime_type": None, "filename": "greeting"},
)
# Act
out = list(
mt.ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=_gen([msg]),
user_id="u1",
tenant_id="t1",
conversation_id="c1",
)
)
# Assert: default to application/octet-stream when mime_type is present but None
assert _FakeToolFileManager.last_call is not None
assert _FakeToolFileManager.last_call["mimetype"] == "application/octet-stream"
# Should yield a BINARY_LINK (not IMAGE_LINK) and the URL ends with .bin
assert len(out) == 1
o = out[0]
assert o.type == ToolInvokeMessage.MessageType.BINARY_LINK
assert isinstance(o.message, ToolInvokeMessage.TextMessage)
assert o.message.text.endswith(".bin")
# meta is preserved (still contains mime_type: None)
assert "mime_type" in (o.meta or {})
assert o.meta["mime_type"] is None

View File

@ -0,0 +1,60 @@
"""
Test case for end node without value_type field (backward compatibility).
This test validates that end nodes work correctly even when the value_type
field is missing from the output configuration, ensuring backward compatibility
with older workflow definitions.
"""
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from .test_table_runner import TableTestRunner, WorkflowTestCase
def test_end_node_without_value_type_field():
"""
Test that end node works without explicit value_type field.
The fixture implements a simple workflow that:
1. Takes a query input from start node
2. Passes it directly to end node
3. End node outputs the value without specifying value_type
4. Should correctly infer the type and output the value
This ensures backward compatibility with workflow definitions
created before value_type became a required field.
"""
fixture_name = "end_node_without_value_type_field_workflow"
case = WorkflowTestCase(
fixture_path=fixture_name,
inputs={"query": "test query"},
expected_outputs={"query": "test query"},
expected_event_sequence=[
# Graph start
GraphRunStartedEvent,
# Start node
NodeRunStartedEvent,
NodeRunStreamChunkEvent, # Start node streams the input value
NodeRunSucceededEvent,
# End node
NodeRunStartedEvent,
NodeRunSucceededEvent,
# Graph end
GraphRunSucceededEvent,
],
description="End node without value_type field should work correctly",
)
runner = TableTestRunner()
result = runner.run_test_case(case)
assert result.success, f"Test failed: {result.error}"
assert result.actual_outputs == {"query": "test query"}, (
f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}"
)

View File

@ -1,3 +1,4 @@
import json
from unittest.mock import Mock, PropertyMock, patch
import httpx
@ -138,3 +139,95 @@ def test_is_file_with_no_content_disposition(mock_response):
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
response = Response(mock_response)
assert response.is_file
# UTF-8 Encoding Tests
@pytest.mark.parametrize(
("content_bytes", "expected_text", "description"),
[
# Chinese UTF-8 bytes
(
b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}',
'{"message": "你好世界"}',
"Chinese characters UTF-8",
),
# Japanese UTF-8 bytes
(
b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}',
'{"message": "こんにちは"}',
"Japanese characters UTF-8",
),
# Korean UTF-8 bytes
(
b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}',
'{"message": "안녕하세요"}',
"Korean characters UTF-8",
),
# Arabic UTF-8
(b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"),
# European characters UTF-8
(b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"),
# Simple ASCII
(b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"),
],
)
def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description):
"""Test that Response.text properly decodes UTF-8 content with charset_normalizer"""
mock_response.headers = {"content-type": "application/json; charset=utf-8"}
type(mock_response).content = PropertyMock(return_value=content_bytes)
# Mock httpx response.text to return something different (simulating potential encoding issues)
mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property
response = Response(mock_response)
# Our enhanced text property should decode properly using charset_normalizer
assert response.text == expected_text, (
f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}"
)
def test_text_property_fallback_to_httpx(mock_response):
"""Test that Response.text falls back to httpx.text when charset_normalizer fails"""
mock_response.headers = {"content-type": "application/json"}
# Create malformed UTF-8 bytes
malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}'
type(mock_response).content = PropertyMock(return_value=malformed_bytes)
# Mock httpx.text to return some fallback value
fallback_text = '{"text": "fallback"}'
mock_response.text = fallback_text
response = Response(mock_response)
# Should fall back to httpx's text when charset_normalizer fails
assert response.text == fallback_text
@pytest.mark.parametrize(
("json_content", "description"),
[
# JSON with escaped Unicode (like Flask jsonify())
('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"),
# JSON with mixed escape sequences and UTF-8
('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"),
# JSON with complex escape sequences
('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"),
],
)
def test_text_property_with_escaped_unicode(mock_response, json_content, description):
"""Test Response.text with JSON containing Unicode escape sequences"""
mock_response.headers = {"content-type": "application/json"}
content_bytes = json_content.encode("utf-8")
type(mock_response).content = PropertyMock(return_value=content_bytes)
mock_response.text = json_content # httpx would return the same for valid UTF-8
response = Response(mock_response)
# Should preserve the escape sequences (valid JSON)
assert response.text == json_content, f"Failed for {description}"
# The text should be valid JSON that can be parsed back to proper Unicode
parsed = json.loads(response.text)
assert isinstance(parsed, dict), f"Invalid JSON for {description}"

View File

@ -1149,3 +1149,258 @@ class TestModelIntegration:
# Assert
assert site.app_id == app.id
assert app.enable_site is True
class TestConversationStatusCount:
"""Test suite for Conversation.status_count property N+1 query fix."""
def test_status_count_no_messages(self):
"""Test status_count returns None when conversation has no messages."""
# Arrange
conversation = Conversation(
app_id=str(uuid4()),
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = str(uuid4())
# Mock the database query to return no messages
with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
result = conversation.status_count
# Assert
assert result is None
def test_status_count_messages_without_workflow_runs(self):
"""Test status_count when messages have no workflow_run_id."""
# Arrange
app_id = str(uuid4())
conversation_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
# Mock the database query to return no messages with workflow_run_id
with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
result = conversation.status_count
# Assert
assert result is None
def test_status_count_batch_loading_implementation(self):
"""Test that status_count uses batch loading instead of N+1 queries."""
# Arrange
from core.workflow.enums import WorkflowExecutionStatus
app_id = str(uuid4())
conversation_id = str(uuid4())
# Create workflow run IDs
workflow_run_id_1 = str(uuid4())
workflow_run_id_2 = str(uuid4())
workflow_run_id_3 = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
# Mock messages with workflow_run_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_1,
),
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_2,
),
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_3,
),
]
# Mock workflow runs with different statuses
mock_workflow_runs = [
MagicMock(
id=workflow_run_id_1,
status=WorkflowExecutionStatus.SUCCEEDED.value,
app_id=app_id,
),
MagicMock(
id=workflow_run_id_2,
status=WorkflowExecutionStatus.FAILED.value,
app_id=app_id,
),
MagicMock(
id=workflow_run_id_3,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
app_id=app_id,
),
]
# Track database calls
calls_made = []
def mock_scalars(query):
calls_made.append(str(query))
mock_result = MagicMock()
# Return messages for the first query (messages with workflow_run_id)
if "messages" in str(query) and "conversation_id" in str(query):
mock_result.all.return_value = mock_messages
# Return workflow runs for the batch query
elif "workflow_runs" in str(query):
mock_result.all.return_value = mock_workflow_runs
else:
mock_result.all.return_value = []
return mock_result
# Act & Assert
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count
# Verify only 2 database queries were made (not N+1)
assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}"
# Verify the first query gets messages
assert "messages" in calls_made[0]
assert "conversation_id" in calls_made[0]
# Verify the second query batch loads workflow runs with proper filtering
assert "workflow_runs" in calls_made[1]
assert "app_id" in calls_made[1] # Security filter applied
assert "IN" in calls_made[1] # Batch loading with IN clause
# Verify correct status counts
assert result["success"] == 1 # One SUCCEEDED
assert result["failed"] == 1 # One FAILED
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
def test_status_count_app_id_filtering(self):
"""Test that status_count filters workflow runs by app_id for security."""
# Arrange
app_id = str(uuid4())
other_app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_run_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
# Mock message with workflow_run_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id,
),
]
calls_made = []
def mock_scalars(query):
calls_made.append(str(query))
mock_result = MagicMock()
if "messages" in str(query):
mock_result.all.return_value = mock_messages
elif "workflow_runs" in str(query):
# Return empty list because no workflow run matches the correct app_id
mock_result.all.return_value = [] # Workflow run filtered out by app_id
else:
mock_result.all.return_value = []
return mock_result
# Act
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count
# Assert - query should include app_id filter
workflow_query = calls_made[1]
assert "app_id" in workflow_query
# Since workflow run has wrong app_id, it shouldn't be included in counts
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0
def test_status_count_handles_invalid_workflow_status(self):
"""Test that status_count gracefully handles invalid workflow status values."""
# Arrange
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_run_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id,
),
]
# Mock workflow run with invalid status
mock_workflow_runs = [
MagicMock(
id=workflow_run_id,
status="invalid_status", # Invalid status that should raise ValueError
app_id=app_id,
),
]
with patch("models.model.db.session.scalars") as mock_scalars:
# Mock the messages query
def mock_scalars_side_effect(query):
mock_result = MagicMock()
if "messages" in str(query):
mock_result.all.return_value = mock_messages
elif "workflow_runs" in str(query):
mock_result.all.return_value = mock_workflow_runs
else:
mock_result.all.return_value = []
return mock_result
mock_scalars.side_effect = mock_scalars_side_effect
# Act - should not raise exception
result = conversation.status_count
# Assert - should handle invalid status gracefully
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0

View File

@ -14,7 +14,9 @@ def get_example_bucket() -> str:
def get_opendal_bucket() -> str:
return "./dify"
import os
return os.environ.get("OPENDAL_FS_ROOT", "/tmp/dify-storage")
def get_example_filename() -> str:

View File

@ -21,20 +21,16 @@ class TestOpenDAL:
)
@pytest.fixture(scope="class", autouse=True)
def teardown_class(self, request):
def teardown_class(self):
"""Clean up after all tests in the class."""
def cleanup():
folder = Path(get_opendal_bucket())
if folder.exists() and folder.is_dir():
for item in folder.iterdir():
if item.is_file():
item.unlink()
elif item.is_dir():
item.rmdir()
folder.rmdir()
yield
return cleanup()
folder = Path(get_opendal_bucket())
if folder.exists() and folder.is_dir():
import shutil
shutil.rmtree(folder, ignore_errors=True)
def test_save_and_exists(self):
"""Test saving data and checking existence."""

View File

@ -117,7 +117,7 @@ import pytest
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
# ============================================================================
# Test Data Factory
@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy:
# Features Property Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""
Test cached_property features.
@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy:
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_features_property_with_different_tenants(self, mock_feature_service):
"""
Test features property with different tenant IDs.
@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy:
# Direct Queue Routing Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue(self, mock_task):
"""
Test _send_to_direct_queue method.
@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_direct_queue_with_priority_task(self, mock_task):
"""
Test _send_to_direct_queue with priority task function.
@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue_with_single_document(self, mock_task):
"""
Test _send_to_direct_queue with single document ID.
@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue_with_empty_documents(self, mock_task):
"""
Test _send_to_direct_queue with empty document_ids list.
@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy:
# Tenant Queue Routing Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""
Test _send_to_tenant_queue when task key exists.
@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy:
mock_task.delay.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""
Test _send_to_tenant_queue when no task key exists.
@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy:
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_tenant_queue_with_priority_task(self, mock_task):
"""
Test _send_to_tenant_queue with priority task function.
@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_document_task_serialization(self, mock_task):
"""
Test DocumentTask serialization in _send_to_tenant_queue.
@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy:
# Queue Type Selection Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""
Test _send_to_default_tenant_queue method.
@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""
Test _send_to_priority_tenant_queue method.
@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""
Test _send_to_priority_direct_queue method.
@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy:
# Dispatch Logic Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""
Test _dispatch method when billing is enabled with SANDBOX plan.
@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service):
"""
Test _dispatch method when billing is enabled with TEAM plan.
@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service):
"""
Test _dispatch method when billing is enabled with PROFESSIONAL plan.
@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""
Test _dispatch method when billing is disabled.
@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""
Test _dispatch method with empty plan string.
@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""
Test _dispatch method with None plan.
@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy:
# Delay Method Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_delay_method(self, mock_feature_service):
"""
Test delay method integration.
@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_delay_method_with_team_plan(self, mock_feature_service):
"""
Test delay method with TEAM plan.
@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_delay_method_with_billing_disabled(self, mock_feature_service):
"""
Test delay method with billing disabled.
@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy:
# Batch Operations Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_batch_operation_with_multiple_documents(self, mock_task):
"""
Test batch operation with multiple documents.
@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_batch_operation_with_large_batch(self, mock_task):
"""
Test batch operation with large batch of documents.
@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy:
# Error Handling Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue_task_delay_failure(self, mock_task):
"""
Test _send_to_direct_queue when task.delay() raises an exception.
@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy:
with pytest.raises(Exception, match="Task delay failed"):
proxy._send_to_direct_queue(mock_task)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_push_tasks_failure(self, mock_task):
"""
Test _send_to_tenant_queue when push_tasks raises an exception.
@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy:
with pytest.raises(Exception, match="Push tasks failed"):
proxy._send_to_tenant_queue(mock_task)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task):
"""
Test _send_to_tenant_queue when set_task_waiting_time raises an exception.
@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy:
with pytest.raises(Exception, match="Set waiting time failed"):
proxy._send_to_tenant_queue(mock_task)
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_feature_service_failure(self, mock_feature_service):
"""
Test _dispatch when FeatureService.get_features raises an exception.
@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy:
# Integration Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service):
"""
Test full flow for SANDBOX plan with tenant queue.
@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_full_flow_team_plan(self, mock_task, mock_feature_service):
"""
Test full flow for TEAM plan with priority tenant queue.
@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_full_flow_billing_disabled(self, mock_task, mock_feature_service):
"""
Test full flow for billing disabled (self-hosted/enterprise).

View File

@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
class DocumentIndexingTaskProxyTestDataFactory:
@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy:
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy:
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy:
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy:
)
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_default_tenant_queue(self, mock_task):
def test_send_to_default_tenant_queue(self):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy:
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_tenant_queue(self, mock_task):
def test_send_to_priority_tenant_queue(self):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy:
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_direct_queue(self, mock_task):
def test_send_to_priority_direct_queue(self):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy:
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy:
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy:
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy:
assert task.dataset_id == dataset_id
assert task.document_ids == document_ids
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange

View File

@ -0,0 +1,363 @@
from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import (
DuplicateDocumentIndexingTaskProxy,
)
class DuplicateDocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_duplicate_document_task_proxy(
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
) -> DuplicateDocumentIndexingTaskProxy:
"""Create DuplicateDocumentIndexingTaskProxy instance for testing."""
if document_ids is None:
document_ids = ["doc-1", "doc-2", "doc-3"]
return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
class TestDuplicateDocumentIndexingTaskProxy:
"""Test cases for DuplicateDocumentIndexingTaskProxy class."""
def test_initialization(self):
"""Test DuplicateDocumentIndexingTaskProxy initialization."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2", "doc-3"]
# Act
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing"
def test_queue_name(self):
"""Test QUEUE_NAME class variable."""
# Arrange & Act
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
# Assert
assert proxy.QUEUE_NAME == "duplicate_document_indexing"
def test_task_functions(self):
"""Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables."""
# Arrange & Act
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
# Assert
assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task"
assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task"
@patch("services.document_indexing_proxy.base.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch(
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
)
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(mock_task)
# Assert
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch(
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
)
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
assert len(pushed_tasks) == 1
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch(
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
)
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
def test_send_to_default_tenant_queue(self):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
def test_send_to_priority_tenant_queue(self):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
def test_send_to_priority_direct_queue(self):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_direct_queue = Mock()
# Act
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
# Act
proxy._dispatch()
# Assert
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy.delay()
# Assert
# If billing enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=""
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=None
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
def test_initialization_with_empty_document_ids(self):
"""Test initialization with empty document_ids list."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = []
# Act
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
def test_initialization_with_single_document_id(self):
"""Test initialization with single document_id."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1"]
# Act
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
def test_initialization_with_large_batch(self):
"""Test initialization with large batch of document IDs."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = [f"doc-{i}" for i in range(100)]
# Act
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
assert len(proxy._document_ids) == 100
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_professional_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with professional plan."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.PROFESSIONAL
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()

View File

@ -6,6 +6,7 @@ Target: 1500+ lines of comprehensive test coverage.
"""
import json
import re
from datetime import datetime
from unittest.mock import MagicMock, Mock, patch
@ -1791,8 +1792,8 @@ class TestExternalDatasetServiceFetchRetrieval:
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_non_200_status(self, mock_db, mock_process, factory):
"""Test retrieval returns empty list on non-200 status."""
def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory):
"""Test that non-200 status code raises Exception with response text."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
@ -1817,12 +1818,103 @@ class TestExternalDatasetServiceFetchRetrieval:
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error: Database connection failed"
mock_process.return_value = mock_response
# Act
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
"tenant-123", "dataset-123", "query", {"top_k": 5}
)
# Act & Assert
with pytest.raises(Exception, match="Internal Server Error: Database connection failed"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
"tenant-123", "dataset-123", "query", {"top_k": 5}
)
# Assert
assert result == []
@pytest.mark.parametrize(
("status_code", "error_message"),
[
(400, "Bad Request: Invalid query parameters"),
(401, "Unauthorized: Invalid API key"),
(403, "Forbidden: Access denied to resource"),
(404, "Not Found: Knowledge base not found"),
(429, "Too Many Requests: Rate limit exceeded"),
(500, "Internal Server Error: Database connection failed"),
(502, "Bad Gateway: External service unavailable"),
(503, "Service Unavailable: Maintenance mode"),
],
)
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_various_error_status_codes(
self, mock_db, mock_process, factory, status_code, error_message
):
"""Test that various error status codes raise exceptions with response text."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-123"
binding = factory.create_external_knowledge_binding_mock(
dataset_id=dataset_id, external_knowledge_api_id="api-123"
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = error_message
mock_process.return_value = mock_response
# Act & Assert
with pytest.raises(ValueError, match=re.escape(error_message)):
ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5})
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory):
"""Test exception with empty response text."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_response = MagicMock()
mock_response.status_code = 503
mock_response.text = ""
mock_process.return_value = mock_response
# Act & Assert
with pytest.raises(Exception, match=""):
ExternalDatasetService.fetch_external_knowledge_retrieval(
"tenant-123", "dataset-123", "query", {"top_k": 5}
)

View File

@ -2,8 +2,6 @@ from pathlib import Path
from unittest.mock import Mock, create_autospec, patch
import pytest
from flask_restx import reqparse
from werkzeug.exceptions import BadRequest
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
@ -77,60 +75,39 @@ class TestMetadataBugCompleteValidation:
assert type_column.nullable is False, "type column should be nullable=False"
assert name_column.nullable is False, "name column should be nullable=False"
def test_4_fixed_api_layer_rejects_null(self, app):
"""Test Layer 4: Fixed API configuration properly rejects null values."""
# Test Console API create endpoint (fixed)
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
def test_4_fixed_api_layer_rejects_null(self):
"""Test Layer 4: Fixed API configuration properly rejects null values using Pydantic."""
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate({"type": None, "name": None})
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
with pytest.raises(BadRequest):
parser.parse_args()
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate({"type": "string", "name": None})
# Test with just name being null
with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"):
with pytest.raises(BadRequest):
parser.parse_args()
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate({"type": None, "name": "test"})
# Test with just type being null
with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"):
with pytest.raises(BadRequest):
parser.parse_args()
def test_5_fixed_api_accepts_valid_values(self, app):
def test_5_fixed_api_accepts_valid_values(self):
"""Test that fixed API still accepts valid non-null values."""
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = MetadataArgs.model_validate({"type": "string", "name": "valid_name"})
assert args.type == "string"
assert args.name == "valid_name"
with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"):
args = parser.parse_args()
assert args["type"] == "string"
assert args["name"] == "valid_name"
def test_6_simulated_buggy_behavior(self):
"""Test simulating the original buggy behavior by bypassing Pydantic validation."""
mock_metadata_args = Mock()
mock_metadata_args.name = None
mock_metadata_args.type = None
def test_6_simulated_buggy_behavior(self, app):
"""Test simulating the original buggy behavior with nullable=True."""
# Simulate the old buggy configuration
buggy_parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
# This would pass in the buggy version
args = buggy_parser.parse_args()
assert args["type"] is None
assert args["name"] is None
# But would crash when trying to create MetadataArgs
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate(args)
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_7_end_to_end_validation_layers(self):
"""Test all validation layers work together correctly."""

View File

@ -1,7 +1,6 @@
from unittest.mock import Mock, create_autospec, patch
import pytest
from flask_restx import reqparse
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
@ -51,76 +50,16 @@ class TestMetadataNullableBug:
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
def test_api_parser_accepts_null_values(self, app):
"""Test that API parser configuration incorrectly accepts null values."""
# Simulate the current API parser configuration
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
def test_api_layer_now_uses_pydantic_validation(self):
"""Verify that API layer relies on Pydantic validation instead of reqparse."""
invalid_payload = {"type": None, "name": None}
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate(invalid_payload)
# Simulate request data with null values
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
# This should parse successfully due to nullable=True
args = parser.parse_args()
# Verify that null values are accepted
assert args["type"] is None
assert args["name"] is None
# This demonstrates the bug: API accepts None but business logic will crash
def test_integration_bug_scenario(self, app):
"""Test the complete bug scenario from API to service layer."""
# Step 1: API parser accepts null values (current buggy behavior)
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
args = parser.parse_args()
# Step 2: Try to create MetadataArgs with None values
# This should fail at Pydantic validation level
with pytest.raises((ValueError, TypeError)):
metadata_args = MetadataArgs.model_validate(args)
# Step 3: If we bypass Pydantic (simulating the bug scenario)
# Move this outside the request context to avoid Flask-Login issues
mock_metadata_args = Mock()
mock_metadata_args.name = None # From args["name"]
mock_metadata_args.type = None # From args["type"]
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# Step 4: Service layer crashes on len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_correct_nullable_false_configuration_works(self, app):
"""Test that the correct nullable=False configuration works as expected."""
# This tests the FIXED configuration
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
# This should fail with BadRequest due to nullable=False
from werkzeug.exceptions import BadRequest
with pytest.raises(BadRequest):
parser.parse_args()
valid_payload = {"type": "string", "name": "valid"}
args = MetadataArgs.model_validate(valid_payload)
assert args.type == "string"
assert args.name == "valid"
if __name__ == "__main__":

View File

@ -0,0 +1,88 @@
import types
import pytest
from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod
from models.provider import ProviderType
from services.model_provider_service import ModelProviderService
class _FakeConfigurations:
def __init__(self, provider_configuration: types.SimpleNamespace) -> None:
self._provider_configuration = provider_configuration
def values(self) -> list[types.SimpleNamespace]:
return [self._provider_configuration]
@pytest.fixture
def service_with_fake_configurations():
# Build a fake provider schema with minimal fields used by ProviderResponse
fake_provider = types.SimpleNamespace(
provider="langgenius/openai_api_compatible/openai_api_compatible",
label=I18nObject(en_US="OpenAI API Compatible", zh_Hans="OpenAI API Compatible"),
description=None,
icon_small=None,
icon_small_dark=None,
icon_large=None,
background=None,
help=None,
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL],
provider_credential_schema=None,
model_credential_schema=None,
)
# Include decrypted credentials to simulate the leak source
custom_model = CustomModelConfiguration(
model="gpt-4o-mini",
model_type=ModelType.LLM,
credentials={"api_key": "sk-plain-text", "endpoint": "https://example.com"},
current_credential_id="cred-1",
current_credential_name="API KEY 1",
available_model_credentials=[],
unadded_to_model_list=False,
)
fake_custom_provider = types.SimpleNamespace(
current_credential_id="cred-1",
current_credential_name="API KEY 1",
available_credentials=[CredentialConfiguration(credential_id="cred-1", credential_name="API KEY 1")],
)
fake_custom_configuration = types.SimpleNamespace(
provider=fake_custom_provider, models=[custom_model], can_added_models=[]
)
fake_system_configuration = types.SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[])
fake_provider_configuration = types.SimpleNamespace(
provider=fake_provider,
preferred_provider_type=ProviderType.CUSTOM,
custom_configuration=fake_custom_configuration,
system_configuration=fake_system_configuration,
is_custom_configuration_available=lambda: True,
)
class _FakeProviderManager:
def get_configurations(self, tenant_id: str) -> _FakeConfigurations:
return _FakeConfigurations(fake_provider_configuration)
svc = ModelProviderService()
svc.provider_manager = _FakeProviderManager()
return svc
def test_get_provider_list_strips_credentials(service_with_fake_configurations: ModelProviderService):
providers = service_with_fake_configurations.get_provider_list(tenant_id="tenant-1", model_type=None)
assert len(providers) == 1
custom_models = providers[0].custom_configuration.custom_models
assert custom_models is not None
assert len(custom_models) == 1
# The sanitizer should drop credentials in list response
assert custom_models[0].credentials is None

File diff suppressed because it is too large Load Diff

View File

@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
from tasks.document_indexing_task import (
_document_indexing,
_document_indexing_with_tenant_queue,
@ -138,7 +138,9 @@ class TestTaskEnqueuing:
with patch.object(DocumentIndexingTaskProxy, "features") as mock_features:
mock_features.billing.enabled = False
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Act
@ -163,7 +165,9 @@ class TestTaskEnqueuing:
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task):
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Act
@ -187,7 +191,9 @@ class TestTaskEnqueuing:
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Act
@ -211,7 +217,9 @@ class TestTaskEnqueuing:
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Act
@ -1493,7 +1501,9 @@ class TestEdgeCases:
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
# Act - Enqueue multiple tasks rapidly
for doc_ids in document_ids_list:
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids)
@ -1898,7 +1908,7 @@ class TestRobustness:
- Error is propagated appropriately
"""
# Arrange
with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features:
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features:
# Simulate FeatureService failure
mock_get_features.side_effect = Exception("Feature service unavailable")

View File

@ -0,0 +1,112 @@
"""
Unit tests for delete_account_task.
Covers:
- Billing enabled with existing account: calls billing and sends success email
- Billing disabled with existing account: skips billing, sends success email
- Account not found: still calls billing when enabled, does not send email
- Billing deletion raises: logs and re-raises, no email
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from tasks.delete_account_task import delete_account_task
@pytest.fixture
def mock_db_session():
"""Mock the db.session used in delete_account_task."""
with patch("tasks.delete_account_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
yield mock_session
@pytest.fixture
def mock_deps():
"""Patch external dependencies: BillingService and send_deletion_success_task."""
with (
patch("tasks.delete_account_task.BillingService") as mock_billing,
patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task,
):
# ensure .delay exists on the mail task
mock_mail_task.delay = MagicMock()
yield {
"billing": mock_billing,
"mail_task": mock_mail_task,
}
def _set_account_found(mock_db_session, email: str = "user@example.com"):
account = SimpleNamespace(email=email)
mock_db_session.query.return_value.where.return_value.first.return_value = account
return account
def _set_account_missing(mock_db_session):
mock_db_session.query.return_value.where.return_value.first.return_value = None
class TestDeleteAccountTask:
def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-123"
account = _set_account_found(mock_db_session, email="a@b.com")
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-456"
account = _set_account_found(mock_db_session, email="x@y.com")
# Disable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_not_called()
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog):
# Arrange
account_id = "missing-id"
_set_account_missing(mock_db_session)
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act
delete_account_task(account_id)
# Assert
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
mock_deps["mail_task"].delay.assert_not_called()
# Optional: verify log contains not found message
assert any("not found" in rec.getMessage().lower() for rec in caplog.records)
def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps):
# Arrange
account_id = "acc-err"
_set_account_found(mock_db_session, email="err@ex.com")
mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down")
# Enable billing
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
# Act & Assert
with pytest.raises(RuntimeError):
delete_account_task(account_id)
# Ensure email was not sent
mock_deps["mail_task"].delay.assert_not_called()

View File

@ -0,0 +1,567 @@
"""
Unit tests for duplicate document indexing tasks.
This module tests the duplicate document indexing task functionality including:
- Task enqueuing to different queues (normal, priority, tenant-isolated)
- Batch processing of multiple duplicate documents
- Progress tracking through task lifecycle
- Error handling and retry mechanisms
- Cleanup of old document data before re-indexing
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from models.dataset import Dataset, Document, DocumentSegment
from tasks.duplicate_document_indexing_task import (
_duplicate_document_indexing_task,
_duplicate_document_indexing_task_with_tenant_queue,
duplicate_document_indexing_task,
normal_duplicate_document_indexing_task,
priority_duplicate_document_indexing_task,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def tenant_id():
"""Generate a unique tenant ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def dataset_id():
"""Generate a unique dataset ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def document_ids():
"""Generate a list of document IDs for testing."""
return [str(uuid.uuid4()) for _ in range(3)]
@pytest.fixture
def mock_dataset(dataset_id, tenant_id):
"""Create a mock Dataset object."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@pytest.fixture
def mock_documents(document_ids, dataset_id):
"""Create mock Document objects."""
documents = []
for doc_id in document_ids:
doc = Mock(spec=Document)
doc.id = doc_id
doc.dataset_id = dataset_id
doc.indexing_status = "waiting"
doc.error = None
doc.stopped_at = None
doc.processing_started_at = None
doc.doc_form = "text_model"
documents.append(doc)
return documents
@pytest.fixture
def mock_document_segments(document_ids):
"""Create mock DocumentSegment objects."""
segments = []
for doc_id in document_ids:
for i in range(3):
segment = Mock(spec=DocumentSegment)
segment.id = str(uuid.uuid4())
segment.document_id = doc_id
segment.index_node_id = f"node-{doc_id}-{i}"
segments.append(segment)
return segments
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
@pytest.fixture
def mock_indexing_runner():
"""Mock IndexingRunner."""
with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class:
mock_runner = MagicMock(spec=IndexingRunner)
mock_runner_class.return_value = mock_runner
yield mock_runner
@pytest.fixture
def mock_feature_service():
"""Mock FeatureService."""
with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service:
mock_features = Mock()
mock_features.billing = Mock()
mock_features.billing.enabled = False
mock_features.vector_space = Mock()
mock_features.vector_space.size = 0
mock_features.vector_space.limit = 1000
mock_service.get_features.return_value = mock_features
yield mock_service
@pytest.fixture
def mock_index_processor_factory():
"""Mock IndexProcessorFactory."""
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_processor.clean = Mock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
yield mock_factory
@pytest.fixture
def mock_tenant_isolated_queue():
"""Mock TenantIsolatedTaskQueue."""
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class:
mock_queue = MagicMock(spec=TenantIsolatedTaskQueue)
mock_queue.pull_tasks.return_value = []
mock_queue.delete_task_key = Mock()
mock_queue.set_task_waiting_time = Mock()
mock_queue_class.return_value = mock_queue
yield mock_queue
# ============================================================================
# Tests for deprecated duplicate_document_indexing_task
# ============================================================================
class TestDuplicateDocumentIndexingTask:
"""Tests for the deprecated duplicate_document_indexing_task function."""
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids):
"""Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function."""
# Act
duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
mock_core_func.assert_called_once_with(dataset_id, document_ids)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id):
"""Test duplicate_document_indexing_task with empty document_ids list."""
# Arrange
document_ids = []
# Act
duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
mock_core_func.assert_called_once_with(dataset_id, document_ids)
# ============================================================================
# Tests for _duplicate_document_indexing_task core function
# ============================================================================
class TestDuplicateDocumentIndexingTaskCore:
"""Tests for the _duplicate_document_indexing_task core function."""
def test_successful_duplicate_document_indexing(
self,
mock_db_session,
mock_indexing_runner,
mock_feature_service,
mock_index_processor_factory,
mock_dataset,
mock_documents,
mock_document_segments,
dataset_id,
document_ids,
):
"""Test successful duplicate document indexing flow."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Verify IndexingRunner was called
mock_indexing_runner.run.assert_called_once()
# Verify all documents were set to parsing status
for doc in mock_documents:
assert doc.indexing_status == "parsing"
assert doc.processing_started_at is not None
# Verify session operations
assert mock_db_session.commit.called
assert mock_db_session.close.called
def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids):
"""Test duplicate document indexing when dataset is not found."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = None
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Should close the session at least once
assert mock_db_session.close.called
def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan(
self,
mock_db_session,
mock_feature_service,
mock_dataset,
dataset_id,
document_ids,
):
"""Test duplicate document indexing with billing enabled and sandbox plan."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_features = mock_feature_service.get_features.return_value
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# For sandbox plan with multiple documents, should fail
mock_db_session.commit.assert_called()
def test_duplicate_document_indexing_with_billing_limit_exceeded(
self,
mock_db_session,
mock_feature_service,
mock_dataset,
mock_documents,
dataset_id,
document_ids,
):
"""Test duplicate document indexing when billing limit is exceeded."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
mock_features = mock_feature_service.get_features.return_value
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.TEAM
mock_features.vector_space.size = 990
mock_features.vector_space.limit = 1000
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Should commit the session
assert mock_db_session.commit.called
# Should close the session
assert mock_db_session.close.called
def test_duplicate_document_indexing_runner_error(
self,
mock_db_session,
mock_indexing_runner,
mock_feature_service,
mock_index_processor_factory,
mock_dataset,
mock_documents,
dataset_id,
document_ids,
):
"""Test duplicate document indexing when IndexingRunner raises an error."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = []
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Should close the session even after error
mock_db_session.close.assert_called_once()
def test_duplicate_document_indexing_document_is_paused(
self,
mock_db_session,
mock_indexing_runner,
mock_feature_service,
mock_index_processor_factory,
mock_dataset,
mock_documents,
dataset_id,
document_ids,
):
"""Test duplicate document indexing when document is paused."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = []
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Should handle DocumentIsPausedError gracefully
mock_db_session.close.assert_called_once()
def test_duplicate_document_indexing_cleans_old_segments(
self,
mock_db_session,
mock_indexing_runner,
mock_feature_service,
mock_index_processor_factory,
mock_dataset,
mock_documents,
mock_document_segments,
dataset_id,
document_ids,
):
"""Test that duplicate document indexing cleans old segments."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Verify clean was called for each document
assert mock_processor.clean.call_count == len(mock_documents)
# Verify segments were deleted
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# ============================================================================
# Tests for tenant queue wrapper function
# ============================================================================
class TestDuplicateDocumentIndexingTaskWithTenantQueue:
"""Tests for _duplicate_document_indexing_task_with_tenant_queue function."""
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
def test_tenant_queue_wrapper_calls_core_function(
self,
mock_core_func,
mock_tenant_isolated_queue,
tenant_id,
dataset_id,
document_ids,
):
"""Test that tenant queue wrapper calls the core function."""
# Arrange
mock_task_func = Mock()
# Act
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
# Assert
mock_core_func.assert_called_once_with(dataset_id, document_ids)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
def test_tenant_queue_wrapper_deletes_key_when_no_tasks(
self,
mock_core_func,
mock_tenant_isolated_queue,
tenant_id,
dataset_id,
document_ids,
):
"""Test that tenant queue wrapper deletes task key when no more tasks."""
# Arrange
mock_task_func = Mock()
mock_tenant_isolated_queue.pull_tasks.return_value = []
# Act
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
# Assert
mock_tenant_isolated_queue.delete_task_key.assert_called_once()
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
def test_tenant_queue_wrapper_processes_next_tasks(
self,
mock_core_func,
mock_tenant_isolated_queue,
tenant_id,
dataset_id,
document_ids,
):
"""Test that tenant queue wrapper processes next tasks from queue."""
# Arrange
mock_task_func = Mock()
next_task = {
"tenant_id": tenant_id,
"dataset_id": dataset_id,
"document_ids": document_ids,
}
mock_tenant_isolated_queue.pull_tasks.return_value = [next_task]
# Act
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
# Assert
mock_tenant_isolated_queue.set_task_waiting_time.assert_called_once()
mock_task_func.delay.assert_called_once_with(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_ids=document_ids,
)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
def test_tenant_queue_wrapper_handles_core_function_error(
self,
mock_core_func,
mock_tenant_isolated_queue,
tenant_id,
dataset_id,
document_ids,
):
"""Test that tenant queue wrapper handles errors from core function."""
# Arrange
mock_task_func = Mock()
mock_core_func.side_effect = Exception("Core function error")
# Act
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
# Assert
# Should still check for next tasks even after error
mock_tenant_isolated_queue.pull_tasks.assert_called_once()
# ============================================================================
# Tests for normal_duplicate_document_indexing_task
# ============================================================================
class TestNormalDuplicateDocumentIndexingTask:
"""Tests for normal_duplicate_document_indexing_task function."""
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
def test_normal_task_calls_tenant_queue_wrapper(
self,
mock_wrapper_func,
tenant_id,
dataset_id,
document_ids,
):
"""Test that normal task calls tenant queue wrapper."""
# Act
normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
# Assert
mock_wrapper_func.assert_called_once_with(
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
def test_normal_task_with_empty_document_ids(
self,
mock_wrapper_func,
tenant_id,
dataset_id,
):
"""Test normal task with empty document_ids list."""
# Arrange
document_ids = []
# Act
normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
# Assert
mock_wrapper_func.assert_called_once_with(
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
)
# ============================================================================
# Tests for priority_duplicate_document_indexing_task
# ============================================================================
class TestPriorityDuplicateDocumentIndexingTask:
"""Tests for priority_duplicate_document_indexing_task function."""
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
def test_priority_task_calls_tenant_queue_wrapper(
self,
mock_wrapper_func,
tenant_id,
dataset_id,
document_ids,
):
"""Test that priority task calls tenant queue wrapper."""
# Act
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
# Assert
mock_wrapper_func.assert_called_once_with(
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
def test_priority_task_with_single_document(
self,
mock_wrapper_func,
tenant_id,
dataset_id,
):
"""Test priority task with single document."""
# Arrange
document_ids = ["doc-1"]
# Act
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
# Assert
mock_wrapper_func.assert_called_once_with(
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
def test_priority_task_with_large_batch(
self,
mock_wrapper_func,
tenant_id,
dataset_id,
):
"""Test priority task with large batch of documents."""
# Arrange
document_ids = [f"doc-{i}" for i in range(100)]
# Act
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
# Assert
mock_wrapper_func.assert_called_once_with(
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
)

View File

@ -8,10 +8,13 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
[
("...Hello, World!", "Hello, World!"),
("。测试中文标点", "测试中文标点"),
("!@#Test symbols", "Test symbols"),
# Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols"
# The pattern intentionally excludes ! as per #11868 fix
("@#Test symbols", "Test symbols"),
("Hello, World!", "Hello, World!"),
("", ""),
(" ", " "),
("【测试】", "【测试】"),
],
)
def test_remove_leading_symbols(input_text, expected_output):