mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
Merge branch 'main' into feat/agent-node-v2
This commit is contained in:
@ -26,16 +26,29 @@ redis_mock.hgetall = MagicMock(return_value={})
|
||||
redis_mock.hdel = MagicMock()
|
||||
redis_mock.incr = MagicMock(return_value=1)
|
||||
|
||||
# Ensure OpenDAL fs writes to tmp to avoid polluting workspace
|
||||
os.environ.setdefault("OPENDAL_SCHEME", "fs")
|
||||
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
|
||||
# Add the API directory to Python path to ensure proper imports
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, PROJECT_DIR)
|
||||
|
||||
# apply the mock to the Redis client in the Flask app
|
||||
from extensions import ext_redis
|
||||
|
||||
redis_patcher = patch.object(ext_redis, "redis_client", redis_mock)
|
||||
redis_patcher.start()
|
||||
|
||||
def _patch_redis_clients_on_loaded_modules():
|
||||
"""Ensure any module-level redis_client references point to the shared redis_mock."""
|
||||
|
||||
import sys
|
||||
|
||||
for module in list(sys.modules.values()):
|
||||
if module is None:
|
||||
continue
|
||||
if hasattr(module, "redis_client"):
|
||||
module.redis_client = redis_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -49,6 +62,15 @@ def _provide_app_context(app: Flask):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_redis_clients():
|
||||
"""Patch redis_client to MagicMock only for unit test executions."""
|
||||
|
||||
with patch.object(ext_redis, "redis_client", redis_mock):
|
||||
_patch_redis_clients_on_loaded_modules()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_redis_mock():
|
||||
"""reset the Redis mock before each test"""
|
||||
@ -63,3 +85,20 @@ def reset_redis_mock():
|
||||
redis_mock.hgetall.return_value = {}
|
||||
redis_mock.hdel.return_value = None
|
||||
redis_mock.incr.return_value = 1
|
||||
|
||||
# Keep any imported modules pointing at the mock between tests
|
||||
_patch_redis_clients_on_loaded_modules()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_secret_key():
|
||||
"""Ensure SECRET_KEY-dependent logic sees an empty config value by default."""
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
original = dify_config.SECRET_KEY
|
||||
dify_config.SECRET_KEY = ""
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dify_config.SECRET_KEY = original
|
||||
|
||||
@ -0,0 +1,344 @@
|
||||
"""
|
||||
Unit tests for annotation import security features.
|
||||
|
||||
Tests rate limiting, concurrency control, file validation, and other
|
||||
security features added to prevent DoS attacks on the annotation import endpoint.
|
||||
"""
|
||||
|
||||
import io
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class TestAnnotationImportRateLimiting:
|
||||
"""Test rate limiting for annotation import operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""Mock Redis client for testing."""
|
||||
with patch("controllers.console.wraps.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_account(self):
|
||||
"""Mock current account with tenant."""
|
||||
with patch("controllers.console.wraps.current_account_with_tenant") as mock:
|
||||
mock.return_value = (MagicMock(id="user_id"), "test_tenant_id")
|
||||
yield mock
|
||||
|
||||
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-minute rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-minute limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
|
||||
10, # Hour check
|
||||
]
|
||||
|
||||
@annotation_import_rate_limit
|
||||
def dummy_view():
|
||||
return "success"
|
||||
|
||||
# Should abort with 429
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
dummy_view()
|
||||
|
||||
# Verify it's a rate limit error
|
||||
assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value)
|
||||
|
||||
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-hour rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-hour limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
3, # Minute check (under limit)
|
||||
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR + 1, # Hour check (over limit)
|
||||
]
|
||||
|
||||
@annotation_import_rate_limit
|
||||
def dummy_view():
|
||||
return "success"
|
||||
|
||||
# Should abort with 429
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
dummy_view()
|
||||
|
||||
assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value)
|
||||
|
||||
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate being under both limits
|
||||
mock_redis.zcard.return_value = 2
|
||||
|
||||
@annotation_import_rate_limit
|
||||
def dummy_view():
|
||||
return "success"
|
||||
|
||||
# Should succeed
|
||||
result = dummy_view()
|
||||
assert result == "success"
|
||||
|
||||
# Verify Redis operations were called
|
||||
assert mock_redis.zadd.called
|
||||
assert mock_redis.zremrangebyscore.called
|
||||
|
||||
|
||||
class TestAnnotationImportConcurrencyControl:
|
||||
"""Test concurrency control for annotation import operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""Mock Redis client for testing."""
|
||||
with patch("controllers.console.wraps.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_account(self):
|
||||
"""Mock current account with tenant."""
|
||||
with patch("controllers.console.wraps.current_account_with_tenant") as mock:
|
||||
mock.return_value = (MagicMock(id="user_id"), "test_tenant_id")
|
||||
yield mock
|
||||
|
||||
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that concurrent task limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate max concurrent tasks already running
|
||||
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
|
||||
|
||||
@annotation_import_concurrency_limit
|
||||
def dummy_view():
|
||||
return "success"
|
||||
|
||||
# Should abort with 429
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
dummy_view()
|
||||
|
||||
assert "429" in str(exc_info.value) or "concurrent" in str(exc_info.value).lower()
|
||||
|
||||
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within concurrency limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate being under concurrent task limit
|
||||
mock_redis.zcard.return_value = 1
|
||||
|
||||
@annotation_import_concurrency_limit
|
||||
def dummy_view():
|
||||
return "success"
|
||||
|
||||
# Should succeed
|
||||
result = dummy_view()
|
||||
assert result == "success"
|
||||
|
||||
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
|
||||
"""Test that old/stale job entries are removed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
mock_redis.zcard.return_value = 0
|
||||
|
||||
@annotation_import_concurrency_limit
|
||||
def dummy_view():
|
||||
return "success"
|
||||
|
||||
dummy_view()
|
||||
|
||||
# Verify cleanup was called
|
||||
assert mock_redis.zremrangebyscore.called
|
||||
|
||||
|
||||
class TestAnnotationImportFileValidation:
|
||||
"""Test file validation in annotation import."""
|
||||
|
||||
def test_file_size_limit_enforced(self):
|
||||
"""Test that files exceeding size limit are rejected."""
|
||||
# Create a file larger than the limit
|
||||
max_size = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
large_content = b"x" * (max_size + 1024) # Exceed by 1KB
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(large_content), filename="test.csv", content_type="text/csv")
|
||||
|
||||
# Should be rejected in controller
|
||||
# This would be tested in integration tests with actual endpoint
|
||||
|
||||
def test_empty_file_rejected(self):
|
||||
"""Test that empty files are rejected."""
|
||||
file = FileStorage(stream=io.BytesIO(b""), filename="test.csv", content_type="text/csv")
|
||||
|
||||
# Should be rejected
|
||||
# This would be tested in integration tests
|
||||
|
||||
def test_non_csv_file_rejected(self):
|
||||
"""Test that non-CSV files are rejected."""
|
||||
file = FileStorage(stream=io.BytesIO(b"test"), filename="test.txt", content_type="text/plain")
|
||||
|
||||
# Should be rejected based on extension
|
||||
# This would be tested in integration tests
|
||||
|
||||
|
||||
class TestAnnotationImportServiceValidation:
|
||||
"""Test service layer validation for annotation import."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Mock application object."""
|
||||
app = MagicMock()
|
||||
app.id = "app_id"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.annotation_service.db.session") as mock:
|
||||
yield mock
|
||||
|
||||
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too many records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with too many records
|
||||
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
|
||||
csv_content = "question,answer\n"
|
||||
for i in range(max_records + 100):
|
||||
csv_content += f"Question {i},Answer {i}\n"
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
with patch("services.annotation_service.FeatureService") as mock_features:
|
||||
mock_features.get_features.return_value.billing.enabled = False
|
||||
|
||||
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
|
||||
|
||||
# Should return error about too many records
|
||||
assert "error_msg" in result
|
||||
assert "too many" in result["error_msg"].lower() or "maximum" in result["error_msg"].lower()
|
||||
|
||||
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too few valid records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with only header (no data rows)
|
||||
csv_content = "question,answer\n"
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
|
||||
|
||||
# Should return error about insufficient records
|
||||
assert "error_msg" in result
|
||||
assert "at least" in result["error_msg"].lower() or "minimum" in result["error_msg"].lower()
|
||||
|
||||
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
|
||||
"""Test that invalid CSV format is handled gracefully."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create invalid CSV content
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
|
||||
|
||||
# Should return error message
|
||||
assert "error_msg" in result
|
||||
|
||||
def test_valid_import_succeeds(self, mock_app, mock_db_session):
|
||||
"""Test that valid import request succeeds."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create valid CSV
|
||||
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
with patch("services.annotation_service.FeatureService") as mock_features:
|
||||
mock_features.get_features.return_value.billing.enabled = False
|
||||
|
||||
with patch("services.annotation_service.batch_import_annotations_task") as mock_task:
|
||||
with patch("services.annotation_service.redis_client"):
|
||||
result = AppAnnotationService.batch_import_app_annotations("app_id", file)
|
||||
|
||||
# Should return success response
|
||||
assert "job_id" in result
|
||||
assert "job_status" in result
|
||||
assert result["job_status"] == "waiting"
|
||||
assert "record_count" in result
|
||||
assert result["record_count"] == 2
|
||||
|
||||
|
||||
class TestAnnotationImportTaskOptimization:
|
||||
"""Test optimizations in batch import task."""
|
||||
|
||||
def test_task_has_timeout_configured(self):
|
||||
"""Test that task has proper timeout configuration."""
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
# Verify task configuration
|
||||
assert hasattr(batch_import_annotations_task, "time_limit")
|
||||
assert hasattr(batch_import_annotations_task, "soft_time_limit")
|
||||
|
||||
# Check timeout values are reasonable
|
||||
# Hard limit should be 6 minutes (360s)
|
||||
# Soft limit should be 5 minutes (300s)
|
||||
# Note: actual values depend on Celery configuration
|
||||
|
||||
|
||||
class TestConfigurationValues:
|
||||
"""Test that security configuration values are properly set."""
|
||||
|
||||
def test_rate_limit_configs_exist(self):
|
||||
"""Test that rate limit configurations are defined."""
|
||||
assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE")
|
||||
assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR")
|
||||
|
||||
assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE > 0
|
||||
assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR > 0
|
||||
|
||||
def test_file_size_limit_config_exists(self):
|
||||
"""Test that file size limit configuration is defined."""
|
||||
assert hasattr(dify_config, "ANNOTATION_IMPORT_FILE_SIZE_LIMIT")
|
||||
assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT > 0
|
||||
assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT <= 10 # Reasonable max (10MB)
|
||||
|
||||
def test_record_limit_configs_exist(self):
|
||||
"""Test that record limit configurations are defined."""
|
||||
assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_RECORDS")
|
||||
assert hasattr(dify_config, "ANNOTATION_IMPORT_MIN_RECORDS")
|
||||
|
||||
assert dify_config.ANNOTATION_IMPORT_MAX_RECORDS > 0
|
||||
assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS > 0
|
||||
assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS < dify_config.ANNOTATION_IMPORT_MAX_RECORDS
|
||||
|
||||
def test_concurrency_limit_config_exists(self):
|
||||
"""Test that concurrency limit configuration is defined."""
|
||||
assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_CONCURRENT")
|
||||
assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT > 0
|
||||
assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT <= 10 # Reasonable upper bound
|
||||
@ -125,7 +125,7 @@ class TestPartnerTenants:
|
||||
resource = PartnerTenants()
|
||||
|
||||
# Act & Assert
|
||||
# reqparse will raise BadRequest for missing required field
|
||||
# Validation should raise BadRequest for missing required field
|
||||
with pytest.raises(BadRequest):
|
||||
resource.put(partner_key_encoded)
|
||||
|
||||
|
||||
407
api/tests/unit_tests/controllers/console/test_admin.py
Normal file
407
api/tests/unit_tests/controllers/console/test_admin.py
Normal file
@ -0,0 +1,407 @@
|
||||
"""Final working unit tests for admin endpoints - tests business logic directly."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.console.admin import InsertExploreAppPayload
|
||||
from models.model import App, RecommendedApp
|
||||
|
||||
|
||||
class TestInsertExploreAppPayload:
|
||||
"""Test InsertExploreAppPayload validation."""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test creating payload with valid data."""
|
||||
payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"desc": "Test app description",
|
||||
"copyright": "© 2024 Test Company",
|
||||
"privacy_policy": "https://example.com/privacy",
|
||||
"custom_disclaimer": "Custom disclaimer text",
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
payload = InsertExploreAppPayload.model_validate(payload_data)
|
||||
|
||||
assert payload.app_id == payload_data["app_id"]
|
||||
assert payload.desc == payload_data["desc"]
|
||||
assert payload.copyright == payload_data["copyright"]
|
||||
assert payload.privacy_policy == payload_data["privacy_policy"]
|
||||
assert payload.custom_disclaimer == payload_data["custom_disclaimer"]
|
||||
assert payload.language == payload_data["language"]
|
||||
assert payload.category == payload_data["category"]
|
||||
assert payload.position == payload_data["position"]
|
||||
|
||||
def test_minimal_payload(self):
|
||||
"""Test creating payload with only required fields."""
|
||||
payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
payload = InsertExploreAppPayload.model_validate(payload_data)
|
||||
|
||||
assert payload.app_id == payload_data["app_id"]
|
||||
assert payload.desc is None
|
||||
assert payload.copyright is None
|
||||
assert payload.privacy_policy is None
|
||||
assert payload.custom_disclaimer is None
|
||||
assert payload.language == payload_data["language"]
|
||||
assert payload.category == payload_data["category"]
|
||||
assert payload.position == payload_data["position"]
|
||||
|
||||
def test_invalid_language(self):
|
||||
"""Test payload with invalid language code."""
|
||||
payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "invalid-lang",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
InsertExploreAppPayload.model_validate(payload_data)
|
||||
|
||||
|
||||
class TestAdminRequiredDecorator:
|
||||
"""Test admin_required decorator."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Mock dify_config
|
||||
self.dify_config_patcher = patch("controllers.console.admin.dify_config")
|
||||
self.mock_dify_config = self.dify_config_patcher.start()
|
||||
self.mock_dify_config.ADMIN_API_KEY = "test-admin-key"
|
||||
|
||||
# Mock extract_access_token
|
||||
self.token_patcher = patch("controllers.console.admin.extract_access_token")
|
||||
self.mock_extract_token = self.token_patcher.start()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up test fixtures."""
|
||||
self.dify_config_patcher.stop()
|
||||
self.token_patcher.stop()
|
||||
|
||||
def test_admin_required_success(self):
|
||||
"""Test successful admin authentication."""
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
@admin_required
|
||||
def test_view():
|
||||
return {"success": True}
|
||||
|
||||
self.mock_extract_token.return_value = "test-admin-key"
|
||||
result = test_view()
|
||||
assert result["success"] is True
|
||||
|
||||
def test_admin_required_invalid_token(self):
|
||||
"""Test admin_required with invalid token."""
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
@admin_required
|
||||
def test_view():
|
||||
return {"success": True}
|
||||
|
||||
self.mock_extract_token.return_value = "wrong-key"
|
||||
with pytest.raises(Unauthorized, match="API key is invalid"):
|
||||
test_view()
|
||||
|
||||
def test_admin_required_no_api_key_configured(self):
|
||||
"""Test admin_required when no API key is configured."""
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
self.mock_dify_config.ADMIN_API_KEY = None
|
||||
|
||||
@admin_required
|
||||
def test_view():
|
||||
return {"success": True}
|
||||
|
||||
with pytest.raises(Unauthorized, match="API key is invalid"):
|
||||
test_view()
|
||||
|
||||
def test_admin_required_missing_authorization_header(self):
|
||||
"""Test admin_required with missing authorization header."""
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
@admin_required
|
||||
def test_view():
|
||||
return {"success": True}
|
||||
|
||||
self.mock_extract_token.return_value = None
|
||||
with pytest.raises(Unauthorized, match="Authorization header is missing"):
|
||||
test_view()
|
||||
|
||||
|
||||
class TestExploreAppBusinessLogicDirect:
|
||||
"""Test the core business logic of explore app management directly."""
|
||||
|
||||
def test_data_fusion_logic(self):
|
||||
"""Test the data fusion logic between payload and site data."""
|
||||
# Test cases for different data scenarios
|
||||
test_cases = [
|
||||
{
|
||||
"name": "site_data_overrides_payload",
|
||||
"payload": {"desc": "Payload desc", "copyright": "Payload copyright"},
|
||||
"site": {"description": "Site desc", "copyright": "Site copyright"},
|
||||
"expected": {
|
||||
"desc": "Site desc",
|
||||
"copyright": "Site copyright",
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "payload_used_when_no_site",
|
||||
"payload": {"desc": "Payload desc", "copyright": "Payload copyright"},
|
||||
"site": None,
|
||||
"expected": {
|
||||
"desc": "Payload desc",
|
||||
"copyright": "Payload copyright",
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "empty_defaults_when_no_data",
|
||||
"payload": {},
|
||||
"site": None,
|
||||
"expected": {"desc": "", "copyright": "", "privacy_policy": "", "custom_disclaimer": ""},
|
||||
},
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
# Simulate the data fusion logic
|
||||
payload_desc = case["payload"].get("desc")
|
||||
payload_copyright = case["payload"].get("copyright")
|
||||
payload_privacy_policy = case["payload"].get("privacy_policy")
|
||||
payload_custom_disclaimer = case["payload"].get("custom_disclaimer")
|
||||
|
||||
if case["site"]:
|
||||
site_desc = case["site"].get("description")
|
||||
site_copyright = case["site"].get("copyright")
|
||||
site_privacy_policy = case["site"].get("privacy_policy")
|
||||
site_custom_disclaimer = case["site"].get("custom_disclaimer")
|
||||
|
||||
# Site data takes precedence
|
||||
desc = site_desc or payload_desc or ""
|
||||
copyright = site_copyright or payload_copyright or ""
|
||||
privacy_policy = site_privacy_policy or payload_privacy_policy or ""
|
||||
custom_disclaimer = site_custom_disclaimer or payload_custom_disclaimer or ""
|
||||
else:
|
||||
# Use payload data or empty defaults
|
||||
desc = payload_desc or ""
|
||||
copyright = payload_copyright or ""
|
||||
privacy_policy = payload_privacy_policy or ""
|
||||
custom_disclaimer = payload_custom_disclaimer or ""
|
||||
|
||||
result = {
|
||||
"desc": desc,
|
||||
"copyright": copyright,
|
||||
"privacy_policy": privacy_policy,
|
||||
"custom_disclaimer": custom_disclaimer,
|
||||
}
|
||||
|
||||
assert result == case["expected"], f"Failed test case: {case['name']}"
|
||||
|
||||
def test_app_visibility_logic(self):
|
||||
"""Test that apps are made public when added to explore list."""
|
||||
# Create a mock app
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.is_public = False
|
||||
|
||||
# Simulate the business logic
|
||||
mock_app.is_public = True
|
||||
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_recommended_app_creation_logic(self):
|
||||
"""Test the creation of RecommendedApp objects."""
|
||||
app_id = str(uuid.uuid4())
|
||||
payload_data = {
|
||||
"app_id": app_id,
|
||||
"desc": "Test app description",
|
||||
"copyright": "© 2024 Test Company",
|
||||
"privacy_policy": "https://example.com/privacy",
|
||||
"custom_disclaimer": "Custom disclaimer",
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
# Simulate the creation logic
|
||||
recommended_app = Mock(spec=RecommendedApp)
|
||||
recommended_app.app_id = payload_data["app_id"]
|
||||
recommended_app.description = payload_data["desc"]
|
||||
recommended_app.copyright = payload_data["copyright"]
|
||||
recommended_app.privacy_policy = payload_data["privacy_policy"]
|
||||
recommended_app.custom_disclaimer = payload_data["custom_disclaimer"]
|
||||
recommended_app.language = payload_data["language"]
|
||||
recommended_app.category = payload_data["category"]
|
||||
recommended_app.position = payload_data["position"]
|
||||
|
||||
# Verify the data
|
||||
assert recommended_app.app_id == app_id
|
||||
assert recommended_app.description == "Test app description"
|
||||
assert recommended_app.copyright == "© 2024 Test Company"
|
||||
assert recommended_app.privacy_policy == "https://example.com/privacy"
|
||||
assert recommended_app.custom_disclaimer == "Custom disclaimer"
|
||||
assert recommended_app.language == "en-US"
|
||||
assert recommended_app.category == "Productivity"
|
||||
assert recommended_app.position == 1
|
||||
|
||||
def test_recommended_app_update_logic(self):
|
||||
"""Test the update logic for existing RecommendedApp objects."""
|
||||
mock_recommended_app = Mock(spec=RecommendedApp)
|
||||
|
||||
update_data = {
|
||||
"desc": "Updated description",
|
||||
"copyright": "© 2024 Updated",
|
||||
"language": "fr-FR",
|
||||
"category": "Tools",
|
||||
"position": 2,
|
||||
}
|
||||
|
||||
# Simulate the update logic
|
||||
mock_recommended_app.description = update_data["desc"]
|
||||
mock_recommended_app.copyright = update_data["copyright"]
|
||||
mock_recommended_app.language = update_data["language"]
|
||||
mock_recommended_app.category = update_data["category"]
|
||||
mock_recommended_app.position = update_data["position"]
|
||||
|
||||
# Verify the updates
|
||||
assert mock_recommended_app.description == "Updated description"
|
||||
assert mock_recommended_app.copyright == "© 2024 Updated"
|
||||
assert mock_recommended_app.language == "fr-FR"
|
||||
assert mock_recommended_app.category == "Tools"
|
||||
assert mock_recommended_app.position == 2
|
||||
|
||||
def test_app_not_found_error_logic(self):
|
||||
"""Test error handling when app is not found."""
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Simulate app lookup returning None
|
||||
found_app = None
|
||||
|
||||
# Test the error condition
|
||||
if not found_app:
|
||||
with pytest.raises(NotFound, match=f"App '{app_id}' is not found"):
|
||||
raise NotFound(f"App '{app_id}' is not found")
|
||||
|
||||
def test_recommended_app_not_found_error_logic(self):
|
||||
"""Test error handling when recommended app is not found for deletion."""
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Simulate recommended app lookup returning None
|
||||
found_recommended_app = None
|
||||
|
||||
# Test the error condition
|
||||
if not found_recommended_app:
|
||||
with pytest.raises(NotFound, match=f"App '{app_id}' is not found in the explore list"):
|
||||
raise NotFound(f"App '{app_id}' is not found in the explore list")
|
||||
|
||||
def test_database_session_usage_patterns(self):
|
||||
"""Test the expected database session usage patterns."""
|
||||
# Mock session usage patterns
|
||||
mock_session = Mock()
|
||||
|
||||
# Test session.add pattern
|
||||
mock_recommended_app = Mock(spec=RecommendedApp)
|
||||
mock_session.add(mock_recommended_app)
|
||||
mock_session.commit()
|
||||
|
||||
# Verify session was used correctly
|
||||
mock_session.add.assert_called_once_with(mock_recommended_app)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Test session.delete pattern
|
||||
mock_recommended_app_to_delete = Mock(spec=RecommendedApp)
|
||||
mock_session.delete(mock_recommended_app_to_delete)
|
||||
mock_session.commit()
|
||||
|
||||
# Verify delete pattern
|
||||
mock_session.delete.assert_called_once_with(mock_recommended_app_to_delete)
|
||||
|
||||
def test_payload_validation_integration(self):
|
||||
"""Test payload validation in the context of the business logic."""
|
||||
# Test valid payload
|
||||
valid_payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"desc": "Test app description",
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
# This should succeed
|
||||
payload = InsertExploreAppPayload.model_validate(valid_payload_data)
|
||||
assert payload.app_id == valid_payload_data["app_id"]
|
||||
|
||||
# Test invalid payload
|
||||
invalid_payload_data = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "invalid-lang", # This should fail validation
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
# This should raise an exception
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
InsertExploreAppPayload.model_validate(invalid_payload_data)
|
||||
|
||||
|
||||
class TestExploreAppDataHandling:
|
||||
"""Test specific data handling scenarios."""
|
||||
|
||||
def test_uuid_validation(self):
|
||||
"""Test UUID validation and handling."""
|
||||
# Test valid UUID
|
||||
valid_uuid = str(uuid.uuid4())
|
||||
|
||||
# This should be a valid UUID
|
||||
assert uuid.UUID(valid_uuid) is not None
|
||||
|
||||
# Test invalid UUID
|
||||
invalid_uuid = "not-a-valid-uuid"
|
||||
|
||||
# This should raise a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
uuid.UUID(invalid_uuid)
|
||||
|
||||
def test_language_validation(self):
|
||||
"""Test language validation against supported languages."""
|
||||
from constants.languages import supported_language
|
||||
|
||||
# Test supported language
|
||||
assert supported_language("en-US") == "en-US"
|
||||
assert supported_language("fr-FR") == "fr-FR"
|
||||
|
||||
# Test unsupported language
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
supported_language("invalid-lang")
|
||||
|
||||
def test_response_formatting(self):
|
||||
"""Test API response formatting."""
|
||||
# Test success responses
|
||||
create_response = {"result": "success"}
|
||||
update_response = {"result": "success"}
|
||||
delete_response = None # 204 No Content returns None
|
||||
|
||||
assert create_response["result"] == "success"
|
||||
assert update_response["result"] == "success"
|
||||
assert delete_response is None
|
||||
|
||||
# Test status codes
|
||||
create_status = 201 # Created
|
||||
update_status = 200 # OK
|
||||
delete_status = 204 # No Content
|
||||
|
||||
assert create_status == 201
|
||||
assert update_status == 200
|
||||
assert delete_status == 204
|
||||
@ -0,0 +1,25 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.service_api.app.completion import ChatRequestPayload
|
||||
|
||||
|
||||
def test_chat_request_payload_accepts_blank_conversation_id():
|
||||
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": ""})
|
||||
|
||||
assert payload.conversation_id is None
|
||||
|
||||
|
||||
def test_chat_request_payload_validates_uuid():
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": conversation_id})
|
||||
|
||||
assert payload.conversation_id == conversation_id
|
||||
|
||||
|
||||
def test_chat_request_payload_rejects_invalid_uuid():
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})
|
||||
@ -256,24 +256,18 @@ class TestFilePreviewApi:
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse:
|
||||
# Mock request parsing
|
||||
mock_parser = Mock()
|
||||
mock_parser.parse_args.return_value = {"as_attachment": False}
|
||||
mock_reqparse.RequestParser.return_value = mock_parser
|
||||
# Test the core logic directly without Flask decorators
|
||||
# Validate file ownership
|
||||
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
assert result_message_file == mock_message_file
|
||||
assert result_upload_file == mock_upload_file
|
||||
|
||||
# Test the core logic directly without Flask decorators
|
||||
# Validate file ownership
|
||||
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
assert result_message_file == mock_message_file
|
||||
assert result_upload_file == mock_upload_file
|
||||
# Test file response building
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
assert response is not None
|
||||
|
||||
# Test file response building
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
assert response is not None
|
||||
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_not_called() # Since we're testing components separately
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_not_called() # Since we're testing components separately
|
||||
|
||||
@patch("controllers.service_api.app.file_preview.storage")
|
||||
def test_storage_error_handling(
|
||||
|
||||
@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.explore.conversation import ConversationRenamePayload as ConsolePayload
|
||||
from controllers.service_api.app.conversation import ConversationRenamePayload as ServicePayload
|
||||
|
||||
|
||||
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
|
||||
def test_payload_allows_auto_generate_without_name(payload_cls):
|
||||
payload = payload_cls.model_validate({"auto_generate": True})
|
||||
|
||||
assert payload.auto_generate is True
|
||||
assert payload.name is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
|
||||
@pytest.mark.parametrize("value", [None, "", " "])
|
||||
def test_payload_requires_name_when_not_auto_generate(payload_cls, value):
|
||||
with pytest.raises(ValidationError):
|
||||
payload_cls.model_validate({"name": value, "auto_generate": False})
|
||||
129
api/tests/unit_tests/core/helper/test_tool_provider_cache.py
Normal file
129
api/tests/unit_tests/core/helper/test_tool_provider_cache.py
Normal file
@ -0,0 +1,129 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client():
|
||||
"""Fixture: Mock Redis client"""
|
||||
with patch("core.helper.tool_provider_cache.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class TestToolProviderListCache:
|
||||
"""Test class for ToolProviderListCache"""
|
||||
|
||||
def test_generate_cache_key(self):
|
||||
"""Test cache key generation logic"""
|
||||
# Scenario 1: Specify typ (valid literal value)
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
|
||||
|
||||
# Scenario 2: typ is None (defaults to "all")
|
||||
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
|
||||
|
||||
def test_get_cached_providers_hit(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit and successful decoding"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "api"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
|
||||
|
||||
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
|
||||
assert result == mock_providers
|
||||
|
||||
def test_get_cached_providers_decode_error(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit but decoding failed"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = b"invalid_json_data"
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_get_cached_providers_miss(self, mock_redis_client):
|
||||
"""Test get cached providers - cache miss"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_set_cached_providers(self, mock_redis_client):
|
||||
"""Test set cached providers"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
|
||||
|
||||
mock_redis_client.setex.assert_called_once_with(
|
||||
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
|
||||
)
|
||||
|
||||
def test_invalidate_cache_specific_type(self, mock_redis_client):
|
||||
"""Test invalidate cache - specific type"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "workflow"
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id, typ)
|
||||
|
||||
mock_redis_client.delete.assert_called_once_with(cache_key)
|
||||
|
||||
def test_invalidate_cache_all_types(self, mock_redis_client):
|
||||
"""Test invalidate cache - clear all tenant cache"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_keys = [
|
||||
b"tool_providers:tenant_id:tenant_123:type:all",
|
||||
b"tool_providers:tenant_id:tenant_123:type:builtin",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = mock_keys
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
def test_invalidate_cache_no_keys(self, mock_redis_client):
|
||||
"""Test invalidate cache - no cache keys for tenant"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.scan_iter.return_value = []
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
def test_redis_fallback_default_return(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - default return value (Redis error)"""
|
||||
mock_redis_client.get.side_effect = RedisError("Redis connection error")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers("tenant_123")
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_redis_fallback_no_default(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - no default return value (Redis error)"""
|
||||
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
|
||||
|
||||
try:
|
||||
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
|
||||
except RedisError:
|
||||
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
|
||||
|
||||
mock_redis_client.setex.assert_called_once()
|
||||
@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeConnectionError,
|
||||
@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments:
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding_result(self):
|
||||
"""Create a sample TextEmbeddingResult for testing.
|
||||
"""Create a sample EmbeddingResult for testing.
|
||||
|
||||
Returns:
|
||||
TextEmbeddingResult: Mock embedding result with proper structure
|
||||
EmbeddingResult: Mock embedding result with proper structure
|
||||
"""
|
||||
# Create normalized embedding vectors (dimension 1536 for ada-002)
|
||||
embedding_vector = np.random.randn(1536)
|
||||
@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments:
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
return EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized_vector],
|
||||
usage=usage,
|
||||
@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments:
|
||||
latency=0.8,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments:
|
||||
latency=0.6,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=new_embeddings,
|
||||
usage=usage,
|
||||
@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments:
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
return EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments:
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[valid_vector.tolist(), nan_vector],
|
||||
usage=usage,
|
||||
@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery:
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery:
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[nan_vector],
|
||||
usage=usage,
|
||||
@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery:
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching:
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
result_ada = TextEmbeddingResult(
|
||||
result_ada = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized_ada],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
result_3_small = TextEmbeddingResult(
|
||||
result_3_small = EmbeddingResult(
|
||||
model="text-embedding-3-small",
|
||||
embeddings=[normalized_3_small],
|
||||
usage=usage,
|
||||
@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching:
|
||||
latency=0.4,
|
||||
)
|
||||
|
||||
result_openai = TextEmbeddingResult(
|
||||
result_openai = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized_openai],
|
||||
usage=usage_openai,
|
||||
)
|
||||
|
||||
result_cohere = TextEmbeddingResult(
|
||||
result_cohere = EmbeddingResult(
|
||||
model="embed-english-v3.0",
|
||||
embeddings=[normalized_cohere],
|
||||
usage=usage_cohere,
|
||||
@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation:
|
||||
latency=0.7,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation:
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation:
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
result_ada = TextEmbeddingResult(
|
||||
result_ada = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized_ada],
|
||||
usage=usage_ada,
|
||||
@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation:
|
||||
latency=0.4,
|
||||
)
|
||||
|
||||
result_cohere = TextEmbeddingResult(
|
||||
result_cohere = EmbeddingResult(
|
||||
model="embed-english-v3.0",
|
||||
embeddings=[normalized_cohere],
|
||||
usage=usage_cohere,
|
||||
@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases:
|
||||
latency=0.1,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases:
|
||||
latency=1.5,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases:
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases:
|
||||
latency=0.2,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
# Model returns embeddings for all texts
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases:
|
||||
latency=0.8,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases:
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases:
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance:
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance:
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
return EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance:
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = TextEmbeddingResult(
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from docx import Document
|
||||
|
||||
import core.rag.extractor.word_extractor as we
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
|
||||
|
||||
@ -47,3 +50,85 @@ def test_parse_row():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
for idx, row in enumerate(table.rows):
|
||||
assert extractor._parse_row(row, {}, 3) == gt[idx]
|
||||
|
||||
|
||||
def test_extract_images_from_docx(monkeypatch):
|
||||
external_bytes = b"ext-bytes"
|
||||
internal_bytes = b"int-bytes"
|
||||
|
||||
# Patch storage.save to capture writes
|
||||
saves: list[tuple[str, bytes]] = []
|
||||
|
||||
def save(key: str, data: bytes):
|
||||
saves.append((key, data))
|
||||
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=save))
|
||||
|
||||
# Patch db.session to record adds/commit
|
||||
class DummySession:
|
||||
def __init__(self):
|
||||
self.added = []
|
||||
self.committed = False
|
||||
|
||||
def add(self, obj):
|
||||
self.added.append(obj)
|
||||
|
||||
def commit(self):
|
||||
self.committed = True
|
||||
|
||||
db_stub = SimpleNamespace(session=DummySession())
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
|
||||
# Patch config values used for URL composition and storage type
|
||||
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||
|
||||
# Patch UploadFile to avoid real DB models
|
||||
class FakeUploadFile:
|
||||
_i = 0
|
||||
|
||||
def __init__(self, **kwargs): # kwargs match the real signature fields
|
||||
type(self)._i += 1
|
||||
self.id = f"u{self._i}"
|
||||
|
||||
monkeypatch.setattr(we, "UploadFile", FakeUploadFile)
|
||||
|
||||
# Patch external image fetcher
|
||||
def fake_get(url: str):
|
||||
assert url == "https://example.com/image.png"
|
||||
return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes)
|
||||
|
||||
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
|
||||
|
||||
# A hashable internal part object with a blob attribute
|
||||
class HashablePart:
|
||||
def __init__(self, blob: bytes):
|
||||
self.blob = blob
|
||||
|
||||
def __hash__(self) -> int: # ensure it can be used as a dict key like real docx parts
|
||||
return id(self)
|
||||
|
||||
# Build a minimal doc object with both external and internal image rels
|
||||
internal_part = HashablePart(blob=internal_bytes)
|
||||
rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png")
|
||||
rel_int = SimpleNamespace(is_external=False, target_ref="word/media/image1.png", target_part=internal_part)
|
||||
doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext, "rId2": rel_int}))
|
||||
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor.tenant_id = "t1"
|
||||
extractor.user_id = "u1"
|
||||
|
||||
image_map = extractor._extract_images_from_docx(doc)
|
||||
|
||||
# Returned map should contain entries for external (keyed by rId) and internal (keyed by target_part)
|
||||
assert set(image_map.keys()) == {"rId1", internal_part}
|
||||
assert all(v.startswith(" and v.endswith("/file-preview)") for v in image_map.values())
|
||||
|
||||
# Storage should receive both payloads
|
||||
payloads = {data for _, data in saves}
|
||||
assert external_bytes in payloads
|
||||
assert internal_bytes in payloads
|
||||
|
||||
# DB interactions should be recorded
|
||||
assert len(db_stub.session.added) == 2
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
@ -62,7 +62,7 @@ from core.indexing_runner import (
|
||||
IndexingRunner,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
@ -112,7 +112,7 @@ def create_mock_dataset_document(
|
||||
document_id: str | None = None,
|
||||
dataset_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
doc_form: str = IndexType.PARAGRAPH_INDEX,
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
data_source_type: str = "upload_file",
|
||||
doc_language: str = "English",
|
||||
) -> Mock:
|
||||
@ -133,8 +133,8 @@ def create_mock_dataset_document(
|
||||
Mock: A configured mock DatasetDocument object with all required attributes.
|
||||
|
||||
Example:
|
||||
>>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX)
|
||||
>>> assert doc.doc_form == IndexType.QA_INDEX
|
||||
>>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX)
|
||||
>>> assert doc.doc_form == IndexStructureType.QA_INDEX
|
||||
"""
|
||||
doc = Mock(spec=DatasetDocument)
|
||||
doc.id = document_id or str(uuid.uuid4())
|
||||
@ -276,7 +276,7 @@ class TestIndexingRunnerExtract:
|
||||
doc.id = str(uuid.uuid4())
|
||||
doc.dataset_id = str(uuid.uuid4())
|
||||
doc.tenant_id = str(uuid.uuid4())
|
||||
doc.doc_form = IndexType.PARAGRAPH_INDEX
|
||||
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
doc.data_source_type = "upload_file"
|
||||
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
|
||||
return doc
|
||||
@ -616,7 +616,7 @@ class TestIndexingRunnerLoad:
|
||||
doc = Mock(spec=DatasetDocument)
|
||||
doc.id = str(uuid.uuid4())
|
||||
doc.dataset_id = str(uuid.uuid4())
|
||||
doc.doc_form = IndexType.PARAGRAPH_INDEX
|
||||
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
return doc
|
||||
|
||||
@pytest.fixture
|
||||
@ -700,7 +700,7 @@ class TestIndexingRunnerLoad:
|
||||
"""Test loading with parent-child index structure."""
|
||||
# Arrange
|
||||
runner = IndexingRunner()
|
||||
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
|
||||
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||
sample_dataset.indexing_technique = "high_quality"
|
||||
|
||||
# Add child documents
|
||||
@ -775,7 +775,7 @@ class TestIndexingRunnerRun:
|
||||
doc.id = str(uuid.uuid4())
|
||||
doc.dataset_id = str(uuid.uuid4())
|
||||
doc.tenant_id = str(uuid.uuid4())
|
||||
doc.doc_form = IndexType.PARAGRAPH_INDEX
|
||||
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
doc.doc_language = "English"
|
||||
doc.data_source_type = "upload_file"
|
||||
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
|
||||
@ -802,6 +802,21 @@ class TestIndexingRunnerRun:
|
||||
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
|
||||
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
|
||||
|
||||
# Mock current_user (Account) for _transform
|
||||
mock_current_user = MagicMock()
|
||||
mock_current_user.set_tenant_id = MagicMock()
|
||||
|
||||
# Setup db.session.query to return different results based on the model
|
||||
def mock_query_side_effect(model):
|
||||
mock_query_result = MagicMock()
|
||||
if model.__name__ == "Dataset":
|
||||
mock_query_result.filter_by.return_value.first.return_value = mock_dataset
|
||||
elif model.__name__ == "Account":
|
||||
mock_query_result.filter_by.return_value.first.return_value = mock_current_user
|
||||
return mock_query_result
|
||||
|
||||
mock_dependencies["db"].session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Mock processor
|
||||
mock_processor = MagicMock()
|
||||
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
|
||||
@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments:
|
||||
doc.id = str(uuid.uuid4())
|
||||
doc.dataset_id = str(uuid.uuid4())
|
||||
doc.created_by = str(uuid.uuid4())
|
||||
doc.doc_form = IndexType.PARAGRAPH_INDEX
|
||||
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
return doc
|
||||
|
||||
@pytest.fixture
|
||||
@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments:
|
||||
"""Test loading segments for parent-child index."""
|
||||
# Arrange
|
||||
runner = IndexingRunner()
|
||||
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
|
||||
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||
|
||||
# Add child documents
|
||||
for doc in sample_documents:
|
||||
@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate:
|
||||
tenant_id=tenant_id,
|
||||
extract_settings=extract_settings,
|
||||
tmp_processing_rule={"mode": "automatic", "rules": {}},
|
||||
doc_form=IndexType.PARAGRAPH_INDEX,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.rerank.weight_rerank import WeightRerankRunner
|
||||
|
||||
|
||||
def create_mock_model_instance():
|
||||
"""Create a properly configured mock ModelInstance for reranking tests."""
|
||||
mock_instance = Mock(spec=ModelInstance)
|
||||
# Setup provider_model_bundle chain for check_model_support_vision
|
||||
mock_instance.provider_model_bundle = Mock()
|
||||
mock_instance.provider_model_bundle.configuration = Mock()
|
||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||
mock_instance.provider = "test-provider"
|
||||
mock_instance.model = "test-model"
|
||||
return mock_instance
|
||||
|
||||
|
||||
class TestRerankModelRunner:
|
||||
"""Unit tests for RerankModelRunner.
|
||||
|
||||
@ -37,10 +49,23 @@ class TestRerankModelRunner:
|
||||
- Metadata preservation and score injection
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for reranking."""
|
||||
mock_instance = Mock(spec=ModelInstance)
|
||||
# Setup provider_model_bundle chain for check_model_support_vision
|
||||
mock_instance.provider_model_bundle = Mock()
|
||||
mock_instance.provider_model_bundle.configuration = Mock()
|
||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||
mock_instance.provider = "test-provider"
|
||||
mock_instance.model = "test-model"
|
||||
return mock_instance
|
||||
|
||||
@pytest.fixture
|
||||
@ -803,7 +828,7 @@ class TestRerankRunnerFactory:
|
||||
- Parameters are forwarded to runner constructor
|
||||
"""
|
||||
# Arrange: Mock model instance
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
|
||||
# Act: Create runner via factory
|
||||
runner = RerankRunnerFactory.create_rerank_runner(
|
||||
@ -865,7 +890,7 @@ class TestRerankRunnerFactory:
|
||||
- String values are properly matched
|
||||
"""
|
||||
# Arrange: Mock model instance
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
|
||||
# Act: Create runner using enum value
|
||||
runner = RerankRunnerFactory.create_rerank_runner(
|
||||
@ -886,6 +911,13 @@ class TestRerankIntegration:
|
||||
- Real-world usage scenarios
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
def test_model_reranking_full_workflow(self):
|
||||
"""Test complete model-based reranking workflow.
|
||||
|
||||
@ -895,7 +927,7 @@ class TestRerankIntegration:
|
||||
- Top results are returned correctly
|
||||
"""
|
||||
# Arrange: Create mock model and documents
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[
|
||||
@ -951,7 +983,7 @@ class TestRerankIntegration:
|
||||
- Normalization is consistent
|
||||
"""
|
||||
# Arrange: Create mock model with various scores
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[
|
||||
@ -991,6 +1023,13 @@ class TestRerankEdgeCases:
|
||||
- Concurrent reranking scenarios
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
def test_rerank_with_empty_metadata(self):
|
||||
"""Test reranking when documents have empty metadata.
|
||||
|
||||
@ -1000,7 +1039,7 @@ class TestRerankEdgeCases:
|
||||
- Empty metadata documents are processed correctly
|
||||
"""
|
||||
# Arrange: Create documents with empty metadata
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[
|
||||
@ -1046,7 +1085,7 @@ class TestRerankEdgeCases:
|
||||
- Score comparison logic works at boundary
|
||||
"""
|
||||
# Arrange: Create mock with various scores including negatives
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[
|
||||
@ -1082,7 +1121,7 @@ class TestRerankEdgeCases:
|
||||
- No overflow or precision issues
|
||||
"""
|
||||
# Arrange: All documents with perfect scores
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[
|
||||
@ -1117,7 +1156,7 @@ class TestRerankEdgeCases:
|
||||
- Content encoding is preserved
|
||||
"""
|
||||
# Arrange: Documents with special characters
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[
|
||||
@ -1159,7 +1198,7 @@ class TestRerankEdgeCases:
|
||||
- Content is not truncated unexpectedly
|
||||
"""
|
||||
# Arrange: Documents with very long content
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
long_content = "This is a very long document. " * 1000 # ~30,000 characters
|
||||
|
||||
mock_rerank_result = RerankResult(
|
||||
@ -1196,7 +1235,7 @@ class TestRerankEdgeCases:
|
||||
- All documents are processed correctly
|
||||
"""
|
||||
# Arrange: Create 100 documents
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
num_docs = 100
|
||||
|
||||
# Create rerank results for all documents
|
||||
@ -1287,7 +1326,7 @@ class TestRerankEdgeCases:
|
||||
- Documents can still be ranked
|
||||
"""
|
||||
# Arrange: Empty query
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[
|
||||
@ -1325,6 +1364,13 @@ class TestRerankPerformance:
|
||||
- Score calculation optimization
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
def test_rerank_batch_processing(self):
|
||||
"""Test that documents are processed in a single batch.
|
||||
|
||||
@ -1334,7 +1380,7 @@ class TestRerankPerformance:
|
||||
- Efficient batch processing
|
||||
"""
|
||||
# Arrange: Multiple documents
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)],
|
||||
@ -1435,6 +1481,13 @@ class TestRerankErrorHandling:
|
||||
- Error propagation
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
def test_rerank_model_invocation_error(self):
|
||||
"""Test handling of model invocation errors.
|
||||
|
||||
@ -1444,7 +1497,7 @@ class TestRerankErrorHandling:
|
||||
- Error context is preserved
|
||||
"""
|
||||
# Arrange: Mock model that raises exception
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed")
|
||||
|
||||
documents = [
|
||||
@ -1470,7 +1523,7 @@ class TestRerankErrorHandling:
|
||||
- Invalid results don't corrupt output
|
||||
"""
|
||||
# Arrange: Rerank result with invalid index
|
||||
mock_model_instance = Mock(spec=ModelInstance)
|
||||
mock_model_instance = create_mock_model_instance()
|
||||
mock_rerank_result = RerankResult(
|
||||
model="bge-reranker-base",
|
||||
docs=[
|
||||
|
||||
@ -425,15 +425,15 @@ class TestRetrievalService:
|
||||
|
||||
# ==================== Vector Search Tests ====================
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents):
|
||||
def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||
"""
|
||||
Test basic vector/semantic search functionality.
|
||||
|
||||
This test validates the core vector search flow:
|
||||
1. Dataset is retrieved from database
|
||||
2. embedding_search is called via ThreadPoolExecutor
|
||||
2. _retrieve is called via ThreadPoolExecutor
|
||||
3. Documents are added to shared all_documents list
|
||||
4. Results are returned to caller
|
||||
|
||||
@ -447,28 +447,28 @@ class TestRetrievalService:
|
||||
# Set up the mock dataset that will be "retrieved" from database
|
||||
mock_get_dataset.return_value = mock_dataset
|
||||
|
||||
# Create a side effect function that simulates embedding_search behavior
|
||||
# In the real implementation, embedding_search:
|
||||
# 1. Gets the dataset
|
||||
# 2. Creates a Vector instance
|
||||
# 3. Calls search_by_vector with embeddings
|
||||
# 4. Extends all_documents with results
|
||||
def side_effect_embedding_search(
|
||||
# Create a side effect function that simulates _retrieve behavior
|
||||
# _retrieve modifies the all_documents list in place
|
||||
def side_effect_retrieve(
|
||||
flask_app,
|
||||
dataset_id,
|
||||
query,
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_model,
|
||||
all_documents,
|
||||
retrieval_method,
|
||||
exceptions,
|
||||
dataset,
|
||||
query=None,
|
||||
top_k=4,
|
||||
score_threshold=None,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
attachment_id=None,
|
||||
all_documents=None,
|
||||
exceptions=None,
|
||||
):
|
||||
"""Simulate embedding_search adding documents to the shared list."""
|
||||
all_documents.extend(sample_documents)
|
||||
"""Simulate _retrieve adding documents to the shared list."""
|
||||
if all_documents is not None:
|
||||
all_documents.extend(sample_documents)
|
||||
|
||||
mock_embedding_search.side_effect = side_effect_embedding_search
|
||||
mock_retrieve.side_effect = side_effect_retrieve
|
||||
|
||||
# Define test parameters
|
||||
query = "What is Python?" # Natural language query
|
||||
@ -481,7 +481,7 @@ class TestRetrievalService:
|
||||
# 1. Check if query is empty (early return if so)
|
||||
# 2. Get the dataset using _get_dataset
|
||||
# 3. Create ThreadPoolExecutor
|
||||
# 4. Submit embedding_search task
|
||||
# 4. Submit _retrieve task
|
||||
# 5. Wait for completion
|
||||
# 6. Return all_documents list
|
||||
results = RetrievalService.retrieve(
|
||||
@ -502,15 +502,13 @@ class TestRetrievalService:
|
||||
# Verify documents maintain their scores (highest score first in sample_documents)
|
||||
assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
|
||||
|
||||
# Verify embedding_search was called exactly once
|
||||
# Verify _retrieve was called exactly once
|
||||
# This confirms the search method was invoked by ThreadPoolExecutor
|
||||
mock_embedding_search.assert_called_once()
|
||||
mock_retrieve.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_vector_search_with_document_filter(
|
||||
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
|
||||
):
|
||||
def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||
"""
|
||||
Test vector search with document ID filtering.
|
||||
|
||||
@ -522,21 +520,25 @@ class TestRetrievalService:
|
||||
mock_get_dataset.return_value = mock_dataset
|
||||
filtered_docs = [sample_documents[0]]
|
||||
|
||||
def side_effect_embedding_search(
|
||||
def side_effect_retrieve(
|
||||
flask_app,
|
||||
dataset_id,
|
||||
query,
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_model,
|
||||
all_documents,
|
||||
retrieval_method,
|
||||
exceptions,
|
||||
dataset,
|
||||
query=None,
|
||||
top_k=4,
|
||||
score_threshold=None,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
attachment_id=None,
|
||||
all_documents=None,
|
||||
exceptions=None,
|
||||
):
|
||||
all_documents.extend(filtered_docs)
|
||||
if all_documents is not None:
|
||||
all_documents.extend(filtered_docs)
|
||||
|
||||
mock_embedding_search.side_effect = side_effect_embedding_search
|
||||
mock_retrieve.side_effect = side_effect_retrieve
|
||||
document_ids_filter = [sample_documents[0].metadata["document_id"]]
|
||||
|
||||
# Act
|
||||
@ -552,12 +554,12 @@ class TestRetrievalService:
|
||||
assert len(results) == 1
|
||||
assert results[0].metadata["doc_id"] == "doc1"
|
||||
# Verify document_ids_filter was passed
|
||||
call_kwargs = mock_embedding_search.call_args.kwargs
|
||||
call_kwargs = mock_retrieve.call_args.kwargs
|
||||
assert call_kwargs["document_ids_filter"] == document_ids_filter
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset):
|
||||
def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset):
|
||||
"""
|
||||
Test vector search when no results match the query.
|
||||
|
||||
@ -567,8 +569,8 @@ class TestRetrievalService:
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_dataset.return_value = mock_dataset
|
||||
# embedding_search doesn't add anything to all_documents
|
||||
mock_embedding_search.side_effect = lambda *args, **kwargs: None
|
||||
# _retrieve doesn't add anything to all_documents
|
||||
mock_retrieve.side_effect = lambda *args, **kwargs: None
|
||||
|
||||
# Act
|
||||
results = RetrievalService.retrieve(
|
||||
@ -583,9 +585,9 @@ class TestRetrievalService:
|
||||
|
||||
# ==================== Keyword Search Tests ====================
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents):
|
||||
def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||
"""
|
||||
Test basic keyword search functionality.
|
||||
|
||||
@ -597,12 +599,25 @@ class TestRetrievalService:
|
||||
# Arrange
|
||||
mock_get_dataset.return_value = mock_dataset
|
||||
|
||||
def side_effect_keyword_search(
|
||||
flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
|
||||
def side_effect_retrieve(
|
||||
flask_app,
|
||||
retrieval_method,
|
||||
dataset,
|
||||
query=None,
|
||||
top_k=4,
|
||||
score_threshold=None,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
attachment_id=None,
|
||||
all_documents=None,
|
||||
exceptions=None,
|
||||
):
|
||||
all_documents.extend(sample_documents)
|
||||
if all_documents is not None:
|
||||
all_documents.extend(sample_documents)
|
||||
|
||||
mock_keyword_search.side_effect = side_effect_keyword_search
|
||||
mock_retrieve.side_effect = side_effect_retrieve
|
||||
|
||||
query = "Python programming"
|
||||
top_k = 3
|
||||
@ -618,7 +633,7 @@ class TestRetrievalService:
|
||||
# Assert
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(doc, Document) for doc in results)
|
||||
mock_keyword_search.assert_called_once()
|
||||
mock_retrieve.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
@ -1147,11 +1162,9 @@ class TestRetrievalService:
|
||||
|
||||
# ==================== Metadata Filtering Tests ====================
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_vector_search_with_metadata_filter(
|
||||
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
|
||||
):
|
||||
def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||
"""
|
||||
Test vector search with metadata-based document filtering.
|
||||
|
||||
@ -1166,21 +1179,25 @@ class TestRetrievalService:
|
||||
filtered_doc = sample_documents[0]
|
||||
filtered_doc.metadata["category"] = "programming"
|
||||
|
||||
def side_effect_embedding(
|
||||
def side_effect_retrieve(
|
||||
flask_app,
|
||||
dataset_id,
|
||||
query,
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_model,
|
||||
all_documents,
|
||||
retrieval_method,
|
||||
exceptions,
|
||||
dataset,
|
||||
query=None,
|
||||
top_k=4,
|
||||
score_threshold=None,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
attachment_id=None,
|
||||
all_documents=None,
|
||||
exceptions=None,
|
||||
):
|
||||
all_documents.append(filtered_doc)
|
||||
if all_documents is not None:
|
||||
all_documents.append(filtered_doc)
|
||||
|
||||
mock_embedding_search.side_effect = side_effect_embedding
|
||||
mock_retrieve.side_effect = side_effect_retrieve
|
||||
|
||||
# Act
|
||||
results = RetrievalService.retrieve(
|
||||
@ -1243,9 +1260,9 @@ class TestRetrievalService:
|
||||
# Assert
|
||||
assert results == []
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset):
|
||||
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset):
|
||||
"""
|
||||
Test that exceptions during retrieval are properly handled.
|
||||
|
||||
@ -1256,22 +1273,26 @@ class TestRetrievalService:
|
||||
# Arrange
|
||||
mock_get_dataset.return_value = mock_dataset
|
||||
|
||||
# Make embedding_search add an exception to the exceptions list
|
||||
# Make _retrieve add an exception to the exceptions list
|
||||
def side_effect_with_exception(
|
||||
flask_app,
|
||||
dataset_id,
|
||||
query,
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_model,
|
||||
all_documents,
|
||||
retrieval_method,
|
||||
exceptions,
|
||||
dataset,
|
||||
query=None,
|
||||
top_k=4,
|
||||
score_threshold=None,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
attachment_id=None,
|
||||
all_documents=None,
|
||||
exceptions=None,
|
||||
):
|
||||
exceptions.append("Search failed")
|
||||
if exceptions is not None:
|
||||
exceptions.append("Search failed")
|
||||
|
||||
mock_embedding_search.side_effect = side_effect_with_exception
|
||||
mock_retrieve.side_effect = side_effect_with_exception
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@ -1286,9 +1307,9 @@ class TestRetrievalService:
|
||||
|
||||
# ==================== Score Threshold Tests ====================
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset):
|
||||
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset):
|
||||
"""
|
||||
Test vector search with score threshold filtering.
|
||||
|
||||
@ -1306,21 +1327,25 @@ class TestRetrievalService:
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
def side_effect_embedding(
|
||||
def side_effect_retrieve(
|
||||
flask_app,
|
||||
dataset_id,
|
||||
query,
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_model,
|
||||
all_documents,
|
||||
retrieval_method,
|
||||
exceptions,
|
||||
dataset,
|
||||
query=None,
|
||||
top_k=4,
|
||||
score_threshold=None,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
attachment_id=None,
|
||||
all_documents=None,
|
||||
exceptions=None,
|
||||
):
|
||||
all_documents.append(high_score_doc)
|
||||
if all_documents is not None:
|
||||
all_documents.append(high_score_doc)
|
||||
|
||||
mock_embedding_search.side_effect = side_effect_embedding
|
||||
mock_retrieve.side_effect = side_effect_retrieve
|
||||
|
||||
score_threshold = 0.8
|
||||
|
||||
@ -1339,9 +1364,9 @@ class TestRetrievalService:
|
||||
|
||||
# ==================== Top-K Limiting Tests ====================
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset):
|
||||
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset):
|
||||
"""
|
||||
Test that retrieval respects top_k parameter.
|
||||
|
||||
@ -1362,22 +1387,26 @@ class TestRetrievalService:
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
def side_effect_embedding(
|
||||
def side_effect_retrieve(
|
||||
flask_app,
|
||||
dataset_id,
|
||||
query,
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_model,
|
||||
all_documents,
|
||||
retrieval_method,
|
||||
exceptions,
|
||||
dataset,
|
||||
query=None,
|
||||
top_k=4,
|
||||
score_threshold=None,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
attachment_id=None,
|
||||
all_documents=None,
|
||||
exceptions=None,
|
||||
):
|
||||
# Return only top_k documents
|
||||
all_documents.extend(many_docs[:top_k])
|
||||
if all_documents is not None:
|
||||
all_documents.extend(many_docs[:top_k])
|
||||
|
||||
mock_embedding_search.side_effect = side_effect_embedding
|
||||
mock_retrieve.side_effect = side_effect_retrieve
|
||||
|
||||
top_k = 3
|
||||
|
||||
@ -1390,9 +1419,9 @@ class TestRetrievalService:
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify top_k was passed to embedding_search
|
||||
assert mock_embedding_search.called
|
||||
call_kwargs = mock_embedding_search.call_args.kwargs
|
||||
# Verify _retrieve was called
|
||||
assert mock_retrieve.called
|
||||
call_kwargs = mock_retrieve.call_args.kwargs
|
||||
assert call_kwargs["top_k"] == top_k
|
||||
# Verify we got the right number of results
|
||||
assert len(results) == top_k
|
||||
@ -1421,11 +1450,9 @@ class TestRetrievalService:
|
||||
|
||||
# ==================== Reranking Tests ====================
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
def test_semantic_search_with_reranking(
|
||||
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
|
||||
):
|
||||
def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||
"""
|
||||
Test semantic search with reranking model.
|
||||
|
||||
@ -1439,22 +1466,26 @@ class TestRetrievalService:
|
||||
# Simulate reranking changing order
|
||||
reranked_docs = list(reversed(sample_documents))
|
||||
|
||||
def side_effect_embedding(
|
||||
def side_effect_retrieve(
|
||||
flask_app,
|
||||
dataset_id,
|
||||
query,
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_model,
|
||||
all_documents,
|
||||
retrieval_method,
|
||||
exceptions,
|
||||
dataset,
|
||||
query=None,
|
||||
top_k=4,
|
||||
score_threshold=None,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
attachment_id=None,
|
||||
all_documents=None,
|
||||
exceptions=None,
|
||||
):
|
||||
# embedding_search handles reranking internally
|
||||
all_documents.extend(reranked_docs)
|
||||
# _retrieve handles reranking internally
|
||||
if all_documents is not None:
|
||||
all_documents.extend(reranked_docs)
|
||||
|
||||
mock_embedding_search.side_effect = side_effect_embedding
|
||||
mock_retrieve.side_effect = side_effect_retrieve
|
||||
|
||||
reranking_model = {
|
||||
"reranking_provider_name": "cohere",
|
||||
@ -1473,7 +1504,7 @@ class TestRetrievalService:
|
||||
# Assert
|
||||
# For semantic search with reranking, reranking_model should be passed
|
||||
assert len(results) == 3
|
||||
call_kwargs = mock_embedding_search.call_args.kwargs
|
||||
call_kwargs = mock_retrieve.call_args.kwargs
|
||||
assert call_kwargs["reranking_model"] == reranking_model
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,86 @@
|
||||
import pytest
|
||||
|
||||
import core.tools.utils.message_transformer as mt
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class _FakeToolFile:
|
||||
def __init__(self, mimetype: str):
|
||||
self.id = "fake-tool-file-id"
|
||||
self.mimetype = mimetype
|
||||
|
||||
|
||||
class _FakeToolFileManager:
|
||||
"""Fake ToolFileManager to capture the mimetype passed in."""
|
||||
|
||||
last_call: dict | None = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def create_file_by_raw(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str | None,
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: str | None = None,
|
||||
):
|
||||
type(self).last_call = {
|
||||
"user_id": user_id,
|
||||
"tenant_id": tenant_id,
|
||||
"conversation_id": conversation_id,
|
||||
"file_binary": file_binary,
|
||||
"mimetype": mimetype,
|
||||
"filename": filename,
|
||||
}
|
||||
return _FakeToolFile(mimetype)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_tool_file_manager(monkeypatch):
|
||||
# Patch the manager used inside the transformer module
|
||||
monkeypatch.setattr(mt, "ToolFileManager", _FakeToolFileManager)
|
||||
# also ensure predictable URL generation (no need to patch; uses id and extension only)
|
||||
yield
|
||||
_FakeToolFileManager.last_call = None
|
||||
|
||||
|
||||
def _gen(messages):
|
||||
yield from messages
|
||||
|
||||
|
||||
def test_transform_tool_invoke_messages_mimetype_key_present_but_none():
|
||||
# Arrange: a BLOB message whose meta contains a mime_type key set to None
|
||||
blob = b"hello"
|
||||
msg = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=blob),
|
||||
meta={"mime_type": None, "filename": "greeting"},
|
||||
)
|
||||
|
||||
# Act
|
||||
out = list(
|
||||
mt.ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=_gen([msg]),
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
conversation_id="c1",
|
||||
)
|
||||
)
|
||||
|
||||
# Assert: default to application/octet-stream when mime_type is present but None
|
||||
assert _FakeToolFileManager.last_call is not None
|
||||
assert _FakeToolFileManager.last_call["mimetype"] == "application/octet-stream"
|
||||
|
||||
# Should yield a BINARY_LINK (not IMAGE_LINK) and the URL ends with .bin
|
||||
assert len(out) == 1
|
||||
o = out[0]
|
||||
assert o.type == ToolInvokeMessage.MessageType.BINARY_LINK
|
||||
assert isinstance(o.message, ToolInvokeMessage.TextMessage)
|
||||
assert o.message.text.endswith(".bin")
|
||||
# meta is preserved (still contains mime_type: None)
|
||||
assert "mime_type" in (o.meta or {})
|
||||
assert o.meta["mime_type"] is None
|
||||
@ -0,0 +1,60 @@
|
||||
"""
|
||||
Test case for end node without value_type field (backward compatibility).
|
||||
|
||||
This test validates that end nodes work correctly even when the value_type
|
||||
field is missing from the output configuration, ensuring backward compatibility
|
||||
with older workflow definitions.
|
||||
"""
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_end_node_without_value_type_field():
|
||||
"""
|
||||
Test that end node works without explicit value_type field.
|
||||
|
||||
The fixture implements a simple workflow that:
|
||||
1. Takes a query input from start node
|
||||
2. Passes it directly to end node
|
||||
3. End node outputs the value without specifying value_type
|
||||
4. Should correctly infer the type and output the value
|
||||
|
||||
This ensures backward compatibility with workflow definitions
|
||||
created before value_type became a required field.
|
||||
"""
|
||||
fixture_name = "end_node_without_value_type_field_workflow"
|
||||
|
||||
case = WorkflowTestCase(
|
||||
fixture_path=fixture_name,
|
||||
inputs={"query": "test query"},
|
||||
expected_outputs={"query": "test query"},
|
||||
expected_event_sequence=[
|
||||
# Graph start
|
||||
GraphRunStartedEvent,
|
||||
# Start node
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent, # Start node streams the input value
|
||||
NodeRunSucceededEvent,
|
||||
# End node
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Graph end
|
||||
GraphRunSucceededEvent,
|
||||
],
|
||||
description="End node without value_type field should work correctly",
|
||||
)
|
||||
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
assert result.actual_outputs == {"query": "test query"}, (
|
||||
f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}"
|
||||
)
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import httpx
|
||||
@ -138,3 +139,95 @@ def test_is_file_with_no_content_disposition(mock_response):
|
||||
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
|
||||
response = Response(mock_response)
|
||||
assert response.is_file
|
||||
|
||||
|
||||
# UTF-8 Encoding Tests
|
||||
@pytest.mark.parametrize(
|
||||
("content_bytes", "expected_text", "description"),
|
||||
[
|
||||
# Chinese UTF-8 bytes
|
||||
(
|
||||
b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}',
|
||||
'{"message": "你好世界"}',
|
||||
"Chinese characters UTF-8",
|
||||
),
|
||||
# Japanese UTF-8 bytes
|
||||
(
|
||||
b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}',
|
||||
'{"message": "こんにちは"}',
|
||||
"Japanese characters UTF-8",
|
||||
),
|
||||
# Korean UTF-8 bytes
|
||||
(
|
||||
b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}',
|
||||
'{"message": "안녕하세요"}',
|
||||
"Korean characters UTF-8",
|
||||
),
|
||||
# Arabic UTF-8
|
||||
(b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"),
|
||||
# European characters UTF-8
|
||||
(b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"),
|
||||
# Simple ASCII
|
||||
(b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"),
|
||||
],
|
||||
)
|
||||
def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description):
|
||||
"""Test that Response.text properly decodes UTF-8 content with charset_normalizer"""
|
||||
mock_response.headers = {"content-type": "application/json; charset=utf-8"}
|
||||
type(mock_response).content = PropertyMock(return_value=content_bytes)
|
||||
# Mock httpx response.text to return something different (simulating potential encoding issues)
|
||||
mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property
|
||||
|
||||
response = Response(mock_response)
|
||||
|
||||
# Our enhanced text property should decode properly using charset_normalizer
|
||||
assert response.text == expected_text, (
|
||||
f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}"
|
||||
)
|
||||
|
||||
|
||||
def test_text_property_fallback_to_httpx(mock_response):
|
||||
"""Test that Response.text falls back to httpx.text when charset_normalizer fails"""
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
|
||||
# Create malformed UTF-8 bytes
|
||||
malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}'
|
||||
type(mock_response).content = PropertyMock(return_value=malformed_bytes)
|
||||
|
||||
# Mock httpx.text to return some fallback value
|
||||
fallback_text = '{"text": "fallback"}'
|
||||
mock_response.text = fallback_text
|
||||
|
||||
response = Response(mock_response)
|
||||
|
||||
# Should fall back to httpx's text when charset_normalizer fails
|
||||
assert response.text == fallback_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("json_content", "description"),
|
||||
[
|
||||
# JSON with escaped Unicode (like Flask jsonify())
|
||||
('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"),
|
||||
# JSON with mixed escape sequences and UTF-8
|
||||
('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"),
|
||||
# JSON with complex escape sequences
|
||||
('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"),
|
||||
],
|
||||
)
|
||||
def test_text_property_with_escaped_unicode(mock_response, json_content, description):
|
||||
"""Test Response.text with JSON containing Unicode escape sequences"""
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
|
||||
content_bytes = json_content.encode("utf-8")
|
||||
type(mock_response).content = PropertyMock(return_value=content_bytes)
|
||||
mock_response.text = json_content # httpx would return the same for valid UTF-8
|
||||
|
||||
response = Response(mock_response)
|
||||
|
||||
# Should preserve the escape sequences (valid JSON)
|
||||
assert response.text == json_content, f"Failed for {description}"
|
||||
|
||||
# The text should be valid JSON that can be parsed back to proper Unicode
|
||||
parsed = json.loads(response.text)
|
||||
assert isinstance(parsed, dict), f"Invalid JSON for {description}"
|
||||
|
||||
@ -1149,3 +1149,258 @@ class TestModelIntegration:
|
||||
# Assert
|
||||
assert site.app_id == app.id
|
||||
assert app.enable_site is True
|
||||
|
||||
|
||||
class TestConversationStatusCount:
|
||||
"""Test suite for Conversation.status_count property N+1 query fix."""
|
||||
|
||||
def test_status_count_no_messages(self):
|
||||
"""Test status_count returns None when conversation has no messages."""
|
||||
# Arrange
|
||||
conversation = Conversation(
|
||||
app_id=str(uuid4()),
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = str(uuid4())
|
||||
|
||||
# Mock the database query to return no messages
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_status_count_messages_without_workflow_runs(self):
|
||||
"""Test status_count when messages have no workflow_run_id."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock the database query to return no messages with workflow_run_id
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_status_count_batch_loading_implementation(self):
|
||||
"""Test that status_count uses batch loading instead of N+1 queries."""
|
||||
# Arrange
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
# Create workflow run IDs
|
||||
workflow_run_id_1 = str(uuid4())
|
||||
workflow_run_id_2 = str(uuid4())
|
||||
workflow_run_id_3 = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock messages with workflow_run_id
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_1,
|
||||
),
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_2,
|
||||
),
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_3,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow runs with different statuses
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id_1,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
MagicMock(
|
||||
id=workflow_run_id_2,
|
||||
status=WorkflowExecutionStatus.FAILED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
MagicMock(
|
||||
id=workflow_run_id_3,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
# Track database calls
|
||||
calls_made = []
|
||||
|
||||
def mock_scalars(query):
|
||||
calls_made.append(str(query))
|
||||
mock_result = MagicMock()
|
||||
|
||||
# Return messages for the first query (messages with workflow_run_id)
|
||||
if "messages" in str(query) and "conversation_id" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
# Return workflow runs for the batch query
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
|
||||
return mock_result
|
||||
|
||||
# Act & Assert
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Verify only 2 database queries were made (not N+1)
|
||||
assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}"
|
||||
|
||||
# Verify the first query gets messages
|
||||
assert "messages" in calls_made[0]
|
||||
assert "conversation_id" in calls_made[0]
|
||||
|
||||
# Verify the second query batch loads workflow runs with proper filtering
|
||||
assert "workflow_runs" in calls_made[1]
|
||||
assert "app_id" in calls_made[1] # Security filter applied
|
||||
assert "IN" in calls_made[1] # Batch loading with IN clause
|
||||
|
||||
# Verify correct status counts
|
||||
assert result["success"] == 1 # One SUCCEEDED
|
||||
assert result["failed"] == 1 # One FAILED
|
||||
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
|
||||
|
||||
def test_status_count_app_id_filtering(self):
|
||||
"""Test that status_count filters workflow runs by app_id for security."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
other_app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock message with workflow_run_id
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
calls_made = []
|
||||
|
||||
def mock_scalars(query):
|
||||
calls_made.append(str(query))
|
||||
mock_result = MagicMock()
|
||||
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
# Return empty list because no workflow run matches the correct app_id
|
||||
mock_result.all.return_value = [] # Workflow run filtered out by app_id
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
|
||||
return mock_result
|
||||
|
||||
# Act
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - query should include app_id filter
|
||||
workflow_query = calls_made[1]
|
||||
assert "app_id" in workflow_query
|
||||
|
||||
# Since workflow run has wrong app_id, it shouldn't be included in counts
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
|
||||
def test_status_count_handles_invalid_workflow_status(self):
|
||||
"""Test that status_count gracefully handles invalid workflow status values."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow run with invalid status
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id,
|
||||
status="invalid_status", # Invalid status that should raise ValueError
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
# Mock the messages query
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
return mock_result
|
||||
|
||||
mock_scalars.side_effect = mock_scalars_side_effect
|
||||
|
||||
# Act - should not raise exception
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - should handle invalid status gracefully
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
|
||||
@ -14,7 +14,9 @@ def get_example_bucket() -> str:
|
||||
|
||||
|
||||
def get_opendal_bucket() -> str:
|
||||
return "./dify"
|
||||
import os
|
||||
|
||||
return os.environ.get("OPENDAL_FS_ROOT", "/tmp/dify-storage")
|
||||
|
||||
|
||||
def get_example_filename() -> str:
|
||||
|
||||
@ -21,20 +21,16 @@ class TestOpenDAL:
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def teardown_class(self, request):
|
||||
def teardown_class(self):
|
||||
"""Clean up after all tests in the class."""
|
||||
|
||||
def cleanup():
|
||||
folder = Path(get_opendal_bucket())
|
||||
if folder.exists() and folder.is_dir():
|
||||
for item in folder.iterdir():
|
||||
if item.is_file():
|
||||
item.unlink()
|
||||
elif item.is_dir():
|
||||
item.rmdir()
|
||||
folder.rmdir()
|
||||
yield
|
||||
|
||||
return cleanup()
|
||||
folder = Path(get_opendal_bucket())
|
||||
if folder.exists() and folder.is_dir():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(folder, ignore_errors=True)
|
||||
|
||||
def test_save_and_exists(self):
|
||||
"""Test saving data and checking existence."""
|
||||
|
||||
@ -117,7 +117,7 @@ import pytest
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
|
||||
# ============================================================================
|
||||
# Test Data Factory
|
||||
@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Features Property Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""
|
||||
Test cached_property features.
|
||||
@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_features_property_with_different_tenants(self, mock_feature_service):
|
||||
"""
|
||||
Test features property with different tenant IDs.
|
||||
@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Direct Queue Routing Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue method.
|
||||
@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_direct_queue_with_priority_task(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue with priority task function.
|
||||
@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue_with_single_document(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue with single document ID.
|
||||
@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue_with_empty_documents(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue with empty document_ids list.
|
||||
@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Tenant Queue Routing Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue when task key exists.
|
||||
@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue when no task key exists.
|
||||
@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_priority_task(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue with priority task function.
|
||||
@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_document_task_serialization(self, mock_task):
|
||||
"""
|
||||
Test DocumentTask serialization in _send_to_tenant_queue.
|
||||
@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Queue Type Selection Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""
|
||||
Test _send_to_default_tenant_queue method.
|
||||
@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""
|
||||
Test _send_to_priority_tenant_queue method.
|
||||
@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""
|
||||
Test _send_to_priority_direct_queue method.
|
||||
@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Dispatch Logic Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method when billing is enabled with SANDBOX plan.
|
||||
@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method when billing is enabled with TEAM plan.
|
||||
@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method when billing is enabled with PROFESSIONAL plan.
|
||||
@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method when billing is disabled.
|
||||
@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method with empty plan string.
|
||||
@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method with None plan.
|
||||
@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Delay Method Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""
|
||||
Test delay method integration.
|
||||
@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method_with_team_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test delay method with TEAM plan.
|
||||
@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method_with_billing_disabled(self, mock_feature_service):
|
||||
"""
|
||||
Test delay method with billing disabled.
|
||||
@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Batch Operations Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_batch_operation_with_multiple_documents(self, mock_task):
|
||||
"""
|
||||
Test batch operation with multiple documents.
|
||||
@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_batch_operation_with_large_batch(self, mock_task):
|
||||
"""
|
||||
Test batch operation with large batch of documents.
|
||||
@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Error Handling Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue_task_delay_failure(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue when task.delay() raises an exception.
|
||||
@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
with pytest.raises(Exception, match="Task delay failed"):
|
||||
proxy._send_to_direct_queue(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_push_tasks_failure(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue when push_tasks raises an exception.
|
||||
@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
with pytest.raises(Exception, match="Push tasks failed"):
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue when set_task_waiting_time raises an exception.
|
||||
@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
with pytest.raises(Exception, match="Set waiting time failed"):
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_feature_service_failure(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch when FeatureService.get_features raises an exception.
|
||||
@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Integration Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service):
|
||||
"""
|
||||
Test full flow for SANDBOX plan with tenant queue.
|
||||
@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy:
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_full_flow_team_plan(self, mock_task, mock_feature_service):
|
||||
"""
|
||||
Test full flow for TEAM plan with priority tenant queue.
|
||||
@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy:
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_full_flow_billing_disabled(self, mock_task, mock_feature_service):
|
||||
"""
|
||||
Test full flow for billing disabled (self-hosted/enterprise).
|
||||
|
||||
@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxyTestDataFactory:
|
||||
@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
)
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
def test_send_to_default_tenant_queue(self):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy:
|
||||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
def test_send_to_priority_tenant_queue(self):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy:
|
||||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
def test_send_to_priority_direct_queue(self):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy:
|
||||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
|
||||
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
assert task.dataset_id == dataset_id
|
||||
assert task.document_ids == document_ids
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy:
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
|
||||
@ -0,0 +1,363 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import (
|
||||
DuplicateDocumentIndexingTaskProxy,
|
||||
)
|
||||
|
||||
|
||||
class DuplicateDocumentIndexingTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_duplicate_document_task_proxy(
|
||||
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
|
||||
) -> DuplicateDocumentIndexingTaskProxy:
|
||||
"""Create DuplicateDocumentIndexingTaskProxy instance for testing."""
|
||||
if document_ids is None:
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTaskProxy:
|
||||
"""Test cases for DuplicateDocumentIndexingTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test DuplicateDocumentIndexingTaskProxy initialization."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
|
||||
# Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing"
|
||||
|
||||
def test_queue_name(self):
|
||||
"""Test QUEUE_NAME class variable."""
|
||||
# Arrange & Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
|
||||
# Assert
|
||||
assert proxy.QUEUE_NAME == "duplicate_document_indexing"
|
||||
|
||||
def test_task_functions(self):
|
||||
"""Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables."""
|
||||
# Arrange & Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
|
||||
# Assert
|
||||
assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task"
|
||||
assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task"
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch(
|
||||
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||
)
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch(
|
||||
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||
)
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
|
||||
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
|
||||
assert len(pushed_tasks) == 1
|
||||
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
|
||||
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
|
||||
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
|
||||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||
)
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
def test_send_to_default_tenant_queue(self):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
|
||||
|
||||
def test_send_to_priority_tenant_queue(self):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
def test_send_to_priority_direct_queue(self):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
# If billing enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=""
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=None
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
def test_initialization_with_empty_document_ids(self):
|
||||
"""Test initialization with empty document_ids list."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
||||
def test_initialization_with_single_document_id(self):
|
||||
"""Test initialization with single document_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1"]
|
||||
|
||||
# Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
||||
def test_initialization_with_large_batch(self):
|
||||
"""Test initialization with large batch of document IDs."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = [f"doc-{i}" for i in range(100)]
|
||||
|
||||
# Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
assert len(proxy._document_ids) == 100
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_professional_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with professional plan."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.PROFESSIONAL
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
@ -6,6 +6,7 @@ Target: 1500+ lines of comprehensive test coverage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
@ -1791,8 +1792,8 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_non_200_status(self, mock_db, mock_process, factory):
|
||||
"""Test retrieval returns empty list on non-200 status."""
|
||||
def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory):
|
||||
"""Test that non-200 status code raises Exception with response text."""
|
||||
# Arrange
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
@ -1817,12 +1818,103 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error: Database connection failed"
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Internal Server Error: Database connection failed"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "error_message"),
|
||||
[
|
||||
(400, "Bad Request: Invalid query parameters"),
|
||||
(401, "Unauthorized: Invalid API key"),
|
||||
(403, "Forbidden: Access denied to resource"),
|
||||
(404, "Not Found: Knowledge base not found"),
|
||||
(429, "Too Many Requests: Rate limit exceeded"),
|
||||
(500, "Internal Server Error: Database connection failed"),
|
||||
(502, "Bad Gateway: External service unavailable"),
|
||||
(503, "Service Unavailable: Maintenance mode"),
|
||||
],
|
||||
)
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_various_error_status_codes(
|
||||
self, mock_db, mock_process, factory, status_code, error_message
|
||||
):
|
||||
"""Test that various error status codes raise exceptions with response text."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-123"
|
||||
|
||||
binding = factory.create_external_knowledge_binding_mock(
|
||||
dataset_id=dataset_id, external_knowledge_api_id="api-123"
|
||||
)
|
||||
api = factory.create_external_knowledge_api_mock(api_id="api-123")
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = error_message
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match=re.escape(error_message)):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5})
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory):
|
||||
"""Test exception with empty response text."""
|
||||
# Arrange
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 503
|
||||
mock_response.text = ""
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match=""):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
|
||||
@ -2,8 +2,6 @@ from pathlib import Path
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from models.account import Account
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
@ -77,60 +75,39 @@ class TestMetadataBugCompleteValidation:
|
||||
assert type_column.nullable is False, "type column should be nullable=False"
|
||||
assert name_column.nullable is False, "name column should be nullable=False"
|
||||
|
||||
def test_4_fixed_api_layer_rejects_null(self, app):
|
||||
"""Test Layer 4: Fixed API configuration properly rejects null values."""
|
||||
# Test Console API create endpoint (fixed)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
def test_4_fixed_api_layer_rejects_null(self):
|
||||
"""Test Layer 4: Fixed API configuration properly rejects null values using Pydantic."""
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs.model_validate({"type": None, "name": None})
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs.model_validate({"type": "string", "name": None})
|
||||
|
||||
# Test with just name being null
|
||||
with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs.model_validate({"type": None, "name": "test"})
|
||||
|
||||
# Test with just type being null
|
||||
with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
def test_5_fixed_api_accepts_valid_values(self, app):
|
||||
def test_5_fixed_api_accepts_valid_values(self):
|
||||
"""Test that fixed API still accepts valid non-null values."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = MetadataArgs.model_validate({"type": "string", "name": "valid_name"})
|
||||
assert args.type == "string"
|
||||
assert args.name == "valid_name"
|
||||
|
||||
with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"):
|
||||
args = parser.parse_args()
|
||||
assert args["type"] == "string"
|
||||
assert args["name"] == "valid_name"
|
||||
def test_6_simulated_buggy_behavior(self):
|
||||
"""Test simulating the original buggy behavior by bypassing Pydantic validation."""
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None
|
||||
mock_metadata_args.type = None
|
||||
|
||||
def test_6_simulated_buggy_behavior(self, app):
|
||||
"""Test simulating the original buggy behavior with nullable=True."""
|
||||
# Simulate the old buggy configuration
|
||||
buggy_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This would pass in the buggy version
|
||||
args = buggy_parser.parse_args()
|
||||
assert args["type"] is None
|
||||
assert args["name"] is None
|
||||
|
||||
# But would crash when trying to create MetadataArgs
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs.model_validate(args)
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_7_end_to_end_validation_layers(self):
|
||||
"""Test all validation layers work together correctly."""
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import reqparse
|
||||
|
||||
from models.account import Account
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
@ -51,76 +50,16 @@ class TestMetadataNullableBug:
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
|
||||
def test_api_parser_accepts_null_values(self, app):
|
||||
"""Test that API parser configuration incorrectly accepts null values."""
|
||||
# Simulate the current API parser configuration
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
def test_api_layer_now_uses_pydantic_validation(self):
|
||||
"""Verify that API layer relies on Pydantic validation instead of reqparse."""
|
||||
invalid_payload = {"type": None, "name": None}
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs.model_validate(invalid_payload)
|
||||
|
||||
# Simulate request data with null values
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This should parse successfully due to nullable=True
|
||||
args = parser.parse_args()
|
||||
|
||||
# Verify that null values are accepted
|
||||
assert args["type"] is None
|
||||
assert args["name"] is None
|
||||
|
||||
# This demonstrates the bug: API accepts None but business logic will crash
|
||||
|
||||
def test_integration_bug_scenario(self, app):
|
||||
"""Test the complete bug scenario from API to service layer."""
|
||||
# Step 1: API parser accepts null values (current buggy behavior)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
args = parser.parse_args()
|
||||
|
||||
# Step 2: Try to create MetadataArgs with None values
|
||||
# This should fail at Pydantic validation level
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
# Step 3: If we bypass Pydantic (simulating the bug scenario)
|
||||
# Move this outside the request context to avoid Flask-Login issues
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None # From args["name"]
|
||||
mock_metadata_args.type = None # From args["type"]
|
||||
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
# Step 4: Service layer crashes on len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_correct_nullable_false_configuration_works(self, app):
|
||||
"""Test that the correct nullable=False configuration works as expected."""
|
||||
# This tests the FIXED configuration
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This should fail with BadRequest due to nullable=False
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
valid_payload = {"type": "string", "name": "valid"}
|
||||
args = MetadataArgs.model_validate(valid_payload)
|
||||
assert args.type == "string"
|
||||
assert args.name == "valid"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -0,0 +1,88 @@
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod
|
||||
from models.provider import ProviderType
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class _FakeConfigurations:
|
||||
def __init__(self, provider_configuration: types.SimpleNamespace) -> None:
|
||||
self._provider_configuration = provider_configuration
|
||||
|
||||
def values(self) -> list[types.SimpleNamespace]:
|
||||
return [self._provider_configuration]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_fake_configurations():
|
||||
# Build a fake provider schema with minimal fields used by ProviderResponse
|
||||
fake_provider = types.SimpleNamespace(
|
||||
provider="langgenius/openai_api_compatible/openai_api_compatible",
|
||||
label=I18nObject(en_US="OpenAI API Compatible", zh_Hans="OpenAI API Compatible"),
|
||||
description=None,
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
icon_large=None,
|
||||
background=None,
|
||||
help=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
)
|
||||
|
||||
# Include decrypted credentials to simulate the leak source
|
||||
custom_model = CustomModelConfiguration(
|
||||
model="gpt-4o-mini",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-plain-text", "endpoint": "https://example.com"},
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="API KEY 1",
|
||||
available_model_credentials=[],
|
||||
unadded_to_model_list=False,
|
||||
)
|
||||
|
||||
fake_custom_provider = types.SimpleNamespace(
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="API KEY 1",
|
||||
available_credentials=[CredentialConfiguration(credential_id="cred-1", credential_name="API KEY 1")],
|
||||
)
|
||||
|
||||
fake_custom_configuration = types.SimpleNamespace(
|
||||
provider=fake_custom_provider, models=[custom_model], can_added_models=[]
|
||||
)
|
||||
|
||||
fake_system_configuration = types.SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[])
|
||||
|
||||
fake_provider_configuration = types.SimpleNamespace(
|
||||
provider=fake_provider,
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
custom_configuration=fake_custom_configuration,
|
||||
system_configuration=fake_system_configuration,
|
||||
is_custom_configuration_available=lambda: True,
|
||||
)
|
||||
|
||||
class _FakeProviderManager:
|
||||
def get_configurations(self, tenant_id: str) -> _FakeConfigurations:
|
||||
return _FakeConfigurations(fake_provider_configuration)
|
||||
|
||||
svc = ModelProviderService()
|
||||
svc.provider_manager = _FakeProviderManager()
|
||||
return svc
|
||||
|
||||
|
||||
def test_get_provider_list_strips_credentials(service_with_fake_configurations: ModelProviderService):
|
||||
providers = service_with_fake_configurations.get_provider_list(tenant_id="tenant-1", model_type=None)
|
||||
|
||||
assert len(providers) == 1
|
||||
custom_models = providers[0].custom_configuration.custom_models
|
||||
|
||||
assert custom_models is not None
|
||||
assert len(custom_models) == 1
|
||||
# The sanitizer should drop credentials in list response
|
||||
assert custom_models[0].credentials is None
|
||||
1232
api/tests/unit_tests/tasks/test_clean_dataset_task.py
Normal file
1232
api/tests/unit_tests/tasks/test_clean_dataset_task.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Document
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from tasks.document_indexing_task import (
|
||||
_document_indexing,
|
||||
_document_indexing_with_tenant_queue,
|
||||
@ -138,7 +138,9 @@ class TestTaskEnqueuing:
|
||||
with patch.object(DocumentIndexingTaskProxy, "features") as mock_features:
|
||||
mock_features.billing.enabled = False
|
||||
|
||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Act
|
||||
@ -163,7 +165,9 @@ class TestTaskEnqueuing:
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task):
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Act
|
||||
@ -187,7 +191,9 @@ class TestTaskEnqueuing:
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Act
|
||||
@ -211,7 +217,9 @@ class TestTaskEnqueuing:
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Act
|
||||
@ -1493,7 +1501,9 @@ class TestEdgeCases:
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||
# Act - Enqueue multiple tasks rapidly
|
||||
for doc_ids in document_ids_list:
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids)
|
||||
@ -1898,7 +1908,7 @@ class TestRobustness:
|
||||
- Error is propagated appropriately
|
||||
"""
|
||||
# Arrange
|
||||
with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features:
|
||||
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features:
|
||||
# Simulate FeatureService failure
|
||||
mock_get_features.side_effect = Exception("Feature service unavailable")
|
||||
|
||||
|
||||
112
api/tests/unit_tests/tasks/test_delete_account_task.py
Normal file
112
api/tests/unit_tests/tasks/test_delete_account_task.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""
|
||||
Unit tests for delete_account_task.
|
||||
|
||||
Covers:
|
||||
- Billing enabled with existing account: calls billing and sends success email
|
||||
- Billing disabled with existing account: skips billing, sends success email
|
||||
- Account not found: still calls billing when enabled, does not send email
|
||||
- Billing deletion raises: logs and re-raises, no email
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tasks.delete_account_task import delete_account_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock the db.session used in delete_account_task."""
|
||||
with patch("tasks.delete_account_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deps():
|
||||
"""Patch external dependencies: BillingService and send_deletion_success_task."""
|
||||
with (
|
||||
patch("tasks.delete_account_task.BillingService") as mock_billing,
|
||||
patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task,
|
||||
):
|
||||
# ensure .delay exists on the mail task
|
||||
mock_mail_task.delay = MagicMock()
|
||||
yield {
|
||||
"billing": mock_billing,
|
||||
"mail_task": mock_mail_task,
|
||||
}
|
||||
|
||||
|
||||
def _set_account_found(mock_db_session, email: str = "user@example.com"):
|
||||
account = SimpleNamespace(email=email)
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = account
|
||||
return account
|
||||
|
||||
|
||||
def _set_account_missing(mock_db_session):
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
|
||||
class TestDeleteAccountTask:
|
||||
def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps):
|
||||
# Arrange
|
||||
account_id = "acc-123"
|
||||
account = _set_account_found(mock_db_session, email="a@b.com")
|
||||
|
||||
# Enable billing
|
||||
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
|
||||
# Act
|
||||
delete_account_task(account_id)
|
||||
|
||||
# Assert
|
||||
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
|
||||
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
|
||||
|
||||
def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps):
|
||||
# Arrange
|
||||
account_id = "acc-456"
|
||||
account = _set_account_found(mock_db_session, email="x@y.com")
|
||||
|
||||
# Disable billing
|
||||
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False):
|
||||
# Act
|
||||
delete_account_task(account_id)
|
||||
|
||||
# Assert
|
||||
mock_deps["billing"].delete_account.assert_not_called()
|
||||
mock_deps["mail_task"].delay.assert_called_once_with(account.email)
|
||||
|
||||
def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog):
|
||||
# Arrange
|
||||
account_id = "missing-id"
|
||||
_set_account_missing(mock_db_session)
|
||||
|
||||
# Enable billing
|
||||
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
|
||||
# Act
|
||||
delete_account_task(account_id)
|
||||
|
||||
# Assert
|
||||
mock_deps["billing"].delete_account.assert_called_once_with(account_id)
|
||||
mock_deps["mail_task"].delay.assert_not_called()
|
||||
# Optional: verify log contains not found message
|
||||
assert any("not found" in rec.getMessage().lower() for rec in caplog.records)
|
||||
|
||||
def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps):
|
||||
# Arrange
|
||||
account_id = "acc-err"
|
||||
_set_account_found(mock_db_session, email="err@ex.com")
|
||||
mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down")
|
||||
|
||||
# Enable billing
|
||||
with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True):
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError):
|
||||
delete_account_task(account_id)
|
||||
|
||||
# Ensure email was not sent
|
||||
mock_deps["mail_task"].delay.assert_not_called()
|
||||
@ -0,0 +1,567 @@
|
||||
"""
|
||||
Unit tests for duplicate document indexing tasks.
|
||||
|
||||
This module tests the duplicate document indexing task functionality including:
|
||||
- Task enqueuing to different queues (normal, priority, tenant-isolated)
|
||||
- Batch processing of multiple duplicate documents
|
||||
- Progress tracking through task lifecycle
|
||||
- Error handling and retry mechanisms
|
||||
- Cleanup of old document data before re-indexing
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.duplicate_document_indexing_task import (
|
||||
_duplicate_document_indexing_task,
|
||||
_duplicate_document_indexing_task_with_tenant_queue,
|
||||
duplicate_document_indexing_task,
|
||||
normal_duplicate_document_indexing_task,
|
||||
priority_duplicate_document_indexing_task,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
"""Generate a unique tenant ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
"""Generate a unique dataset ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_ids():
|
||||
"""Generate a list of document IDs for testing."""
|
||||
return [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id, tenant_id):
|
||||
"""Create a mock Dataset object."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_documents(document_ids, dataset_id):
|
||||
"""Create mock Document objects."""
|
||||
documents = []
|
||||
for doc_id in document_ids:
|
||||
doc = Mock(spec=Document)
|
||||
doc.id = doc_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.indexing_status = "waiting"
|
||||
doc.error = None
|
||||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
doc.doc_form = "text_model"
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_segments(document_ids):
|
||||
"""Create mock DocumentSegment objects."""
|
||||
segments = []
|
||||
for doc_id in document_ids:
|
||||
for i in range(3):
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = doc_id
|
||||
segment.index_node_id = f"node-{doc_id}-{i}"
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_service():
|
||||
"""Mock FeatureService."""
|
||||
with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service:
|
||||
mock_features = Mock()
|
||||
mock_features.billing = Mock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_features.vector_space = Mock()
|
||||
mock_features.vector_space.size = 0
|
||||
mock_features.vector_space.limit = 1000
|
||||
mock_service.get_features.return_value = mock_features
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
yield mock_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant_isolated_queue():
|
||||
"""Mock TenantIsolatedTaskQueue."""
|
||||
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class:
|
||||
mock_queue = MagicMock(spec=TenantIsolatedTaskQueue)
|
||||
mock_queue.pull_tasks.return_value = []
|
||||
mock_queue.delete_task_key = Mock()
|
||||
mock_queue.set_task_waiting_time = Mock()
|
||||
mock_queue_class.return_value = mock_queue
|
||||
yield mock_queue
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for deprecated duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTask:
|
||||
"""Tests for the deprecated duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids):
|
||||
"""Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function."""
|
||||
# Act
|
||||
duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id):
|
||||
"""Test duplicate_document_indexing_task with empty document_ids list."""
|
||||
# Arrange
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for _duplicate_document_indexing_task core function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTaskCore:
|
||||
"""Tests for the _duplicate_document_indexing_task core function."""
|
||||
|
||||
def test_successful_duplicate_document_indexing(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test successful duplicate document indexing flow."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Verify IndexingRunner was called
|
||||
mock_indexing_runner.run.assert_called_once()
|
||||
|
||||
# Verify all documents were set to parsing status
|
||||
for doc in mock_documents:
|
||||
assert doc.indexing_status == "parsing"
|
||||
assert doc.processing_started_at is not None
|
||||
|
||||
# Verify session operations
|
||||
assert mock_db_session.commit.called
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids):
|
||||
"""Test duplicate document indexing when dataset is not found."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should close the session at least once
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_feature_service,
|
||||
mock_dataset,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing with billing enabled and sandbox plan."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# For sandbox plan with multiple documents, should fail
|
||||
mock_db_session.commit.assert_called()
|
||||
|
||||
def test_duplicate_document_indexing_with_billing_limit_exceeded(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_feature_service,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when billing limit is exceeded."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.TEAM
|
||||
mock_features.vector_space.size = 990
|
||||
mock_features.vector_space.limit = 1000
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should commit the session
|
||||
assert mock_db_session.commit.called
|
||||
# Should close the session
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_runner_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when IndexingRunner raises an error."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should close the session even after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_document_is_paused(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when document is paused."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should handle DocumentIsPausedError gracefully
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_cleans_old_segments(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that duplicate document indexing cleans old segments."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Verify clean was called for each document
|
||||
assert mock_processor.clean.call_count == len(mock_documents)
|
||||
|
||||
# Verify segments were deleted
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for tenant queue wrapper function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
"""Tests for _duplicate_document_indexing_task_with_tenant_queue function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_tenant_queue_wrapper_calls_core_function(
|
||||
self,
|
||||
mock_core_func,
|
||||
mock_tenant_isolated_queue,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that tenant queue wrapper calls the core function."""
|
||||
# Arrange
|
||||
mock_task_func = Mock()
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_tenant_queue_wrapper_deletes_key_when_no_tasks(
|
||||
self,
|
||||
mock_core_func,
|
||||
mock_tenant_isolated_queue,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that tenant queue wrapper deletes task key when no more tasks."""
|
||||
# Arrange
|
||||
mock_task_func = Mock()
|
||||
mock_tenant_isolated_queue.pull_tasks.return_value = []
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert
|
||||
mock_tenant_isolated_queue.delete_task_key.assert_called_once()
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_tenant_queue_wrapper_processes_next_tasks(
|
||||
self,
|
||||
mock_core_func,
|
||||
mock_tenant_isolated_queue,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that tenant queue wrapper processes next tasks from queue."""
|
||||
# Arrange
|
||||
mock_task_func = Mock()
|
||||
next_task = {
|
||||
"tenant_id": tenant_id,
|
||||
"dataset_id": dataset_id,
|
||||
"document_ids": document_ids,
|
||||
}
|
||||
mock_tenant_isolated_queue.pull_tasks.return_value = [next_task]
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert
|
||||
mock_tenant_isolated_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task_func.delay.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_tenant_queue_wrapper_handles_core_function_error(
|
||||
self,
|
||||
mock_core_func,
|
||||
mock_tenant_isolated_queue,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that tenant queue wrapper handles errors from core function."""
|
||||
# Arrange
|
||||
mock_task_func = Mock()
|
||||
mock_core_func.side_effect = Exception("Core function error")
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert
|
||||
# Should still check for next tasks even after error
|
||||
mock_tenant_isolated_queue.pull_tasks.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for normal_duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestNormalDuplicateDocumentIndexingTask:
|
||||
"""Tests for normal_duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_normal_task_calls_tenant_queue_wrapper(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that normal task calls tenant queue wrapper."""
|
||||
# Act
|
||||
normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_normal_task_with_empty_document_ids(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
):
|
||||
"""Test normal task with empty document_ids list."""
|
||||
# Arrange
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for priority_duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPriorityDuplicateDocumentIndexingTask:
|
||||
"""Tests for priority_duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_priority_task_calls_tenant_queue_wrapper(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that priority task calls tenant queue wrapper."""
|
||||
# Act
|
||||
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_priority_task_with_single_document(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
):
|
||||
"""Test priority task with single document."""
|
||||
# Arrange
|
||||
document_ids = ["doc-1"]
|
||||
|
||||
# Act
|
||||
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_priority_task_with_large_batch(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
):
|
||||
"""Test priority task with large batch of documents."""
|
||||
# Arrange
|
||||
document_ids = [f"doc-{i}" for i in range(100)]
|
||||
|
||||
# Act
|
||||
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
@ -8,10 +8,13 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
[
|
||||
("...Hello, World!", "Hello, World!"),
|
||||
("。测试中文标点", "测试中文标点"),
|
||||
("!@#Test symbols", "Test symbols"),
|
||||
# Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols"
|
||||
# The pattern intentionally excludes ! as per #11868 fix
|
||||
("@#Test symbols", "Test symbols"),
|
||||
("Hello, World!", "Hello, World!"),
|
||||
("", ""),
|
||||
(" ", " "),
|
||||
("【测试】", "【测试】"),
|
||||
],
|
||||
)
|
||||
def test_remove_leading_symbols(input_text, expected_output):
|
||||
|
||||
Reference in New Issue
Block a user