Merge branch 'main' into feat/agent-node-v2

This commit is contained in:
Novice
2025-12-30 10:20:42 +08:00
232 changed files with 18692 additions and 2696 deletions

View File

@ -7,11 +7,14 @@ CODE_LANGUAGE = CodeLanguage.JINJA2
def test_jinja2():
"""Test basic Jinja2 template rendering."""
template = "Hello {{template}}"
# Template must be base64 encoded to match the new safe embedding approach
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
code = (
Jinja2TemplateTransformer.get_runner_script()
.replace(Jinja2TemplateTransformer._code_placeholder, template)
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
)
result = CodeExecutor.execute_code(
@ -21,6 +24,7 @@ def test_jinja2():
def test_jinja2_with_code_template():
"""Test template rendering via the high-level workflow API."""
result = CodeExecutor.execute_workflow_code_template(
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
)
@ -28,7 +32,64 @@ def test_jinja2_with_code_template():
def test_jinja2_get_runner_script():
"""Test that runner script contains required placeholders."""
runner_script = Jinja2TemplateTransformer.get_runner_script()
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
def test_jinja2_template_with_special_characters():
"""
Test that templates with special characters (quotes, newlines) render correctly.
This is a regression test for issue #26818 where textarea pre-fill values
containing special characters would break template rendering.
"""
# Template with triple quotes, single quotes, double quotes, and newlines
template = """<html>
<body>
<input value="{{ task.get('Task ID', '') }}"/>
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
<p>Status: "{{ status }}"</p>
<pre>'''code block'''</pre>
</body>
</html>"""
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
# Verify the template rendered correctly with all special characters
output = result["result"]
assert 'value="TASK-123"' in output
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
assert 'Status: "completed"' in output
assert "'''code block'''" in output
def test_jinja2_template_with_html_textarea_prefill():
"""
Specific test for HTML textarea with Jinja2 variable pre-fill.
Verifies fix for issue #26818.
"""
template = "<textarea name='notes'>{{ notes }}</textarea>"
notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes."
inputs = {"notes": notes_content}
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
expected_output = f"<textarea name='notes'>{notes_content}</textarea>"
assert result["result"] == expected_output
def test_jinja2_assemble_runner_script_encodes_template():
"""Test that assemble_runner_script properly base64 encodes the template."""
template = "Hello {{ name }}!"
inputs = {"name": "World"}
script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs)
# The template should be base64 encoded in the script
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
assert template_b64 in script
# The raw template should NOT appear in the script (it's encoded)
assert "Hello {{ name }}!" not in script

View File

@ -6,6 +6,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.node_factory import DifyNodeFactory
@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
ssl_verify=True,
)
# Create executor
executor = Executor(
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
)
# Get assembled headers
headers = executor._assembling_headers()
# When api_key is empty, the custom header should NOT be set
assert "X-Custom-Auth" not in headers
# Create executor should raise AuthorizationConfigError
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10),
variable_pool=variable_pool,
)
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_authorization_with_empty_api_key(setup_http_mock):
"""
Test that custom authorization doesn't set header when api_key is empty.
This test verifies the fix for issue #23554.
Test that custom authorization raises error when api_key is empty.
This test verifies the fix for issue #21830.
"""
node = init_http_node(
config={
"id": "1",
@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# Custom header should NOT be set when api_key is empty
assert "X-Custom-Auth:" not in data
# Should fail with AuthorizationConfigError
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "API key is required" in result.error
assert result.error_type == "AuthorizationConfigError"
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)

View File

@ -0,0 +1,365 @@
import json
from unittest.mock import patch
import pytest
from extensions.ext_redis import redis_client
from services.billing_service import BillingService
class TestBillingServiceGetPlanBulkWithCache:
"""
Comprehensive integration tests for get_plan_bulk_with_cache using testcontainers.
This test class covers all major scenarios:
- Cache hit/miss scenarios
- Redis operation failures and fallback behavior
- Invalid cache data handling
- TTL expiration handling
- Error recovery and logging
"""
@pytest.fixture(autouse=True)
def setup_redis_cleanup(self, flask_app_with_containers):
"""Clean up Redis cache before and after each test."""
with flask_app_with_containers.app_context():
# Clean up before test
yield
# Clean up after test
# Delete all test cache keys
pattern = f"{BillingService._PLAN_CACHE_KEY_PREFIX}*"
keys = redis_client.keys(pattern)
if keys:
redis_client.delete(*keys)
def _create_test_plan_data(self, plan: str = "sandbox", expiration_date: int = 1735689600):
"""Helper to create test SubscriptionPlan data."""
return {"plan": plan, "expiration_date": expiration_date}
def _set_cache(self, tenant_id: str, plan_data: dict, ttl: int = 600):
"""Helper to set cache data in Redis."""
cache_key = BillingService._make_plan_cache_key(tenant_id)
json_str = json.dumps(plan_data)
redis_client.setex(cache_key, ttl, json_str)
def _get_cache(self, tenant_id: str):
"""Helper to get cache data from Redis."""
cache_key = BillingService._make_plan_cache_key(tenant_id)
value = redis_client.get(cache_key)
if value:
if isinstance(value, bytes):
return value.decode("utf-8")
return value
return None
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers):
"""Test bulk plan retrieval when all tenants are in cache."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
"tenant-3": self._create_test_plan_data("team", 1798761600),
}
# Pre-populate cache
for tenant_id, plan_data in expected_plans.items():
self._set_cache(tenant_id, plan_data)
# Act
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 3
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-1"]["expiration_date"] == 1735689600
assert result["tenant-2"]["plan"] == "professional"
assert result["tenant-2"]["expiration_date"] == 1767225600
assert result["tenant-3"]["plan"] == "team"
assert result["tenant-3"]["expiration_date"] == 1798761600
# Verify API was not called
mock_get_plan_bulk.assert_not_called()
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers):
"""Test bulk plan retrieval when all tenants are not in cache."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 2
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
# Verify API was called with correct tenant_ids
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
# Verify data was written to cache
cached_1 = self._get_cache("tenant-1")
cached_2 = self._get_cache("tenant-2")
assert cached_1 is not None
assert cached_2 is not None
# Verify cache content
cached_data_1 = json.loads(cached_1)
cached_data_2 = json.loads(cached_2)
assert cached_data_1 == expected_plans["tenant-1"]
assert cached_data_2 == expected_plans["tenant-2"]
# Verify TTL is set
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
ttl_1 = redis_client.ttl(cache_key_1)
assert ttl_1 > 0
assert ttl_1 <= 600 # Should be <= 600 seconds
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers):
"""Test bulk plan retrieval when some tenants are in cache, some are not."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
# Pre-populate cache for tenant-1 and tenant-2
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
self._set_cache("tenant-2", self._create_test_plan_data("professional", 1767225600))
# tenant-3 is not in cache
missing_plan = {"tenant-3": self._create_test_plan_data("team", 1798761600)}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=missing_plan) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 3
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
assert result["tenant-3"]["plan"] == "team"
# Verify API was called only for missing tenant
mock_get_plan_bulk.assert_called_once_with(["tenant-3"])
# Verify tenant-3 data was written to cache
cached_3 = self._get_cache("tenant-3")
assert cached_3 is not None
cached_data_3 = json.loads(cached_3)
assert cached_data_3 == missing_plan["tenant-3"]
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers):
"""Test fallback to API when Redis mget fails."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
}
# Act
with (
patch.object(redis_client, "mget", side_effect=Exception("Redis connection error")),
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk,
):
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 2
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
# Verify API was called for all tenants (fallback)
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
# Verify data was written to cache after fallback
cached_1 = self._get_cache("tenant-1")
cached_2 = self._get_cache("tenant-2")
assert cached_1 is not None
assert cached_2 is not None
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers):
"""Test fallback to API when cache contains invalid JSON."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
# Set valid cache for tenant-1
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
# Set invalid JSON for tenant-2
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
redis_client.setex(cache_key_2, 600, "invalid json {")
# tenant-3 is not in cache
expected_plans = {
"tenant-2": self._create_test_plan_data("professional", 1767225600),
"tenant-3": self._create_test_plan_data("team", 1798761600),
}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 3
assert result["tenant-1"]["plan"] == "sandbox" # From cache
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
assert result["tenant-3"]["plan"] == "team" # From API
# Verify API was called for tenant-2 and tenant-3
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
# Verify tenant-2's invalid JSON was replaced with correct data in cache
cached_2 = self._get_cache("tenant-2")
assert cached_2 is not None
cached_data_2 = json.loads(cached_2)
assert cached_data_2 == expected_plans["tenant-2"]
assert cached_data_2["plan"] == "professional"
assert cached_data_2["expiration_date"] == 1767225600
# Verify tenant-2 cache has correct TTL
cache_key_2_new = BillingService._make_plan_cache_key("tenant-2")
ttl_2 = redis_client.ttl(cache_key_2_new)
assert ttl_2 > 0
assert ttl_2 <= 600
# Verify tenant-3 data was also written to cache
cached_3 = self._get_cache("tenant-3")
assert cached_3 is not None
cached_data_3 = json.loads(cached_3)
assert cached_data_3 == expected_plans["tenant-3"]
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers):
"""Test fallback to API when cache data doesn't match SubscriptionPlan schema."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
# Set valid cache for tenant-1
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
# Set invalid plan data for tenant-2 (missing expiration_date)
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
invalid_data = json.dumps({"plan": "professional"}) # Missing expiration_date
redis_client.setex(cache_key_2, 600, invalid_data)
# tenant-3 is not in cache
expected_plans = {
"tenant-2": self._create_test_plan_data("professional", 1767225600),
"tenant-3": self._create_test_plan_data("team", 1798761600),
}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 3
assert result["tenant-1"]["plan"] == "sandbox" # From cache
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
assert result["tenant-3"]["plan"] == "team" # From API
# Verify API was called for tenant-2 and tenant-3
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers):
"""Test that pipeline failure doesn't affect return value."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
}
# Act
with (
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans),
patch.object(redis_client, "pipeline") as mock_pipeline,
):
# Create a mock pipeline that fails on execute
mock_pipe = mock_pipeline.return_value
mock_pipe.execute.side_effect = Exception("Pipeline execution failed")
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert - Function should still return correct result despite pipeline failure
assert len(result) == 2
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
# Verify pipeline was attempted
mock_pipeline.assert_called_once()
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers):
"""Test with empty tenant_ids list."""
with flask_app_with_containers.app_context():
# Act
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache([])
# Assert
assert result == {}
assert len(result) == 0
# Verify no API calls
mock_get_plan_bulk.assert_not_called()
# Verify no Redis operations (mget with empty list would return empty list)
# But we should check that mget was not called at all
# Since we can't easily verify this without more mocking, we just verify the result
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers):
"""Test that expired cache keys are treated as cache misses."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
# Set cache for tenant-1 with very short TTL (1 second) to simulate expiration
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600), ttl=1)
# Wait for TTL to expire (key will be deleted by Redis)
import time
time.sleep(2)
# Verify cache is expired (key doesn't exist)
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
exists = redis_client.exists(cache_key_1)
assert exists == 0 # Key doesn't exist (expired)
# tenant-2 is not in cache
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 2
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
# Verify API was called for both tenants (tenant-1 expired, tenant-2 missing)
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
# Verify both were written to cache with correct TTL
cache_key_1_new = BillingService._make_plan_cache_key("tenant-1")
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
ttl_1_new = redis_client.ttl(cache_key_1_new)
ttl_2 = redis_client.ttl(cache_key_2)
assert ttl_1_new > 0
assert ttl_1_new <= 600
assert ttl_2 > 0
assert ttl_2 <= 600

View File

@ -0,0 +1,682 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
from extensions.ext_database import db
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
class TestTriggerProviderService:
"""Integration tests for TriggerProviderService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager,
patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client,
patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns
mock_provider_controller = MagicMock()
mock_provider_controller.get_credential_schema_config.return_value = MagicMock()
mock_provider_controller.get_properties_schema.return_value = MagicMock()
mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller
# Mock redis lock
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock(return_value=None)
mock_lock.__exit__ = MagicMock(return_value=None)
mock_redis_client.lock.return_value = mock_lock
# Setup account feature service mock
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
yield {
"trigger_manager": mock_trigger_manager,
"redis_client": mock_redis_client,
"delete_cache": mock_delete_cache,
"provider_controller": mock_provider_controller,
"account_feature_service": mock_account_feature_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
from services.account_service import AccountService, TenantService
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies[
"trigger_manager"
].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"]
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
return account, tenant
def _create_test_subscription(
self,
db_session_with_containers,
tenant_id,
user_id,
provider_id,
credential_type,
credentials,
mock_external_service_dependencies,
):
"""
Helper method to create a test trigger subscription.
Args:
db_session_with_containers: Database session
tenant_id: Tenant ID
user_id: User ID
provider_id: Provider ID
credential_type: Credential type
credentials: Credentials dict
mock_external_service_dependencies: Mock dependencies
Returns:
TriggerSubscription: Created subscription instance
"""
fake = Faker()
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
# Use mock provider controller to encrypt credentials
provider_controller = mock_external_service_dependencies["provider_controller"]
# Create encrypter for credentials
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
subscription = TriggerSubscription(
name=fake.word(),
tenant_id=tenant_id,
user_id=user_id,
provider_id=str(provider_id),
endpoint_id=fake.uuid4(),
parameters={"param1": "value1"},
properties={"prop1": "value1"},
credentials=dict(credential_encrypter.encrypt(credentials)),
credential_type=credential_type.value,
credential_expires_at=-1,
expires_at=-1,
)
db.session.add(subscription)
db.session.commit()
db.session.refresh(subscription)
return subscription
def test_rebuild_trigger_subscription_success_with_merged_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
This test verifies:
- Credentials are properly merged (HIDDEN_VALUE replaced with existing values)
- Single transaction wraps all operations
- Merged credentials are used for subscribe and update
- Database state is correctly updated
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create initial subscription with credentials
original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Prepare new credentials with HIDDEN_VALUE for api_key (should keep original)
# and new value for api_secret (should update)
new_credentials = {
"api_key": HIDDEN_VALUE, # Should be replaced with original
"api_secret": "new-secret-value", # Should be updated
}
# Mock subscribe_trigger to return a new subscription entity
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={"param1": "value1"},
properties={"prop1": "new_prop_value"},
expires_at=1234567890,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Mock unsubscribe_trigger
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={"param1": "updated_value"},
name="updated_name",
)
# Verify unsubscribe was called with decrypted original credentials
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once()
unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args
assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id
assert unsubscribe_call_args.kwargs["provider_id"] == provider_id
assert unsubscribe_call_args.kwargs["credential_type"] == credential_type
# Verify subscribe was called with merged credentials (api_key from original, api_secret new)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
# Verify database state was updated
db.session.refresh(subscription)
assert subscription.name == "updated_name"
assert subscription.parameters == {"param1": "updated_value"}
# Verify credentials in DB were updated with merged values (decrypt to check)
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
# Use mock provider controller to decrypt credentials
provider_controller = mock_external_service_dependencies["provider_controller"]
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant.id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
assert decrypted_db_credentials["api_key"] == original_credentials["api_key"]
assert decrypted_db_credentials["api_secret"] == "new-secret-value"
# Verify cache was cleared
mock_external_service_dependencies["delete_cache"].assert_called_once_with(
tenant_id=tenant.id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
def test_rebuild_trigger_subscription_with_all_new_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are new (no HIDDEN_VALUE).
This test verifies:
- All new credentials are used when no HIDDEN_VALUE is present
- Merged credentials contain only new values
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create initial subscription
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# All new credentials (no HIDDEN_VALUE)
new_credentials = {
"api_key": "completely-new-key",
"api_secret": "completely-new-secret",
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with all new credentials
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == "completely-new-key"
assert subscribe_credentials["api_secret"] == "completely-new-secret"
def test_rebuild_trigger_subscription_with_all_hidden_values(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
This test verifies:
- All HIDDEN_VALUE credentials are replaced with existing values
- Original credentials are preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# All HIDDEN_VALUE (should preserve all original)
new_credentials = {
"api_key": HIDDEN_VALUE,
"api_secret": HIDDEN_VALUE,
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with all original credentials
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
This test verifies:
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Original has only api_key
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE
new_credentials = {
"api_key": HIDDEN_VALUE,
"non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
def test_rebuild_trigger_subscription_rollback_on_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that transaction is rolled back on error.
This test verifies:
- Database transaction is rolled back when an error occurs
- Original subscription state is preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
original_name = subscription.name
original_parameters = subscription.parameters.copy()
# Make subscribe_trigger raise an error
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError(
"Subscribe failed"
)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild and expect error
with pytest.raises(ValueError, match="Subscribe failed"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscription state was not changed (rolled back)
db.session.refresh(subscription)
assert subscription.name == original_name
assert subscription.parameters == original_parameters
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that unsubscribe errors are handled gracefully and operation continues.
This test verifies:
- Unsubscribe errors are caught and logged but don't stop the rebuild
- Rebuild continues even if unsubscribe fails
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Make unsubscribe_trigger raise an error (should be caught and continue)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
"Unsubscribe failed"
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Execute rebuild - should succeed despite unsubscribe error
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscribe was still called (operation continued)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
# Verify subscription was updated
db.session.refresh(subscription)
assert subscription.parameters == {}
def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when subscription is not found.
This test verifies:
- Proper error is raised when subscription doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
fake_subscription_id = fake.uuid4()
with pytest.raises(ValueError, match="not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake_subscription_id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when provider is not found.
This test verifies:
- Proper error is raised when provider doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
# Make get_trigger_provider return None
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
with pytest.raises(ValueError, match="Provider.*not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake.uuid4(),
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_unsupported_credential_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when credential type is not supported for rebuild.
This test verifies:
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.UNAUTHORIZED # Not supported
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{},
mock_external_service_dependencies,
)
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_name_uniqueness_check(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that name uniqueness is checked when updating name.
This test verifies:
- Error is raised when new name conflicts with existing subscription
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create first subscription
subscription1 = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{"api_key": "key1"},
mock_external_service_dependencies,
)
# Create second subscription with different name
subscription2 = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{"api_key": "key2"},
mock_external_service_dependencies,
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription2.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Try to rename subscription2 to subscription1's name (should fail)
with pytest.raises(ValueError, match="already exists"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription2.id,
credentials={"api_key": "new-key"},
parameters={},
name=subscription1.name, # Conflicting name
)

View File

@ -2,7 +2,9 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import TypeAdapter, ValidationError
from core.tools.entities.tool_entities import ApiProviderSchemaType
from models import Account, Tenant
from models.tools import ApiToolProvider
from services.tools.api_tools_manage_service import ApiToolManageService
@ -298,7 +300,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
schema_type = "openapi"
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@ -364,7 +366,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none"}
schema_type = "openapi"
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@ -428,21 +430,10 @@ class TestApiToolManageService:
labels = ["test"]
# Act & Assert: Try to create provider with invalid schema type
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
with pytest.raises(ValidationError) as exc_info:
TypeAdapter(ApiProviderSchemaType).validate_python(schema_type)
assert "invalid schema type" in str(exc_info.value)
assert "validation error" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
@ -464,7 +455,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {} # Missing auth_type
schema_type = "openapi"
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@ -507,7 +498,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔑"}
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
schema_type = "openapi"
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"

View File

@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
result = MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
# Assert: Verify the expected outcomes
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
assert tools_data[1]["name"] == "test_tool_2"
# Verify mock interactions
provider_entity = mcp_provider.to_entity()
mock_mcp_client.assert_called_once()
mock_mcp_client.assert_called_once_with(
server_url="https://example.com/mcp",
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
result = MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
# Assert: Verify the expected outcomes
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
service._reconnect_provider(
MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)

View File

@ -705,3 +705,207 @@ class TestWorkflowToolManageService:
db.session.refresh(created_tool)
assert created_tool.name == first_tool_name
assert created_tool.updated_at is not None
def test_create_workflow_tool_with_file_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILE parameter having a file object as default.
This test verifies:
- FILE parameters can have file object defaults
- The default value (dict with id/base64Url) is properly handled
- Tool creation succeeds without Pydantic validation errors
Related issue: Array[File] default value causes Pydantic validation errors.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create workflow graph with a FILE variable that has a default value
workflow_graph = {
"nodes": [
{
"id": "start_node",
"data": {
"type": "start",
"variables": [
{
"variable": "document",
"label": "Document",
"type": "file",
"required": False,
"default": {"id": fake.uuid4(), "base64Url": ""},
}
],
},
}
]
}
workflow.graph = json.dumps(workflow_graph)
# Setup workflow tool parameters with FILE type
file_parameters = [
{
"name": "document",
"description": "Upload a document",
"form": "form",
"type": "file",
"required": False,
}
]
# Execute the method under test
# Note: from_db is mocked, so this test primarily validates the parameter configuration
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "📄"},
description=fake.text(max_nb_chars=200),
parameters=file_parameters,
)
# Verify the result
assert result == {"result": "success"}
def test_create_workflow_tool_with_files_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
This test verifies:
- FILES parameters can have a list of file objects as default
- The default value (list of dicts with id/base64Url) is properly handled
- Tool creation succeeds without Pydantic validation errors
Related issue: Array[File] default value causes 4 Pydantic validation errors
because PluginParameter.default only accepts Union[float, int, str, bool] | None.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create workflow graph with a FILE_LIST variable that has a default value
workflow_graph = {
"nodes": [
{
"id": "start_node",
"data": {
"type": "start",
"variables": [
{
"variable": "documents",
"label": "Documents",
"type": "file-list",
"required": False,
"default": [
{"id": fake.uuid4(), "base64Url": ""},
{"id": fake.uuid4(), "base64Url": ""},
],
}
],
},
}
]
}
workflow.graph = json.dumps(workflow_graph)
# Setup workflow tool parameters with FILES type
files_parameters = [
{
"name": "documents",
"description": "Upload multiple documents",
"form": "form",
"type": "files",
"required": False,
}
]
# Execute the method under test
# Note: from_db is mocked, so this test primarily validates the parameter configuration
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "📁"},
description=fake.text(max_nb_chars=200),
parameters=files_parameters,
)
# Verify the result
assert result == {"result": "success"}
def test_create_workflow_tool_db_commit_before_validation(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that database commit happens before validation, causing DB pollution on validation failure.
This test verifies the second bug:
- WorkflowToolProvider is committed to database BEFORE from_db validation
- If validation fails, the record remains in the database
- Subsequent attempts fail with "Tool already exists" error
This demonstrates why we need to validate BEFORE database commit.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
tool_name = fake.word()
# Mock from_db to raise validation error
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.side_effect = ValueError(
"Validation failed: default parameter type mismatch"
)
# Attempt to create workflow tool (will fail at validation stage)
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
)
assert "Validation failed" in str(exc_info.value)
# Verify the tool was NOT created in database
# This is the expected behavior (no pollution)
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.name == tool_name,
)
.count()
)
# The record should NOT exist because the transaction should be rolled back
# Currently, due to the bug, the record might exist (this test documents the bug)
# After the fix, this should always be 0
# For now, we document that the record may exist, demonstrating the bug
# assert tool_count == 0 # Expected after fix

View File

@ -12,10 +12,12 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
_, Jinja2TemplateTransformer = self.jinja2_imports
template = "Hello {{template}}"
# Template must be base64 encoded to match the new safe embedding approach
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
code = (
Jinja2TemplateTransformer.get_runner_script()
.replace(Jinja2TemplateTransformer._code_placeholder, template)
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
)
result = CodeExecutor.execute_code(
@ -37,6 +39,34 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
_, Jinja2TemplateTransformer = self.jinja2_imports
runner_script = Jinja2TemplateTransformer.get_runner_script()
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
def test_jinja2_template_with_special_characters(self, flask_app_with_containers):
"""
Test that templates with special characters (quotes, newlines) render correctly.
This is a regression test for issue #26818 where textarea pre-fill values
containing special characters would break template rendering.
"""
CodeExecutor, CodeLanguage = self.code_executor_imports
# Template with triple quotes, single quotes, double quotes, and newlines
template = """<html>
<body>
<input value="{{ task.get('Task ID', '') }}"/>
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
<p>Status: "{{ status }}"</p>
<pre>'''code block'''</pre>
</body>
</html>"""
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
# Verify the template rendered correctly with all special characters
output = result["result"]
assert 'value="TASK-123"' in output
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
assert 'Status: "completed"' in output
assert "'''code block'''" in output

View File

@ -0,0 +1,46 @@
from flask import Response
from controllers.common.file_response import enforce_download_for_html, is_html_content
class TestFileResponseHelpers:
def test_is_html_content_detects_mime_type(self):
mime_type = "text/html; charset=UTF-8"
result = is_html_content(mime_type, filename="file.txt", extension="txt")
assert result is True
def test_is_html_content_detects_extension(self):
result = is_html_content("text/plain", filename="report.html", extension=None)
assert result is True
def test_enforce_download_for_html_sets_headers(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
response,
mime_type="text/html",
filename="unsafe.html",
extension="html",
)
assert updated is True
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_enforce_download_for_html_no_change_for_non_html(self):
response = Response("payload", mimetype="text/plain")
updated = enforce_download_for_html(
response,
mime_type="text/plain",
filename="notes.txt",
extension="txt",
)
assert updated is False
assert "Content-Disposition" not in response.headers
assert "X-Content-Type-Options" not in response.headers

View File

@ -163,34 +163,17 @@ class TestActivateApi:
"account": mock_account,
}
@pytest.fixture
def mock_token_pair(self):
"""Create mock token pair object."""
token_pair = MagicMock()
token_pair.access_token = "access_token"
token_pair.refresh_token = "refresh_token"
token_pair.csrf_token = "csrf_token"
token_pair.model_dump.return_value = {
"access_token": "access_token",
"refresh_token": "refresh_token",
"csrf_token": "csrf_token",
}
return token_pair
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_successful_account_activation(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
):
"""
Test successful account activation.
@ -198,12 +181,10 @@ class TestActivateApi:
Verifies that:
- Account is activated with user preferences
- Account status is set to ACTIVE
- User is logged in after activation
- Invitation token is revoked
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
@ -230,7 +211,6 @@ class TestActivateApi:
assert mock_account.initialized_at is not None
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
mock_db.session.commit.assert_called_once()
mock_login.assert_called_once()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
def test_activation_with_invalid_token(self, mock_get_invitation, app):
@ -264,17 +244,14 @@ class TestActivateApi:
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_sets_interface_theme(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
):
"""
Test that activation sets default interface theme.
@ -284,7 +261,6 @@ class TestActivateApi:
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
@ -317,17 +293,14 @@ class TestActivateApi:
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_with_different_locales(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
language,
timezone,
):
@ -341,7 +314,6 @@ class TestActivateApi:
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
@ -367,27 +339,23 @@ class TestActivateApi:
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_returns_token_data(
def test_activation_returns_success_response(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_token_pair,
):
"""
Test that activation returns authentication tokens.
Test that activation returns a success response without authentication tokens.
Verifies that:
- Token pair is returned in response
- All token types are included (access, refresh, csrf)
- Response contains a success result
- No token data is returned
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
@ -406,24 +374,18 @@ class TestActivateApi:
response = api.post()
# Assert
assert "data" in response
assert response["data"]["access_token"] == "access_token"
assert response["data"]["refresh_token"] == "refresh_token"
assert response["data"]["csrf_token"] == "csrf_token"
assert response == {"result": "success"}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_without_workspace_id(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_token_pair,
):
"""
Test account activation without workspace_id.
@ -434,7 +396,6 @@ class TestActivateApi:
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(

View File

@ -0,0 +1,236 @@
from __future__ import annotations
import builtins
import uuid
from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
from flask import Flask
from flask.views import MethodView as FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = False
if not hasattr(builtins, "MethodView"):
builtins.MethodView = FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = True
from constants import HIDDEN_VALUE
from controllers.console.extension import (
APIBasedExtensionAPI,
APIBasedExtensionDetailAPI,
CodeBasedExtensionAPI,
)
if _NEEDS_METHOD_VIEW_CLEANUP:
delattr(builtins, "MethodView")
from models.account import AccountStatus
from models.api_based_extension import APIBasedExtension
def _make_extension(
*,
name: str = "Sample Extension",
api_endpoint: str = "https://example.com/api",
api_key: str = "super-secret-key",
) -> APIBasedExtension:
extension = APIBasedExtension(
tenant_id="tenant-123",
name=name,
api_endpoint=api_endpoint,
api_key=api_key,
)
extension.id = f"{uuid.uuid4()}"
extension.created_at = datetime.now(tz=UTC)
return extension
@pytest.fixture(autouse=True)
def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
"""Bypass console decorators so handlers can run in isolation."""
import controllers.console.extension as extension_module
from controllers.console import wraps as wraps_module
account = MagicMock()
account.status = AccountStatus.ACTIVE
account.current_tenant_id = "tenant-123"
account.id = "account-123"
account.is_authenticated = True
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True)
monkeypatch.delenv("INIT_PASSWORD", raising=False)
monkeypatch.setattr(extension_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
# The login_required decorator consults the shared LocalProxy in libs.login.
monkeypatch.setattr("libs.login.current_user", account)
monkeypatch.setattr("libs.login.check_csrf_token", lambda *_, **__: None)
return account
@pytest.fixture(autouse=True)
def _restx_mask_defaults(app: Flask):
app.config.setdefault("RESTX_MASK_HEADER", "X-Fields")
app.config.setdefault("RESTX_MASK_SWAGGER", False)
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
service_result = {"entrypoint": "main:agent"}
service_mock = MagicMock(return_value=service_result)
monkeypatch.setattr(
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",
service_mock,
)
with app.test_request_context(
"/console/api/code-based-extension",
method="GET",
query_string={"module": "workflow.tools"},
):
response = CodeBasedExtensionAPI().get()
assert response == {"module": "workflow.tools", "data": service_result}
service_mock.assert_called_once_with("workflow.tools")
def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypatch: pytest.MonkeyPatch):
extension = _make_extension(name="Weather API", api_key="abcdefghi123")
service_mock = MagicMock(return_value=[extension])
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_all_by_tenant_id",
service_mock,
)
with app.test_request_context("/console/api/api-based-extension", method="GET"):
response = APIBasedExtensionAPI().get()
assert response[0]["id"] == extension.id
assert response[0]["name"] == "Weather API"
assert response[0]["api_endpoint"] == extension.api_endpoint
assert response[0]["api_key"].startswith(extension.api_key[:3])
service_mock.assert_called_once_with("tenant-123")
def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
saved_extension = _make_extension(name="Docs API", api_key="saved-secret")
save_mock = MagicMock(return_value=saved_extension)
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
payload = {
"name": "Docs API",
"api_endpoint": "https://docs.example.com/hook",
"api_key": "plain-secret",
}
with app.test_request_context("/console/api/api-based-extension", method="POST", json=payload):
response = APIBasedExtensionAPI().post()
args, _ = save_mock.call_args
created_extension: APIBasedExtension = args[0]
assert created_extension.tenant_id == "tenant-123"
assert created_extension.name == payload["name"]
assert created_extension.api_endpoint == payload["api_endpoint"]
assert created_extension.api_key == payload["api_key"]
assert response["name"] == saved_extension.name
save_mock.assert_called_once()
def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
extension = _make_extension(name="Docs API", api_key="abcdefg12345")
service_mock = MagicMock(return_value=extension)
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
service_mock,
)
extension_id = uuid.uuid4()
with app.test_request_context(f"/console/api/api-based-extension/{extension_id}", method="GET"):
response = APIBasedExtensionDetailAPI().get(extension_id)
assert response["id"] == extension.id
assert response["name"] == extension.name
service_mock.assert_called_once_with("tenant-123", str(extension_id))
def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch):
existing_extension = _make_extension(name="Docs API", api_key="keep-me")
get_mock = MagicMock(return_value=existing_extension)
save_mock = MagicMock(return_value=existing_extension)
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
get_mock,
)
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
payload = {
"name": "Docs API Updated",
"api_endpoint": "https://docs.example.com/v2",
"api_key": HIDDEN_VALUE,
}
extension_id = uuid.uuid4()
with app.test_request_context(
f"/console/api/api-based-extension/{extension_id}",
method="POST",
json=payload,
):
response = APIBasedExtensionDetailAPI().post(extension_id)
assert existing_extension.name == payload["name"]
assert existing_extension.api_endpoint == payload["api_endpoint"]
assert existing_extension.api_key == "keep-me"
save_mock.assert_called_once_with(existing_extension)
assert response["name"] == payload["name"]
def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flask, monkeypatch: pytest.MonkeyPatch):
existing_extension = _make_extension(name="Docs API", api_key="old-secret")
get_mock = MagicMock(return_value=existing_extension)
save_mock = MagicMock(return_value=existing_extension)
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
get_mock,
)
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
payload = {
"name": "Docs API Updated",
"api_endpoint": "https://docs.example.com/v2",
"api_key": "new-secret",
}
extension_id = uuid.uuid4()
with app.test_request_context(
f"/console/api/api-based-extension/{extension_id}",
method="POST",
json=payload,
):
response = APIBasedExtensionDetailAPI().post(extension_id)
assert existing_extension.api_key == "new-secret"
save_mock.assert_called_once_with(existing_extension)
assert response["name"] == payload["name"]
def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
existing_extension = _make_extension()
get_mock = MagicMock(return_value=existing_extension)
delete_mock = MagicMock()
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
get_mock,
)
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.delete", delete_mock)
extension_id = uuid.uuid4()
with app.test_request_context(
f"/console/api/api-based-extension/{extension_id}",
method="DELETE",
):
response, status = APIBasedExtensionDetailAPI().delete(extension_id)
delete_mock.assert_called_once_with(existing_extension)
assert response == {"result": "success"}
assert status == 204

View File

@ -0,0 +1,145 @@
"""Unit tests for load balancing credential validation APIs."""
from __future__ import annotations
import builtins
import importlib
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from flask.views import MethodView
from werkzeug.exceptions import Forbidden
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
from models.account import TenantAccountRole
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
"""Reload controller module with lightweight decorators for testing."""
from controllers.console import console_ns, wraps
from libs import login
def _noop(func):
return func
monkeypatch.setattr(login, "login_required", _noop)
monkeypatch.setattr(wraps, "setup_required", _noop)
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
def _noop_route(*args, **kwargs): # type: ignore[override]
def _decorator(cls):
return cls
return _decorator
monkeypatch.setattr(console_ns, "route", _noop_route)
module_name = "controllers.console.workspace.load_balancing_config"
sys.modules.pop(module_name, None)
module = importlib.import_module(module_name)
return module
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
return SimpleNamespace(current_role=role)
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
user = _mock_user(role)
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
mock_service = MagicMock()
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
return mock_service
def _request_payload():
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
assert response == {"result": "success"}
service.validate_load_balancing_credentials.assert_called_once_with(
tenant_id="tenant-123",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM,
credentials={"api_key": "sk-***"},
)
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
assert response == {"result": "error", "error": "invalid credentials"}
def test_validate_credentials_requires_privileged_role(
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
):
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
with pytest.raises(Forbidden):
api.post(provider="openai")
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
provider="openai", config_id="cfg-1"
)
assert response == {"result": "success"}
service.validate_load_balancing_credentials.assert_called_once_with(
tenant_id="tenant-123",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM,
credentials={"api_key": "sk-***"},
config_id="cfg-1",
)

View File

@ -0,0 +1,100 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
# They are intentionally no-ops because the test already patches the required
# behaviors explicitly via @patch and context managers below.
@pytest.fixture
def _mock_cache():
return
@pytest.fixture
def _mock_user_tenant():
return
@pytest.fixture
def client():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
api = Api(app)
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
db.init_app(app)
# Configure session factory used by controller code
with app.app_context():
configure_session_factory(db.engine)
return app.test_client()
@patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
)
@patch("controllers.console.workspace.tool_providers.Session")
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client):
# Arrange: reconnect returns tools immediately
mock_reconnect.return_value = ReconnectResult(
authed=True,
tools=json.dumps(
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
),
encrypted_credentials="{}",
)
# Fake service.create_provider -> returns object with id for reload
svc = MagicMock()
create_result = MagicMock()
create_result.id = "provider-1"
svc.create_provider.return_value = create_result
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
mock_session.return_value.__enter__.return_value = MagicMock()
# Patch MCPToolManageService constructed inside controller
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
payload = {
"server_url": "http://example.com/mcp",
"name": "demo",
"icon": "😀",
"icon_type": "emoji",
"icon_background": "#000",
"server_identifier": "demo-sid",
"configuration": {"timeout": 5, "sse_read_timeout": 30},
"headers": {},
"authentication": {},
}
# Act
with (
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
patch(
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
),
):
resp = client.post(
"/console/api/workspaces/current/tool-provider/mcp",
data=json.dumps(payload),
content_type="application/json",
)
# Assert
assert resp.status_code == 200
body = resp.get_json()
assert body.get("id") == "provider-1"
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]

View File

@ -41,6 +41,7 @@ class TestFilePreviewApi:
upload_file = Mock(spec=UploadFile)
upload_file.id = str(uuid.uuid4())
upload_file.name = "test_file.jpg"
upload_file.extension = "jpg"
upload_file.mime_type = "image/jpeg"
upload_file.size = 1024
upload_file.key = "storage/key/test_file.jpg"
@ -210,6 +211,19 @@ class TestFilePreviewApi:
assert mock_upload_file.name in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
def test_build_file_response_html_forces_attachment(self, file_preview_api, mock_upload_file):
"""Test HTML files are forced to download"""
mock_generator = Mock()
mock_upload_file.mime_type = "text/html"
mock_upload_file.name = "unsafe.html"
mock_upload_file.extension = "html"
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
"""Test file response building for audio/video files"""
mock_generator = Mock()

View File

@ -0,0 +1,195 @@
"""Unit tests for controllers.web.forgot_password endpoints."""
from __future__ import annotations
import base64
import builtins
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.views import MethodView
# Ensure flask_restx.api finds MethodView during import.
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _load_controller_module():
"""Import controllers.web.forgot_password using a stub package."""
import importlib
import importlib.util
import sys
from types import ModuleType
parent_module_name = "controllers.web"
module_name = f"{parent_module_name}.forgot_password"
if parent_module_name not in sys.modules:
from flask_restx import Namespace
stub = ModuleType(parent_module_name)
stub.__file__ = "controllers/web/__init__.py"
stub.__path__ = ["controllers/web"]
stub.__package__ = "controllers"
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
stub.web_ns = Namespace("web", description="Web API", path="/")
sys.modules[parent_module_name] = stub
return importlib.import_module(module_name)
forgot_password_module = _load_controller_module()
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
@pytest.fixture
def app() -> Flask:
"""Configure a minimal Flask app for request contexts."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture(autouse=True)
def _enable_web_endpoint_guards():
"""Stub enterprise and feature toggles used by route decorators."""
features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
):
yield
@pytest.fixture(autouse=True)
def _mock_controller_db():
"""Replace controller-level db reference with a simple stub."""
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
fake_wraps_db = SimpleNamespace(
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
)
with (
patch("controllers.web.forgot_password.db", fake_db),
patch("controllers.console.wraps.db", fake_wraps_db),
):
yield fake_db
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
def test_send_reset_email_success(
mock_extract_ip: MagicMock,
mock_is_ip_limit: MagicMock,
mock_session: MagicMock,
mock_send_email: MagicMock,
app: Flask,
):
"""POST /forgot-password returns token when email exists and limits allow."""
mock_account = MagicMock()
session_ctx = MagicMock()
mock_session.return_value.__enter__.return_value = session_ctx
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
with app.test_request_context(
"/forgot-password",
method="POST",
json={"email": "user@example.com"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "reset-token"}
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
def test_check_token_success(
mock_is_rate_limited: MagicMock,
mock_get_data: MagicMock,
mock_revoke: MagicMock,
mock_generate: MagicMock,
mock_reset_limit: MagicMock,
app: Flask,
):
"""POST /forgot-password/validity validates the code and refreshes token."""
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limited.assert_called_once_with("user@example.com")
mock_get_data.assert_called_once_with("old-token")
mock_revoke.assert_called_once_with("old-token")
mock_generate.assert_called_once_with(
"user@example.com",
code="123456",
additional_data={"phase": "reset"},
)
mock_reset_limit.assert_called_once_with("user@example.com")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_success(
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_session: MagicMock,
mock_token_bytes: MagicMock,
mock_hash_password: MagicMock,
app: Flask,
):
"""POST /forgot-password/resets updates the stored password when token is valid."""
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
account = MagicMock()
session_ctx = MagicMock()
mock_session.return_value.__enter__.return_value = session_ctx
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_data.assert_called_once_with("reset-token")
mock_revoke_token.assert_called_once_with("reset-token")
mock_token_bytes.assert_called_once_with(16)
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
expected_password = base64.b64encode(b"hashed-value").decode()
assert account.password == expected_password
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
assert account.password_salt == expected_salt
session_ctx.commit.assert_called_once()

View File

@ -287,7 +287,7 @@ def test_validate_inputs_optional_file_with_empty_string():
def test_validate_inputs_optional_file_list_with_empty_list():
"""Test that optional FILE_LIST variable with empty list returns None"""
"""Test that optional FILE_LIST variable with empty list returns empty list (not None)"""
base_app_generator = BaseAppGenerator()
var_file_list = VariableEntity(
@ -302,6 +302,28 @@ def test_validate_inputs_optional_file_list_with_empty_list():
value=[],
)
# Empty list should be preserved, not converted to None
# This allows downstream components like document_extractor to handle empty lists properly
assert result == []
def test_validate_inputs_optional_file_list_with_empty_string():
"""Test that optional FILE_LIST variable with empty string returns None"""
base_app_generator = BaseAppGenerator()
var_file_list = VariableEntity(
variable="test_file_list",
label="test_file_list",
type=VariableEntityType.FILE_LIST,
required=False,
)
result = base_app_generator._validate_inputs(
variable_entity=var_file_list,
value="",
)
# Empty string should be treated as unset
assert result is None

View File

@ -0,0 +1,420 @@
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
import pytest
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueErrorEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageFileEvent,
QueuePingEvent,
)
from core.app.entities.task_entities import (
EasyUITaskState,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
PingStreamResponse,
StreamEvent,
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock(spec=ChatAppGenerateEntity)
entity.task_id = "test-task-id"
entity.app_id = "test-app-id"
# minimal app_config used by pipeline internals
entity.app_config = SimpleNamespace(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.CHAT,
app_model_config_dict={},
additional_features=None,
sensitive_word_avoidance=None,
)
# minimal model_conf for LLMResult init
entity.model_conf = SimpleNamespace(
model="test-model",
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
credentials={},
)
return entity
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = Mock(spec=AppQueueManager)
return manager
@pytest.fixture
def mock_message_cycle_manager(self):
"""Create a mock message cycle manager."""
manager = Mock()
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
manager.handle_retriever_resources = Mock()
manager.handle_annotation_reply.return_value = None
return manager
@pytest.fixture
def mock_conversation(self):
"""Create a mock conversation."""
conversation = Mock()
conversation.id = "test-conversation-id"
conversation.mode = "chat"
return conversation
@pytest.fixture
def mock_message(self):
"""Create a mock message."""
message = Mock()
message.id = "test-message-id"
message.created_at = Mock()
message.created_at.timestamp.return_value = 1234567890
return message
@pytest.fixture
def mock_task_state(self):
"""Create a mock task state."""
task_state = Mock(spec=EasyUITaskState)
# Create LLM result mock
llm_result = Mock(spec=RuntimeLLMResult)
llm_result.prompt_messages = []
llm_result.message = Mock()
llm_result.message.content = ""
task_state.llm_result = llm_result
task_state.answer = ""
return task_state
@pytest.fixture
def pipeline(
self,
mock_application_generate_entity,
mock_queue_manager,
mock_conversation,
mock_message,
mock_message_cycle_manager,
mock_task_state,
):
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
):
pipeline = EasyUIBasedGenerateTaskPipeline(
application_generate_entity=mock_application_generate_entity,
queue_manager=mock_queue_manager,
conversation=mock_conversation,
message=mock_message,
stream=True,
)
pipeline._message_cycle_manager = mock_message_cycle_manager
pipeline._task_state = mock_task_state
return pipeline
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
self, pipeline, mock_message_cycle_manager
):
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
# Setup a minimal LLM chunk event
chunk = Mock()
chunk.delta.message.content = "hi"
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
# Execute
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of LLM chunk events with text content."""
# Setup
chunk = Mock()
chunk.delta.message.content = "Hello, world!"
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
assert mock_task_state.llm_result.message.content == "Hello, world!"
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of LLM chunk events with list content."""
# Setup
text_content = Mock(spec=TextPromptMessageContent)
text_content.data = "Hello"
chunk = Mock()
chunk.delta.message.content = [text_content, " world!"]
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
assert mock_task_state.llm_result.message.content == "Hello world!"
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of agent message events."""
# Setup
chunk = Mock()
chunk.delta.message.content = "Agent response"
agent_message_event = Mock(spec=QueueAgentMessageEvent)
agent_message_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = agent_message_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
# Ensure method under assertion is a mock to track calls
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
# Agent messages should use _agent_message_to_stream_response
pipeline._agent_message_to_stream_response.assert_called_once_with(
answer="Agent response", message_id="test-message-id"
)
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of message end events."""
# Setup
llm_result = Mock(spec=RuntimeLLMResult)
llm_result.message = Mock()
llm_result.message.content = "Final response"
message_end_event = Mock(spec=QueueMessageEndEvent)
message_end_event.llm_result = llm_result
mock_queue_message = Mock()
mock_queue_message.event = message_end_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline._save_message = Mock()
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
assert mock_task_state.llm_result == llm_result
pipeline._save_message.assert_called_once()
pipeline._message_end_to_stream_response.assert_called_once()
def test_error_event(self, pipeline):
"""Test handling of error events."""
# Setup
error_event = Mock(spec=QueueErrorEvent)
error_event.error = Exception("Test error")
mock_queue_message = Mock()
mock_queue_message.event = error_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.handle_error = Mock(return_value=Exception("Test error"))
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
pipeline.handle_error.assert_called_once()
pipeline.error_to_stream_response.assert_called_once()
def test_ping_event(self, pipeline):
"""Test handling of ping events."""
# Setup
ping_event = Mock(spec=QueuePingEvent)
mock_queue_message = Mock()
mock_queue_message.event = ping_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
pipeline.ping_stream_response.assert_called_once()
def test_file_event(self, pipeline, mock_message_cycle_manager):
"""Test handling of file events."""
# Setup
file_event = Mock(spec=QueueMessageFileEvent)
file_event.message_file_id = "file-id"
mock_queue_message = Mock()
mock_queue_message.event = file_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
file_response = Mock(spec=MessageFileStreamResponse)
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
assert responses[0] == file_response
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
def test_publisher_is_called_with_messages(self, pipeline):
"""Test that publisher publishes messages when provided."""
# Setup
publisher = Mock(spec=AppGeneratorTTSPublisher)
ping_event = Mock(spec=QueuePingEvent)
mock_queue_message = Mock()
mock_queue_message.event = ping_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
# Assert
# Called once with message and once with None at the end
assert publisher.publish.call_count == 2
publisher.publish.assert_any_call(mock_queue_message)
publisher.publish.assert_any_call(None)
def test_trace_manager_passed_to_save_message(self, pipeline):
"""Test that trace manager is passed to _save_message."""
# Setup
trace_manager = Mock(spec=TraceQueueManager)
message_end_event = Mock(spec=QueueMessageEndEvent)
message_end_event.llm_result = None
mock_queue_message = Mock()
mock_queue_message.event = message_end_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline._save_message = Mock()
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
# Assert
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling multiple events in sequence."""
# Setup
chunk1 = Mock()
chunk1.delta.message.content = "Hello"
chunk1.prompt_messages = []
chunk2 = Mock()
chunk2.delta.message.content = " world!"
chunk2.prompt_messages = []
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event1.chunk = chunk1
ping_event = Mock(spec=QueuePingEvent)
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event2.chunk = chunk2
mock_queue_messages = [
Mock(event=llm_chunk_event1),
Mock(event=ping_event),
Mock(event=llm_chunk_event2),
]
pipeline.queue_manager.listen.return_value = mock_queue_messages
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 3
assert mock_task_state.llm_result.message.content == "Hello world!"
# Verify calls to message_to_stream_response
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)

View File

@ -0,0 +1,166 @@
"""Unit tests for the message cycle manager optimization."""
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
import pytest
from flask import current_app
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
class TestMessageCycleManagerOptimization:
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock()
entity.task_id = "test-task-id"
return entity
@pytest.fixture
def message_cycle_manager(self, mock_application_generate_entity):
"""Create a message cycle manager instance."""
task_state = Mock()
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and no message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE
mock_session.query.return_value.scalar.assert_called_once()
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=event_type
)
# Assert
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
"""Test that message_to_stream_response skips database query when event_type is provided."""
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
# Execute with event_type provided
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
# Assert
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE
# Should not query database when event_type is provided
mock_session_class.assert_not_called()
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
"""Test message_to_stream_response with from_variable_selector parameter."""
result = message_cycle_manager.message_to_stream_response(
answer="Hello world",
message_id="test-message-id",
from_variable_selector=["var1", "var2"],
event_type=StreamEvent.MESSAGE,
)
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.from_variable_selector == ["var1", "var2"]
assert result.event == StreamEvent.MESSAGE
def test_optimization_usage_example(self, message_cycle_manager):
"""Test the optimization pattern that should be used by callers."""
# Step 1: Get event type once (this queries database)
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None # No files
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
# Should query database once
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
assert event_type == StreamEvent.MESSAGE
# Step 2: Use event_type for multiple calls (no additional queries)
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
mock_session_class.return_value.__enter__.return_value = Mock()
chunk1_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 1", message_id="test-message-id", event_type=event_type
)
chunk2_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 2", message_id="test-message-id", event_type=event_type
)
# Should not query database again
mock_session_class.assert_not_called()
assert chunk1_response.event == StreamEvent.MESSAGE
assert chunk2_response.event == StreamEvent.MESSAGE
assert chunk1_response.answer == "Chunk 1"
assert chunk2_response.answer == "Chunk 2"

View File

@ -96,7 +96,7 @@ class TestNotionExtractorAuthentication:
def test_init_with_integration_token_fallback(self, mock_get_token, mock_config, mock_document_model):
"""Test NotionExtractor falls back to integration token when credential not found."""
# Arrange
mock_get_token.return_value = None
mock_get_token.side_effect = Exception("No credential id found")
mock_config.NOTION_INTEGRATION_TOKEN = "integration-token-fallback"
# Act
@ -105,7 +105,7 @@ class TestNotionExtractorAuthentication:
notion_obj_id="page-456",
notion_page_type="page",
tenant_id="tenant-789",
credential_id="cred-123",
credential_id=None,
document_model=mock_document_model,
)
@ -117,7 +117,7 @@ class TestNotionExtractorAuthentication:
def test_init_missing_credentials_raises_error(self, mock_get_token, mock_config, mock_document_model):
"""Test NotionExtractor raises error when no credentials available."""
# Arrange
mock_get_token.return_value = None
mock_get_token.side_effect = Exception("No credential id found")
mock_config.NOTION_INTEGRATION_TOKEN = None
# Act & Assert
@ -127,7 +127,7 @@ class TestNotionExtractorAuthentication:
notion_obj_id="page-456",
notion_page_type="page",
tenant_id="tenant-789",
credential_id="cred-123",
credential_id=None,
document_model=mock_document_model,
)
assert "Must specify `integration_token`" in str(exc_info.value)

View File

@ -1,52 +1,109 @@
import secrets
from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
_get_user_provided_host_header,
make_request,
)
@patch("httpx.Client.request")
def test_successful_request(mock_request):
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_successful_request(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_request.return_value = mock_response
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
@patch("httpx.Client.request")
def test_retry_exceed_max_retries(mock_request):
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_retry_exceed_max_retries(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
mock_request.side_effect = side_effects
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
with pytest.raises(Exception) as e:
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch("httpx.Client.request")
def test_retry_logic_success(mock_request):
side_effects = []
class TestGetUserProvidedHostHeader:
"""Tests for _get_user_provided_host_header function."""
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
status_code = secrets.choice(STATUS_FORCELIST)
mock_response = MagicMock()
mock_response.status_code = status_code
side_effects.append(mock_response)
def test_returns_none_when_headers_is_none(self):
assert _get_user_provided_host_header(None) is None
mock_response_200 = MagicMock()
mock_response_200.status_code = 200
side_effects.append(mock_response_200)
def test_returns_none_when_headers_is_empty(self):
assert _get_user_provided_host_header({}) is None
mock_request.side_effect = side_effects
def test_returns_none_when_host_header_not_present(self):
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}
assert _get_user_provided_host_header(headers) is None
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
def test_returns_host_header_lowercase(self):
headers = {"host": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_uppercase(self):
headers = {"HOST": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_mixed_case(self):
headers = {"HoSt": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_from_multiple_headers(self):
headers = {"Content-Type": "application/json", "Host": "api.example.com", "Authorization": "Bearer token"}
assert _get_user_provided_host_header(headers) == "api.example.com"
def test_returns_first_host_header_when_duplicates(self):
headers = {"host": "first.com", "Host": "second.com"}
# Should return the first one encountered (iteration order is preserved in dict)
result = _get_user_provided_host_header(headers)
assert result in ("first.com", "second.com")
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_without_user_header(mock_get_client):
"""Test that when no Host header is provided, the default behavior is maintained."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
# Host should not be set if not provided by user
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_with_user_header(mock_get_client):
"""Test that user-provided Host header is preserved in the request."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
custom_host = "custom.example.com:8080"
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
assert response.status_code == 200
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_request.call_args_list[0][1].get("method") == "GET"

View File

@ -1,129 +0,0 @@
import json
from unittest.mock import patch
import pytest
from redis.exceptions import RedisError
from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
@pytest.fixture
def mock_redis_client():
"""Fixture: Mock Redis client"""
with patch("core.helper.tool_provider_cache.redis_client") as mock:
yield mock
class TestToolProviderListCache:
"""Test class for ToolProviderListCache"""
def test_generate_cache_key(self):
"""Test cache key generation logic"""
# Scenario 1: Specify typ (valid literal value)
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "builtin"
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
# Scenario 2: typ is None (defaults to "all")
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
def test_get_cached_providers_hit(self, mock_redis_client):
"""Test get cached providers - cache hit and successful decoding"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "api"
mock_providers = [{"id": "tool", "name": "test_provider"}]
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
assert result == mock_providers
def test_get_cached_providers_decode_error(self, mock_redis_client):
"""Test get cached providers - cache hit but decoding failed"""
tenant_id = "tenant_123"
mock_redis_client.get.return_value = b"invalid_json_data"
result = ToolProviderListCache.get_cached_providers(tenant_id)
assert result is None
mock_redis_client.get.assert_called_once()
def test_get_cached_providers_miss(self, mock_redis_client):
"""Test get cached providers - cache miss"""
tenant_id = "tenant_123"
mock_redis_client.get.return_value = None
result = ToolProviderListCache.get_cached_providers(tenant_id)
assert result is None
mock_redis_client.get.assert_called_once()
def test_set_cached_providers(self, mock_redis_client):
"""Test set cached providers"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "builtin"
mock_providers = [{"id": "tool", "name": "test_provider"}]
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
mock_redis_client.setex.assert_called_once_with(
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
)
def test_invalidate_cache_specific_type(self, mock_redis_client):
"""Test invalidate cache - specific type"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "workflow"
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
ToolProviderListCache.invalidate_cache(tenant_id, typ)
mock_redis_client.delete.assert_called_once_with(cache_key)
def test_invalidate_cache_all_types(self, mock_redis_client):
"""Test invalidate cache - clear all tenant cache"""
tenant_id = "tenant_123"
mock_keys = [
b"tool_providers:tenant_id:tenant_123:type:all",
b"tool_providers:tenant_id:tenant_123:type:builtin",
]
mock_redis_client.scan_iter.return_value = mock_keys
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
mock_redis_client.delete.assert_called_once_with(*mock_keys)
def test_invalidate_cache_no_keys(self, mock_redis_client):
"""Test invalidate cache - no cache keys for tenant"""
tenant_id = "tenant_123"
mock_redis_client.scan_iter.return_value = []
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.delete.assert_not_called()
def test_redis_fallback_default_return(self, mock_redis_client):
"""Test redis_fallback decorator - default return value (Redis error)"""
mock_redis_client.get.side_effect = RedisError("Redis connection error")
result = ToolProviderListCache.get_cached_providers("tenant_123")
assert result is None
mock_redis_client.get.assert_called_once()
def test_redis_fallback_no_default(self, mock_redis_client):
"""Test redis_fallback decorator - no default return value (Redis error)"""
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
try:
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
except RedisError:
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
mock_redis_client.setex.assert_called_once()

View File

@ -1,6 +1,9 @@
import re
from datetime import datetime
import pytest
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
class TestValidateUrl:
@ -136,3 +139,51 @@ class TestValidateProjectName:
"""Test custom default name"""
result = validate_project_name("", "Custom Default")
assert result == "Custom Default"
class TestGenerateDottedOrder:
"""Test cases for generate_dotted_order function"""
def test_dotted_order_has_6_digit_microseconds(self):
"""Test that timestamp includes full 6-digit microseconds for LangSmith API compatibility.
LangSmith API expects timestamps in format: YYYYMMDDTHHMMSSffffffZ (6-digit microseconds).
Previously, the code truncated to 3 digits which caused API errors:
'cannot parse .111 as .000000'
"""
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
run_id = "test-run-id"
result = generate_dotted_order(run_id, start_time)
# Extract timestamp portion (before the run_id)
timestamp_match = re.match(r"^(\d{8}T\d{6})(\d+)Z", result)
assert timestamp_match is not None, "Timestamp format should match YYYYMMDDTHHMMSSffffffZ"
microseconds = timestamp_match.group(2)
assert len(microseconds) == 6, f"Microseconds should be 6 digits, got {len(microseconds)}: {microseconds}"
def test_dotted_order_format_matches_langsmith_expected(self):
"""Test that dotted_order format matches LangSmith API expected format."""
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456)
run_id = "abc123"
result = generate_dotted_order(run_id, start_time)
# LangSmith expects: YYYYMMDDTHHMMSSffffffZ followed by run_id
assert result == "20250115T103045123456Zabc123"
def test_dotted_order_with_parent(self):
"""Test dotted_order generation with parent order uses dot separator."""
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
run_id = "child-run-id"
parent_order = "20251223T041955000000Zparent-run-id"
result = generate_dotted_order(run_id, start_time, parent_order)
assert result == "20251223T041955000000Zparent-run-id.20251223T041955111000Zchild-run-id"
def test_dotted_order_without_parent_has_no_dot(self):
"""Test dotted_order generation without parent has no dot separator."""
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
run_id = "test-run-id"
result = generate_dotted_order(run_id, start_time, None)
assert "." not in result

View File

@ -0,0 +1,327 @@
import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.pgvector.pgvector import (
PGVector,
PGVectorConfig,
)
class TestPGVector(unittest.TestCase):
def setUp(self):
self.config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=False,
)
self.collection_name = "test_collection"
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init(self, mock_pool_class):
"""Test PGVector initialization."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, self.config)
assert pgvector._collection_name == self.collection_name
assert pgvector.table_name == f"embedding_{self.collection_name}"
assert pgvector.get_type() == "pgvector"
assert pgvector.pool is not None
assert pgvector.pg_bigm is False
assert pgvector.index_hash is not None
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init_with_pg_bigm(self, mock_pool_class):
"""Test PGVector initialization with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, config)
assert pgvector.pg_bigm is True
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_basic(self, mock_redis, mock_pool_class):
"""Test basic collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify SQL execution calls
assert mock_cursor.execute.called
# Check that CREATE TABLE was called with correct dimension
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(1536)" in create_table_calls[0][0][0]
# Check that CREATE INDEX was called (dimension <= 2000)
create_index_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
]
assert len(create_index_calls) == 1
# Verify Redis cache was set
mock_redis.set.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
"""Test collection creation with dimension > 2000 (no HNSW index)."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(3072) # Dimension > 2000
# Check that CREATE TABLE was called
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(3072)" in create_table_calls[0][0][0]
# Check that HNSW index was NOT created (dimension > 2000)
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
assert len(hnsw_index_calls) == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
"""Test collection creation with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, config)
pgvector._create_collection(1536)
# Check that pg_bigm index was created
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
assert len(bigm_index_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
"""Test that vector extension is created if it doesn't exist."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
# First call: vector extension doesn't exist
mock_cursor.fetchone.return_value = None
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that CREATE EXTENSION was called
create_extension_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
]
assert len(create_extension_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
"""Test that collection creation is skipped when cache exists."""
# Mock Redis operations - cache exists
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = 1 # Cache exists
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that no SQL was executed (early return due to cache)
assert mock_cursor.execute.call_count == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
"""Test that Redis lock is used during collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify Redis lock was acquired with correct lock name
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
# Verify lock context manager was entered and exited
mock_lock.__enter__.assert_called_once()
mock_lock.__exit__.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_get_cursor_context_manager(self, mock_pool_class):
"""Test that _get_cursor properly manages connection lifecycle."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
with pgvector._get_cursor() as cur:
assert cur == mock_cursor
# Verify connection lifecycle methods were called
mock_pool.getconn.assert_called_once()
mock_cursor.close.assert_called_once()
mock_conn.commit.assert_called_once()
mock_pool.putconn.assert_called_once_with(mock_conn)
@pytest.mark.parametrize(
"invalid_config_override",
[
{"host": ""}, # Test empty host
{"port": 0}, # Test invalid port
{"user": ""}, # Test empty user
{"password": ""}, # Test empty password
{"database": ""}, # Test empty database
{"min_connection": 0}, # Test invalid min_connection
{"max_connection": 0}, # Test invalid max_connection
{"min_connection": 10, "max_connection": 5}, # Test min > max
],
)
def test_config_validation_parametrized(invalid_config_override):
"""Test configuration validation for various invalid inputs using parametrize."""
config = {
"host": "localhost",
"port": 5432,
"user": "test_user",
"password": "test_password",
"database": "test_db",
"min_connection": 1,
"max_connection": 5,
}
config.update(invalid_config_override)
with pytest.raises(ValueError):
PGVectorConfig(**config)
if __name__ == "__main__":
unittest.main()

View File

@ -1,5 +1,7 @@
import os
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
@ -25,3 +27,35 @@ def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
assert job_id is not None
assert isinstance(job_id, str)
def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture):
api_key = "fc-"
base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"]
for base in base_urls:
app = FirecrawlApp(api_key=api_key, base_url=base)
mock_post = mocker.patch("httpx.post")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"id": "job123"}
mock_post.return_value = mock_resp
app.crawl_url("https://example.com", params=None)
called_url = mock_post.call_args[0][0]
assert called_url == "https://custom.firecrawl.dev/v2/crawl"
def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture):
api_key = "fc-"
app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/")
mock_post = mocker.patch("httpx.post")
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_resp
with pytest.raises(Exception) as excinfo:
app.scrape_url("https://example.com")
# Should not raise a JSONDecodeError; current behavior reports status code only
assert str(excinfo.value) == "Failed to scrape URL. Status code: 404"

View File

@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch):
# DB interactions should be recorded
assert len(db_stub.session.added) == 2
assert db_stub.session.committed is True
def test_extract_images_from_docx_uses_internal_files_url():
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
# Test the URL generation logic directly
from configs import dify_config
# Mock the configuration values
original_files_url = getattr(dify_config, "FILES_URL", None)
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
try:
# Set both URLs - INTERNAL should take precedence
dify_config.FILES_URL = "http://external.example.com"
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
# Test the URL generation logic (same as in word_extractor.py)
upload_file_id = "test_file_id"
# This is the pattern we fixed in the word extractor
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
finally:
# Restore original values
if original_files_url is not None:
dify_config.FILES_URL = original_files_url
if original_internal_files_url is not None:
dify_config.INTERNAL_FILES_URL = original_internal_files_url

View File

@ -421,7 +421,18 @@ class TestRetrievalService:
# In real code, this waits for all futures to complete
# In tests, futures complete immediately, so wait is a no-op
with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"):
yield mock_executor
# Mock concurrent.futures.as_completed for early error propagation
# In real code, this yields futures as they complete
# In tests, we yield all futures immediately since they're already done
def mock_as_completed(futures_list, timeout=None):
"""Mock as_completed that yields futures immediately."""
yield from futures_list
with patch(
"core.rag.datasource.retrieval_service.concurrent.futures.as_completed",
side_effect=mock_as_completed,
):
yield mock_executor
# ==================== Vector Search Tests ====================

View File

@ -0,0 +1,873 @@
"""
Unit tests for DatasetRetrieval.process_metadata_filter_func.
This module provides comprehensive test coverage for the process_metadata_filter_func
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
filter expressions based on metadata filtering conditions.
Conditions Tested:
==================
1. **String Conditions**: contains, not contains, start with, end with
2. **Equality Conditions**: is / =, is not / ≠
3. **Null Conditions**: empty, not empty
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
5. **List Conditions**: in
6. **Edge Cases**: None values, different data types (str, int, float)
Test Architecture:
==================
- Direct instantiation of DatasetRetrieval
- Mocking of DatasetDocument model attributes
- Verification of SQLAlchemy filter expressions
- Follows Arrange-Act-Assert (AAA) pattern
Running Tests:
==============
# Run all tests in this module
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
# Run a specific test
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
TestProcessMetadataFilterFunc::test_contains_condition -v
"""
from unittest.mock import MagicMock
import pytest
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
class TestProcessMetadataFilterFunc:
"""
Comprehensive test suite for process_metadata_filter_func method.
This test class validates all metadata filtering conditions supported by
the DatasetRetrieval class, including string operations, numeric comparisons,
null checks, and list operations.
Method Signature:
==================
def process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
) -> list:
The method builds SQLAlchemy filter expressions by:
1. Validating value is not None (except for empty/not empty conditions)
2. Using DatasetDocument.doc_metadata JSON field operations
3. Adding appropriate SQLAlchemy expressions to the filters list
4. Returning the updated filters list
Mocking Strategy:
==================
- Mock DatasetDocument.doc_metadata to avoid database dependencies
- Verify filter expressions are created correctly
- Test with various data types (str, int, float, list)
"""
@pytest.fixture
def retrieval(self):
"""
Create a DatasetRetrieval instance for testing.
Returns:
DatasetRetrieval: Instance to test process_metadata_filter_func
"""
return DatasetRetrieval()
@pytest.fixture
def mock_doc_metadata(self):
"""
Mock the DatasetDocument.doc_metadata JSON field.
The method uses DatasetDocument.doc_metadata[metadata_name] to access
JSON fields. We mock this to avoid database dependencies.
Returns:
Mock: Mocked doc_metadata attribute
"""
mock_metadata_field = MagicMock()
# Create mock for string access
mock_string_access = MagicMock()
mock_string_access.like = MagicMock()
mock_string_access.notlike = MagicMock()
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
mock_string_access.in_ = MagicMock(return_value=MagicMock())
# Create mock for float access (for numeric comparisons)
mock_float_access = MagicMock()
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
# Create mock for null checks
mock_null_access = MagicMock()
mock_null_access.is_ = MagicMock(return_value=MagicMock())
mock_null_access.isnot = MagicMock(return_value=MagicMock())
# Setup __getitem__ to return appropriate mock based on usage
def getitem_side_effect(name):
if name in ["author", "title", "category"]:
return mock_string_access
elif name in ["year", "price", "rating"]:
return mock_float_access
else:
return mock_string_access
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
mock_metadata_field.as_string.return_value = mock_string_access
mock_metadata_field.as_float.return_value = mock_float_access
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
return mock_metadata_field
# ==================== String Condition Tests ====================
def test_contains_condition_string_value(self, retrieval):
"""
Test 'contains' condition with string value.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value% syntax
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = "John"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_contains_condition(self, retrieval):
"""
Test 'not contains' condition.
Verifies:
- Filters list is populated with NOT LIKE expression
- Pattern matching uses %value% syntax with negation
"""
filters = []
sequence = 0
condition = "not contains"
metadata_name = "title"
value = "banned"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_start_with_condition(self, retrieval):
"""
Test 'start with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses value% syntax
"""
filters = []
sequence = 0
condition = "start with"
metadata_name = "category"
value = "tech"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_end_with_condition(self, retrieval):
"""
Test 'end with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value syntax
"""
filters = []
sequence = 0
condition = "end with"
metadata_name = "filename"
value = ".pdf"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Equality Condition Tests ====================
def test_is_condition_with_string_value(self, retrieval):
"""
Test 'is' (=) condition with string value.
Verifies:
- Filters list is populated with equality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = "Jane Doe"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_equals_condition_with_string_value(self, retrieval):
"""
Test '=' condition with string value.
Verifies:
- Same behavior as 'is' condition
- String comparison is used
"""
filters = []
sequence = 0
condition = "="
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_int_value(self, retrieval):
"""
Test 'is' condition with integer value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_float_value(self, retrieval):
"""
Test 'is' condition with float value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "price"
value = 19.99
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_string_value(self, retrieval):
"""
Test 'is not' (≠) condition with string value.
Verifies:
- Filters list is populated with inequality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "author"
value = "Unknown"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_equals_condition(self, retrieval):
"""
Test '' condition with string value.
Verifies:
- Same behavior as 'is not' condition
- Inequality expression is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "category"
value = "archived"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_numeric_value(self, retrieval):
"""
Test 'is not' condition with numeric value.
Verifies:
- Numeric inequality comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Null Condition Tests ====================
def test_empty_condition(self, retrieval):
"""
Test 'empty' condition (null check).
Verifies:
- Filters list is populated with IS NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "empty"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_empty_condition(self, retrieval):
"""
Test 'not empty' condition (not null check).
Verifies:
- Filters list is populated with IS NOT NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "not empty"
metadata_name = "description"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Numeric Comparison Tests ====================
def test_before_condition(self, retrieval):
"""
Test 'before' (<) condition.
Verifies:
- Filters list is populated with less than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "before"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_condition(self, retrieval):
"""
Test '<' condition.
Verifies:
- Same behavior as 'before' condition
- Less than expression is used
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "price"
value = 100.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_after_condition(self, retrieval):
"""
Test 'after' (>) condition.
Verifies:
- Filters list is populated with greater than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "after"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_condition(self, retrieval):
"""
Test '>' condition.
Verifies:
- Same behavior as 'after' condition
- Greater than expression is used
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with less than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "price"
value = 50.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_ascii(self, retrieval):
"""
Test '<=' condition.
Verifies:
- Same behavior as '' condition
- Less than or equal expression is used
"""
filters = []
sequence = 0
condition = "<="
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with greater than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "rating"
value = 3.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_ascii(self, retrieval):
"""
Test '>=' condition.
Verifies:
- Same behavior as '' condition
- Greater than or equal expression is used
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== List/In Condition Tests ====================
def test_in_condition_with_comma_separated_string(self, retrieval):
"""
Test 'in' condition with comma-separated string value.
Verifies:
- String is split into list
- Whitespace is trimmed from each value
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "tech, science, AI "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_list_value(self, retrieval):
"""
Test 'in' condition with list value.
Verifies:
- List is processed correctly
- None values are filtered out
- IN expression is created with valid values
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "tags"
value = ["python", "javascript", None, "golang"]
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_tuple_value(self, retrieval):
"""
Test 'in' condition with tuple value.
Verifies:
- Tuple is processed like a list
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ("tech", "science", "ai")
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_empty_string(self, retrieval):
"""
Test 'in' condition with empty string value.
Verifies:
- Empty string results in literal(False) filter
- No valid values to match
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# Verify it's a literal(False) expression
# This is a bit tricky to test without access to the actual expression
def test_in_condition_with_only_whitespace(self, retrieval):
"""
Test 'in' condition with whitespace-only string value.
Verifies:
- Whitespace-only string results in literal(False) filter
- All values are stripped and filtered out
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = " , , "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_single_string(self, retrieval):
"""
Test 'in' condition with single non-comma string.
Verifies:
- Single string is treated as single-item list
- IN expression is created with one value
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Edge Case Tests ====================
def test_none_value_with_non_empty_condition(self, retrieval):
"""
Test None value with conditions that require value.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values (except empty/not empty)
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0 # No filter added
def test_none_value_with_equals_condition(self, retrieval):
"""
Test None value with 'is' (=) condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_none_value_with_numeric_condition(self, retrieval):
"""
Test None value with numeric comparison condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "year"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_existing_filters_preserved(self, retrieval):
"""
Test that existing filters are preserved.
Verifies:
- Existing filters in the list are not removed
- New filters are appended to the list
"""
existing_filter = MagicMock()
filters = [existing_filter]
sequence = 0
condition = "contains"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 2
assert filters[0] == existing_filter
def test_multiple_filters_accumulated(self, retrieval):
"""
Test multiple calls to accumulate filters.
Verifies:
- Each call adds a new filter to the list
- All filters are preserved across calls
"""
filters = []
# First filter
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
assert len(filters) == 1
# Second filter
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
assert len(filters) == 2
# Third filter
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
assert len(filters) == 3
def test_unknown_condition(self, retrieval):
"""
Test unknown/unsupported condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for unknown conditions
"""
filters = []
sequence = 0
condition = "unknown_condition"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_empty_string_value_with_contains(self, retrieval):
"""
Test empty string value with 'contains' condition.
Verifies:
- Filter is added even with empty string
- LIKE expression is created
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_special_characters_in_value(self, retrieval):
"""
Test special characters in value string.
Verifies:
- Special characters are handled in value
- LIKE expression is created correctly
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "title"
value = "C++ & Python's features"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_zero_value_with_numeric_condition(self, retrieval):
"""
Test zero value with numeric comparison condition.
Verifies:
- Zero is treated as valid value
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "price"
value = 0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_negative_value_with_numeric_condition(self, retrieval):
"""
Test negative value with numeric comparison condition.
Verifies:
- Negative numbers are handled correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "temperature"
value = -10.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_float_value_with_integer_comparison(self, retrieval):
"""
Test float value with numeric comparison condition.
Verifies:
- Float values work correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1

View File

@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter:
# Verify no empty chunks
assert all(len(chunk) > 0 for chunk in result)
def test_double_slash_n(self):
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
separator = "\\n\\n---\\n\\n"
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
chunks = splitter.split_text(data)
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
# ============================================================================
# Test Metadata Preservation

View File

@ -1,3 +1,5 @@
import pytest
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNodeAuthorization,
@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -348,3 +351,127 @@ def test_init_params():
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]
def test_empty_api_key_raises_error_bearer():
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": ""},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_empty_api_key_raises_error_basic():
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "basic", "api_key": ""},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_empty_api_key_raises_error_custom():
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_whitespace_only_api_key_raises_error():
"""Test that whitespace-only API key raises AuthorizationConfigError."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": " "},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_valid_api_key_works():
"""Test that valid API key works correctly for bearer auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": "valid-api-key-123"},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
executor = Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
# Should not raise an error
headers = executor._assembling_headers()
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer valid-api-key-123"

View File

@ -1,3 +1,4 @@
import json
import time
import pytest
@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables):
def test_json_object_valid_schema():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
)
variables = [
VariableEntity(
@ -65,7 +68,7 @@ def test_json_object_valid_schema():
)
]
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
node = make_start_node(user_inputs, variables)
result = node._run()
@ -74,12 +77,23 @@ def test_json_object_valid_schema():
def test_json_object_invalid_json_string():
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
json_schema=schema,
)
]
@ -88,38 +102,21 @@ def test_json_object_invalid_json_string():
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
def test_json_object_valid_json_but_not_object(value):
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
)
]
user_inputs = {"profile": value}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
node._run()
def test_json_object_does_not_match_schema():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema():
]
# age is a string, which violates the schema (expects number)
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
node = make_start_node(user_inputs, variables)
@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema():
def test_json_object_missing_required_schema_field():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field():
]
# Missing required field "name"
user_inputs = {"profile": {"age": 20}}
user_inputs = {"profile": json.dumps({"age": 20})}
node = make_start_node(user_inputs, variables)
@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided():
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=False,
required=True,
)
]
@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided():
node = make_start_node(user_inputs, variables)
# Current implementation raises a validation error even when the variable is optional
with pytest.raises(ValueError, match="profile must be a JSON object"):
with pytest.raises(ValueError, match="profile is required in input form"):
node._run()

View File

@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from qcloud_cos import CosConfig
@ -18,3 +18,72 @@ class TestTencentCos(BaseStorageTest):
with patch.object(CosConfig, "__init__", return_value=None):
self.storage = TencentCosStorage()
self.storage.bucket_name = get_example_bucket()
class TestTencentCosConfiguration:
"""Tests for TencentCosStorage initialization with different configurations."""
def test_init_with_custom_domain(self):
"""Test initialization with custom domain configured."""
# Mock dify_config to return custom domain configuration
mock_dify_config = MagicMock()
mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = "cos.example.com"
mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id"
mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key"
mock_dify_config.TENCENT_COS_SCHEME = "https"
# Mock CosConfig and CosS3Client
mock_config_instance = MagicMock()
mock_client = MagicMock()
with (
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
patch(
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
) as mock_cos_config,
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
):
TencentCosStorage()
# Verify CosConfig was called with Domain parameter (not Region)
mock_cos_config.assert_called_once()
call_kwargs = mock_cos_config.call_args[1]
assert "Domain" in call_kwargs
assert call_kwargs["Domain"] == "cos.example.com"
assert "Region" not in call_kwargs
assert call_kwargs["SecretId"] == "test-secret-id"
assert call_kwargs["SecretKey"] == "test-secret-key"
assert call_kwargs["Scheme"] == "https"
def test_init_with_region(self):
"""Test initialization with region configured (no custom domain)."""
# Mock dify_config to return region configuration
mock_dify_config = MagicMock()
mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = None
mock_dify_config.TENCENT_COS_REGION = "ap-guangzhou"
mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id"
mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key"
mock_dify_config.TENCENT_COS_SCHEME = "https"
# Mock CosConfig and CosS3Client
mock_config_instance = MagicMock()
mock_client = MagicMock()
with (
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
patch(
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
) as mock_cos_config,
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
):
TencentCosStorage()
# Verify CosConfig was called with Region parameter (not Domain)
mock_cos_config.assert_called_once()
call_kwargs = mock_cos_config.call_args[1]
assert "Region" in call_kwargs
assert call_kwargs["Region"] == "ap-guangzhou"
assert "Domain" not in call_kwargs
assert call_kwargs["SecretId"] == "test-secret-id"
assert call_kwargs["SecretKey"] == "test-secret-key"
assert call_kwargs["Scheme"] == "https"

View File

@ -1,3 +1,4 @@
import json
from unittest.mock import MagicMock, patch
import httpx
@ -110,9 +111,11 @@ class TestFirecrawlAuth:
@pytest.mark.parametrize(
("status_code", "response_text", "has_json_error", "expected_error_contains"),
[
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
(401, "Not JSON", True, "Expecting value"), # JSON decode error
(403, '{"error": "Forbidden"}', False, "Failed to authorize. Status code: 403. Error: Forbidden"),
# empty body falls back to generic message
(404, "", True, "Failed to authorize. Status code: 404. Error: Unknown error occurred"),
# non-JSON body is surfaced directly
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@ -124,12 +127,14 @@ class TestFirecrawlAuth:
mock_response.status_code = status_code
mock_response.text = response_text
if has_json_error:
mock_response.json.side_effect = Exception("Not JSON")
mock_response.json.side_effect = json.JSONDecodeError("Not JSON", "", 0)
else:
mock_response.json.return_value = {"error": "Forbidden"}
mock_post.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert expected_error_contains in str(exc_info.value)
assert str(exc_info.value) == expected_error_contains
@pytest.mark.parametrize(
("exception_type", "exception_message"),
@ -164,20 +169,21 @@ class TestFirecrawlAuth:
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation"""
"""Test that custom base URL is used in validation and normalized"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
credentials = {
"auth_type": "bearer",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
}
auth = FirecrawlAuth(credentials)
result = auth.validate_credentials()
for base in ("https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"):
credentials = {
"auth_type": "bearer",
"config": {"api_key": "test_api_key_123", "base_url": base},
}
auth = FirecrawlAuth(credentials)
result = auth.validate_credentials()
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):

View File

@ -0,0 +1,71 @@
from unittest.mock import MagicMock
import httpx
from models import Account
from services import app_dsl_service
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
def _build_response(url: str, status_code: int, content: bytes = b"") -> httpx.Response:
request = httpx.Request("GET", url)
return httpx.Response(status_code=status_code, request=request, content=content)
def _pending_yaml_content(version: str = "99.0.0") -> bytes:
return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode()
def _account_mock() -> MagicMock:
account = MagicMock(spec=Account)
account.current_tenant_id = "tenant-1"
return account
def test_import_app_yaml_url_user_attachments_keeps_original_url(monkeypatch):
yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml"
raw_url = "https://raw.githubusercontent.com/user-attachments/files/24290802/loop-test.yml"
yaml_bytes = _pending_yaml_content()
def fake_get(url: str, **kwargs):
if url == raw_url:
return _build_response(url, status_code=404)
assert url == yaml_url
return _build_response(url, status_code=200, content=yaml_bytes)
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
service = AppDslService(MagicMock())
result = service.import_app(
account=_account_mock(),
import_mode=ImportMode.YAML_URL,
yaml_url=yaml_url,
)
assert result.status == ImportStatus.PENDING
assert result.imported_dsl_version == "99.0.0"
def test_import_app_yaml_url_github_blob_rewrites_to_raw(monkeypatch):
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
yaml_bytes = _pending_yaml_content()
requested_urls: list[str] = []
def fake_get(url: str, **kwargs):
requested_urls.append(url)
assert url == raw_url
return _build_response(url, status_code=200, content=yaml_bytes)
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
service = AppDslService(MagicMock())
result = service.import_app(
account=_account_mock(),
import_mode=ImportMode.YAML_URL,
yaml_url=yaml_url,
)
assert result.status == ImportStatus.PENDING
assert requested_urls == [raw_url]

View File

@ -1156,6 +1156,235 @@ class TestBillingServiceEdgeCases:
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
class TestBillingServiceSubscriptionOperations:
"""Unit tests for subscription operations in BillingService.
Tests cover:
- Bulk plan retrieval with chunking
- Expired subscription cleanup whitelist retrieval
"""
@pytest.fixture
def mock_send_request(self):
"""Mock _send_request method."""
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_get_plan_bulk_with_empty_list(self, mock_send_request):
"""Test bulk plan retrieval with empty tenant list."""
# Arrange
tenant_ids = []
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert result == {}
mock_send_request.assert_not_called()
def test_get_plan_bulk_with_chunking(self, mock_send_request):
"""Test bulk plan retrieval with more than 200 tenants (chunking logic)."""
# Arrange - 250 tenants to test chunking (chunk_size = 200)
tenant_ids = [f"tenant-{i}" for i in range(250)]
# First chunk: tenants 0-199
first_chunk_response = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Second chunk: tenants 200-249
second_chunk_response = {
"data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)}
}
mock_send_request.side_effect = [first_chunk_response, second_chunk_response]
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert len(result) == 250
assert result["tenant-0"]["plan"] == "sandbox"
assert result["tenant-199"]["plan"] == "sandbox"
assert result["tenant-200"]["plan"] == "professional"
assert result["tenant-249"]["plan"] == "professional"
assert mock_send_request.call_count == 2
# Verify first chunk call
first_call = mock_send_request.call_args_list[0]
assert first_call[0][0] == "POST"
assert first_call[0][1] == "/subscription/plan/batch"
assert len(first_call[1]["json"]["tenant_ids"]) == 200
# Verify second chunk call
second_call = mock_send_request.call_args_list[1]
assert len(second_call[1]["json"]["tenant_ids"]) == 50
def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request):
"""Test bulk plan retrieval when one batch fails but others succeed."""
# Arrange - 250 tenants, second batch will fail
tenant_ids = [f"tenant-{i}" for i in range(250)]
# First chunk succeeds
first_chunk_response = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Second chunk fails - need to create a mock that raises when called
def side_effect_func(*args, **kwargs):
if mock_send_request.call_count == 1:
return first_chunk_response
else:
raise ValueError("API error")
mock_send_request.side_effect = side_effect_func
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert - should only have data from first batch
assert len(result) == 200
assert result["tenant-0"]["plan"] == "sandbox"
assert result["tenant-199"]["plan"] == "sandbox"
assert "tenant-200" not in result
assert mock_send_request.call_count == 2
def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request):
"""Test bulk plan retrieval when all batches fail."""
# Arrange
tenant_ids = [f"tenant-{i}" for i in range(250)]
# All chunks fail
def side_effect_func(*args, **kwargs):
raise ValueError("API error")
mock_send_request.side_effect = side_effect_func
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert - should return empty dict
assert result == {}
assert mock_send_request.call_count == 2
def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request):
"""Test bulk plan retrieval with exactly 200 tenants (boundary condition)."""
# Arrange
tenant_ids = [f"tenant-{i}" for i in range(200)]
mock_send_request.return_value = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert len(result) == 200
assert mock_send_request.call_count == 1
def test_get_plan_bulk_with_empty_data_response(self, mock_send_request):
"""Test bulk plan retrieval with empty data in response."""
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
mock_send_request.return_value = {"data": {}}
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert result == {}
def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request):
"""Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant)."""
# Arrange
tenant_ids = ["tenant-valid-1", "tenant-invalid", "tenant-valid-2"]
# Response with one invalid tenant plan (missing expiration_date) and two valid ones
mock_send_request.return_value = {
"data": {
"tenant-valid-1": {"plan": "sandbox", "expiration_date": 1735689600},
"tenant-invalid": {"plan": "professional"}, # Missing expiration_date field
"tenant-valid-2": {"plan": "team", "expiration_date": 1767225600},
}
}
# Act
with patch("services.billing_service.logger") as mock_logger:
result = BillingService.get_plan_bulk(tenant_ids)
# Assert - should only contain valid tenants
assert len(result) == 2
assert "tenant-valid-1" in result
assert "tenant-valid-2" in result
assert "tenant-invalid" not in result
# Verify valid tenants have correct data
assert result["tenant-valid-1"]["plan"] == "sandbox"
assert result["tenant-valid-1"]["expiration_date"] == 1735689600
assert result["tenant-valid-2"]["plan"] == "team"
assert result["tenant-valid-2"]["expiration_date"] == 1767225600
# Verify exception was logged for the invalid tenant
mock_logger.exception.assert_called_once()
log_call_args = mock_logger.exception.call_args[0]
assert "get_plan_bulk: failed to validate subscription plan for tenant" in log_call_args[0]
assert "tenant-invalid" in log_call_args[1]
def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request):
"""Test successful retrieval of expired subscription cleanup whitelist."""
# Arrange
api_response = [
{
"created_at": "2025-10-16T01:56:17",
"tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6",
"contact": "example@dify.ai",
"id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5",
"expired_at": "2026-01-01T01:56:17",
"updated_at": "2025-10-16T01:56:17",
},
{
"created_at": "2025-10-16T02:00:00",
"tenant_id": "tenant-2",
"contact": "test@example.com",
"id": "whitelist-id-2",
"expired_at": "2026-02-01T00:00:00",
"updated_at": "2025-10-16T02:00:00",
},
{
"created_at": "2025-10-16T03:00:00",
"tenant_id": "tenant-3",
"contact": "another@example.com",
"id": "whitelist-id-3",
"expired_at": "2026-03-01T00:00:00",
"updated_at": "2025-10-16T03:00:00",
},
]
mock_send_request.return_value = {"data": api_response}
# Act
result = BillingService.get_expired_subscription_cleanup_whitelist()
# Assert - should return only tenant_ids
assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"]
assert len(result) == 3
assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6"
assert result[1] == "tenant-2"
assert result[2] == "tenant-3"
mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist")
def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request):
"""Test retrieval of empty cleanup whitelist."""
# Arrange
mock_send_request.return_value = {"data": []}
# Act
result = BillingService.get_expired_subscription_cleanup_whitelist()
# Assert
assert result == []
assert len(result) == 0
class TestBillingServiceIntegrationScenarios:
"""Integration-style tests simulating real-world usage scenarios.

View File

@ -0,0 +1,472 @@
"""
Unit tests for SegmentService.get_segments method.
Tests the retrieval of document segments with pagination and filtering:
- Basic pagination (page, limit)
- Status filtering
- Keyword search
- Ordering by position and id (to avoid duplicate data)
"""
from unittest.mock import Mock, create_autospec, patch
import pytest
from models.dataset import DocumentSegment
class SegmentServiceTestDataFactory:
"""
Factory class for creating test data and mock objects for segment tests.
"""
@staticmethod
def create_segment_mock(
segment_id: str = "segment-123",
document_id: str = "doc-123",
tenant_id: str = "tenant-123",
dataset_id: str = "dataset-123",
position: int = 1,
content: str = "Test content",
status: str = "completed",
**kwargs,
) -> Mock:
"""
Create a mock document segment.
Args:
segment_id: Unique identifier for the segment
document_id: Parent document ID
tenant_id: Tenant ID the segment belongs to
dataset_id: Parent dataset ID
position: Position within the document
content: Segment text content
status: Indexing status
**kwargs: Additional attributes
Returns:
Mock: DocumentSegment mock object
"""
segment = create_autospec(DocumentSegment, instance=True)
segment.id = segment_id
segment.document_id = document_id
segment.tenant_id = tenant_id
segment.dataset_id = dataset_id
segment.position = position
segment.content = content
segment.status = status
for key, value in kwargs.items():
setattr(segment, key, value)
return segment
class TestSegmentServiceGetSegments:
"""
Comprehensive unit tests for SegmentService.get_segments method.
Tests cover:
- Basic pagination functionality
- Status list filtering
- Keyword search filtering
- Ordering (position + id for uniqueness)
- Empty results
- Combined filters
"""
@pytest.fixture
def mock_segment_service_dependencies(self):
"""
Common mock setup for segment service dependencies.
Patches:
- db: Database operations and pagination
- select: SQLAlchemy query builder
"""
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select") as mock_select,
):
yield {
"db": mock_db,
"select": mock_select,
}
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
"""
Test basic pagination functionality.
Verifies:
- Query is built with document_id and tenant_id filters
- Pagination uses correct page and limit parameters
- Returns segments and total count
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
page = 1
limit = 20
# Create mock segments
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="First segment"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=2, content="Second segment"
)
# Mock pagination result
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
# Mock select builder
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
# Assert
assert len(items) == 2
assert total == 2
assert items[0].id == "seg-1"
assert items[1].id == "seg-2"
mock_segment_service_dependencies["db"].paginate.assert_called_once()
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
"""
Test filtering by status list.
Verifies:
- Status list filter is applied to query
- Only segments with matching status are returned
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed", "indexing"]
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 2
assert total == 2
# Verify where was called multiple times (base filters + status filter)
assert mock_query.where.call_count >= 2
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
"""
Test with empty status list.
Verifies:
- Empty status list is handled correctly
- No status filter is applied to avoid WHERE false condition
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = []
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
"""
Test keyword search functionality.
Verifies:
- Keyword filter uses ilike for case-insensitive search
- Search pattern includes wildcards (%keyword%)
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
keyword = "search term"
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", content="This contains search term"
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
# Assert
assert len(items) == 1
assert total == 1
# Verify where was called for base filters + keyword filter
assert mock_query.where.call_count == 2
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
"""
Test ordering by position and id.
Verifies:
- Results are ordered by position ASC
- Results are secondarily ordered by id ASC to ensure uniqueness
- This prevents duplicate data across pages when positions are not unique
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
# Create segments with same position but different ids
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="Content 1"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=1, content="Content 2"
)
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-3", position=2, content="Content 3"
)
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2, segment3]
mock_paginated.total = 3
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert len(items) == 3
assert total == 3
mock_query.order_by.assert_called_once()
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
"""
Test when no segments match the criteria.
Verifies:
- Empty list is returned for items
- Total count is 0
"""
# Arrange
document_id = "non-existent-doc"
tenant_id = "tenant-123"
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert items == []
assert total == 0
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
"""
Test with multiple filters combined.
Verifies:
- All filters work together correctly
- Status list and keyword search both applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed"]
keyword = "important"
page = 2
limit = 10
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1",
status="completed",
content="This is important information",
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=status_list,
keyword=keyword,
page=page,
limit=limit,
)
# Assert
assert len(items) == 1
assert total == 1
# Verify filters: base + status + keyword
assert mock_query.where.call_count == 3
# Verify pagination parameters
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
"""
Test with None status list.
Verifies:
- None status list is handled correctly
- No status filter is applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=None,
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters only, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
"""
Test that max_per_page is correctly set to 100.
Verifies:
- max_per_page parameter is set to 100
- This prevents excessive page sizes
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
limit = 200 # Request more than max_per_page
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
limit=limit,
)
# Assert
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["max_per_page"] == 100

View File

@ -0,0 +1,520 @@
"""
Unit tests for document indexing sync task.
This module tests the document indexing sync task functionality including:
- Syncing Notion documents when updated
- Validating document and data source existence
- Credential validation and retrieval
- Cleaning old segments before re-indexing
- Error handling and edge cases
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_sync_task import document_indexing_sync_task
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def tenant_id():
"""Generate a unique tenant ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def dataset_id():
"""Generate a unique dataset ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def document_id():
"""Generate a unique document ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def notion_workspace_id():
"""Generate a Notion workspace ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def notion_page_id():
"""Generate a Notion page ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def credential_id():
"""Generate a credential ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_dataset(dataset_id, tenant_id):
"""Create a mock Dataset object."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@pytest.fixture
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
"""Create a mock Document object with Notion data source."""
doc = Mock(spec=Document)
doc.id = document_id
doc.dataset_id = dataset_id
doc.tenant_id = tenant_id
doc.data_source_type = "notion_import"
doc.indexing_status = "completed"
doc.error = None
doc.stopped_at = None
doc.processing_started_at = None
doc.doc_form = "text_model"
doc.data_source_info_dict = {
"notion_workspace_id": notion_workspace_id,
"notion_page_id": notion_page_id,
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
"credential_id": credential_id,
}
return doc
@pytest.fixture
def mock_document_segments(document_id):
"""Create mock DocumentSegment objects."""
segments = []
for i in range(3):
segment = Mock(spec=DocumentSegment)
segment.id = str(uuid.uuid4())
segment.document_id = document_id
segment.index_node_id = f"node-{document_id}-{i}"
segments.append(segment)
return segments
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
@pytest.fixture
def mock_datasource_provider_service():
"""Mock DatasourceProviderService."""
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
mock_service = MagicMock()
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
mock_service_class.return_value = mock_service
yield mock_service
@pytest.fixture
def mock_notion_extractor():
"""Mock NotionExtractor."""
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
mock_extractor_class.return_value = mock_extractor
yield mock_extractor
@pytest.fixture
def mock_index_processor_factory():
"""Mock IndexProcessorFactory."""
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_processor.clean = Mock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
yield mock_factory
@pytest.fixture
def mock_indexing_runner():
"""Mock IndexingRunner."""
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
mock_runner = MagicMock(spec=IndexingRunner)
mock_runner.run = Mock()
mock_runner_class.return_value = mock_runner
yield mock_runner
# ============================================================================
# Tests for document_indexing_sync_task
# ============================================================================
class TestDocumentIndexingSyncTask:
"""Tests for the document_indexing_sync_task function."""
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
"""Test that task handles document not found gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = None
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_db_session.close.assert_called_once()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing."""
# Arrange
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_page_id is missing."""
# Arrange
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when data_source_info is empty."""
# Arrange
mock_document.data_source_info_dict = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_credential_not_found(
self,
mock_db_session,
mock_datasource_provider_service,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles missing credentials by updating document status."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None
mock_db_session.commit.assert_called()
mock_db_session.close.assert_called()
def test_page_not_updated(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task does nothing when page has not been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# No session operations should be performed beyond the initial query
mock_db_session.close.assert_not_called()
def test_successful_sync_when_page_updated(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test successful sync flow when Notion page has been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Verify document status was updated to parsing
assert mock_document.indexing_status == "parsing"
assert mock_document.processing_started_at is not None
# Verify segments were cleaned
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations
assert mock_db_session.commit.called
mock_db_session.close.assert_called_once()
def test_dataset_not_found_during_cleaning(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles dataset not found during cleaning phase."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Document should still be set to parsing
assert mock_document.indexing_status == "parsing"
# Session should be closed after error
mock_db_session.close.assert_called_once()
def test_cleaning_error_continues_to_indexing(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
dataset_id,
document_id,
):
"""Test that indexing continues even if cleaning fails."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document])
mock_db_session.close.assert_called_once()
def test_indexing_runner_document_paused_error(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that DocumentIsPausedError is handled gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Session should be closed after handling error
mock_db_session.close.assert_called_once()
def test_indexing_runner_general_error(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that general exceptions during indexing are handled."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Session should be closed after error
mock_db_session.close.assert_called_once()
def test_notion_extractor_initialized_with_correct_params(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
notion_workspace_id,
notion_page_id,
):
"""Test that NotionExtractor is initialized with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
# Act
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
mock_extractor_class.return_value = mock_extractor
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_extractor_class.assert_called_once_with(
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_page_id,
notion_page_type="page",
notion_access_token="test_token",
tenant_id=mock_document.tenant_id,
)
def test_datasource_credentials_requested_correctly(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
credential_id,
):
"""Test that datasource credentials are requested with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
tenant_id=mock_document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
def test_credential_id_missing_uses_none(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles missing credential_id by passing None."""
# Arrange
mock_document.data_source_info_dict = {
"notion_workspace_id": "ws123",
"notion_page_id": "page123",
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
tenant_id=mock_document.tenant_id,
credential_id=None,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
def test_index_processor_clean_called_with_correct_params(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that index processor clean is called with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
mock_processor.clean.assert_called_once_with(
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
)

View File

@ -0,0 +1,122 @@
import base64
from unittest.mock import Mock, patch
import pytest
from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextResourceContents,
)
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
from core.tools.mcp_tool.tool import MCPTool
def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool:
identity = ToolIdentity(
author="test",
name="test_mcp_tool",
label=I18nObject(en_US="Test MCP Tool", zh_Hans="测试MCP工具"),
provider="test_provider",
)
entity = ToolEntity(identity=identity, output_schema=output_schema or {})
runtime = Mock(spec=ToolRuntime)
runtime.credentials = {}
return MCPTool(
entity=entity,
runtime=runtime,
tenant_id="test_tenant",
icon="",
server_url="https://server.invalid",
provider_id="provider_1",
headers={},
)
class TestMCPToolInvoke:
@pytest.mark.parametrize(
("content_factory", "mime_type"),
[
(
lambda b64, mt: ImageContent(type="image", data=b64, mimeType=mt),
"image/png",
),
(
lambda b64, mt: AudioContent(type="audio", data=b64, mimeType=mt),
"audio/mpeg",
),
],
)
def test_invoke_image_or_audio_yields_blob(self, content_factory, mime_type) -> None:
tool = _make_mcp_tool()
raw = b"\x00\x01test-bytes\x02"
b64 = base64.b64encode(raw).decode()
content = content_factory(b64, mime_type)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
assert msg.message.blob == raw
assert msg.meta == {"mime_type": mime_type}
def test_invoke_embedded_text_resource_yields_text(self) -> None:
tool = _make_mcp_tool()
text_resource = TextResourceContents(uri="file://test.txt", mimeType="text/plain", text="hello world")
content = EmbeddedResource(type="resource", resource=text_resource)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(msg.message, ToolInvokeMessage.TextMessage)
assert msg.message.text == "hello world"
@pytest.mark.parametrize(
("mime_type", "expected_mime"),
[("application/pdf", "application/pdf"), (None, "application/octet-stream")],
)
def test_invoke_embedded_blob_resource_yields_blob(self, mime_type, expected_mime) -> None:
tool = _make_mcp_tool()
raw = b"binary-data"
b64 = base64.b64encode(raw).decode()
blob_resource = BlobResourceContents(uri="file://doc.bin", mimeType=mime_type, blob=b64)
content = EmbeddedResource(type="resource", resource=blob_resource)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
assert msg.message.blob == raw
assert msg.meta == {"mime_type": expected_mime}
def test_invoke_yields_variables_when_structured_content_and_schema(self) -> None:
tool = _make_mcp_tool(output_schema={"type": "object"})
result = CallToolResult(content=[], structuredContent={"a": 1, "b": "x"})
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
# Expect two variable messages corresponding to keys a and b
assert len(messages) == 2
var_msgs = [m for m in messages if isinstance(m.message, ToolInvokeMessage.VariableMessage)]
assert {m.message.variable_name for m in var_msgs} == {"a", "b"}
# Validate values
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
assert values == {"a": 1, "b": "x"}