mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 08:28:03 +08:00
Merge branch 'main' into feat/agent-node-v2
This commit is contained in:
@ -7,11 +7,14 @@ CODE_LANGUAGE = CodeLanguage.JINJA2
|
||||
|
||||
|
||||
def test_jinja2():
|
||||
"""Test basic Jinja2 template rendering."""
|
||||
template = "Hello {{template}}"
|
||||
# Template must be base64 encoded to match the new safe embedding approach
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
@ -21,6 +24,7 @@ def test_jinja2():
|
||||
|
||||
|
||||
def test_jinja2_with_code_template():
|
||||
"""Test template rendering via the high-level workflow API."""
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
|
||||
)
|
||||
@ -28,7 +32,64 @@ def test_jinja2_with_code_template():
|
||||
|
||||
|
||||
def test_jinja2_get_runner_script():
|
||||
"""Test that runner script contains required placeholders."""
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
|
||||
|
||||
def test_jinja2_template_with_special_characters():
|
||||
"""
|
||||
Test that templates with special characters (quotes, newlines) render correctly.
|
||||
This is a regression test for issue #26818 where textarea pre-fill values
|
||||
containing special characters would break template rendering.
|
||||
"""
|
||||
# Template with triple quotes, single quotes, double quotes, and newlines
|
||||
template = """<html>
|
||||
<body>
|
||||
<input value="{{ task.get('Task ID', '') }}"/>
|
||||
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
|
||||
<p>Status: "{{ status }}"</p>
|
||||
<pre>'''code block'''</pre>
|
||||
</body>
|
||||
</html>"""
|
||||
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
|
||||
|
||||
# Verify the template rendered correctly with all special characters
|
||||
output = result["result"]
|
||||
assert 'value="TASK-123"' in output
|
||||
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
|
||||
assert 'Status: "completed"' in output
|
||||
assert "'''code block'''" in output
|
||||
|
||||
|
||||
def test_jinja2_template_with_html_textarea_prefill():
|
||||
"""
|
||||
Specific test for HTML textarea with Jinja2 variable pre-fill.
|
||||
Verifies fix for issue #26818.
|
||||
"""
|
||||
template = "<textarea name='notes'>{{ notes }}</textarea>"
|
||||
notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes."
|
||||
inputs = {"notes": notes_content}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
|
||||
|
||||
expected_output = f"<textarea name='notes'>{notes_content}</textarea>"
|
||||
assert result["result"] == expected_output
|
||||
|
||||
|
||||
def test_jinja2_assemble_runner_script_encodes_template():
|
||||
"""Test that assemble_runner_script properly base64 encodes the template."""
|
||||
template = "Hello {{ name }}!"
|
||||
inputs = {"name": "World"}
|
||||
|
||||
script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs)
|
||||
|
||||
# The template should be base64 encoded in the script
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
assert template_b64 in script
|
||||
# The raw template should NOT appear in the script (it's encoded)
|
||||
assert "Hello {{ name }}!" not in script
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
|
||||
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||
ssl_verify=True,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = Executor(
|
||||
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
|
||||
)
|
||||
|
||||
# Get assembled headers
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
# When api_key is empty, the custom header should NOT be set
|
||||
assert "X-Custom-Auth" not in headers
|
||||
# Create executor should raise AuthorizationConfigError
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||
"""
|
||||
Test that custom authorization doesn't set header when api_key is empty.
|
||||
This test verifies the fix for issue #23554.
|
||||
Test that custom authorization raises error when api_key is empty.
|
||||
This test verifies the fix for issue #21830.
|
||||
"""
|
||||
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# Custom header should NOT be set when api_key is empty
|
||||
assert "X-Custom-Auth:" not in data
|
||||
# Should fail with AuthorizationConfigError
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "API key is required" in result.error
|
||||
assert result.error_type == "AuthorizationConfigError"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
|
||||
@ -0,0 +1,365 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class TestBillingServiceGetPlanBulkWithCache:
|
||||
"""
|
||||
Comprehensive integration tests for get_plan_bulk_with_cache using testcontainers.
|
||||
|
||||
This test class covers all major scenarios:
|
||||
- Cache hit/miss scenarios
|
||||
- Redis operation failures and fallback behavior
|
||||
- Invalid cache data handling
|
||||
- TTL expiration handling
|
||||
- Error recovery and logging
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_redis_cleanup(self, flask_app_with_containers):
|
||||
"""Clean up Redis cache before and after each test."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Clean up before test
|
||||
yield
|
||||
# Clean up after test
|
||||
# Delete all test cache keys
|
||||
pattern = f"{BillingService._PLAN_CACHE_KEY_PREFIX}*"
|
||||
keys = redis_client.keys(pattern)
|
||||
if keys:
|
||||
redis_client.delete(*keys)
|
||||
|
||||
def _create_test_plan_data(self, plan: str = "sandbox", expiration_date: int = 1735689600):
|
||||
"""Helper to create test SubscriptionPlan data."""
|
||||
return {"plan": plan, "expiration_date": expiration_date}
|
||||
|
||||
def _set_cache(self, tenant_id: str, plan_data: dict, ttl: int = 600):
|
||||
"""Helper to set cache data in Redis."""
|
||||
cache_key = BillingService._make_plan_cache_key(tenant_id)
|
||||
json_str = json.dumps(plan_data)
|
||||
redis_client.setex(cache_key, ttl, json_str)
|
||||
|
||||
def _get_cache(self, tenant_id: str):
|
||||
"""Helper to get cache data from Redis."""
|
||||
cache_key = BillingService._make_plan_cache_key(tenant_id)
|
||||
value = redis_client.get(cache_key)
|
||||
if value:
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return value
|
||||
return None
|
||||
|
||||
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers):
|
||||
"""Test bulk plan retrieval when all tenants are in cache."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||
}
|
||||
|
||||
# Pre-populate cache
|
||||
for tenant_id, plan_data in expected_plans.items():
|
||||
self._set_cache(tenant_id, plan_data)
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-1"]["expiration_date"] == 1735689600
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
assert result["tenant-2"]["expiration_date"] == 1767225600
|
||||
assert result["tenant-3"]["plan"] == "team"
|
||||
assert result["tenant-3"]["expiration_date"] == 1798761600
|
||||
|
||||
# Verify API was not called
|
||||
mock_get_plan_bulk.assert_not_called()
|
||||
|
||||
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers):
|
||||
"""Test bulk plan retrieval when all tenants are not in cache."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
|
||||
# Verify API was called with correct tenant_ids
|
||||
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||
|
||||
# Verify data was written to cache
|
||||
cached_1 = self._get_cache("tenant-1")
|
||||
cached_2 = self._get_cache("tenant-2")
|
||||
assert cached_1 is not None
|
||||
assert cached_2 is not None
|
||||
|
||||
# Verify cache content
|
||||
cached_data_1 = json.loads(cached_1)
|
||||
cached_data_2 = json.loads(cached_2)
|
||||
assert cached_data_1 == expected_plans["tenant-1"]
|
||||
assert cached_data_2 == expected_plans["tenant-2"]
|
||||
|
||||
# Verify TTL is set
|
||||
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
|
||||
ttl_1 = redis_client.ttl(cache_key_1)
|
||||
assert ttl_1 > 0
|
||||
assert ttl_1 <= 600 # Should be <= 600 seconds
|
||||
|
||||
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers):
|
||||
"""Test bulk plan retrieval when some tenants are in cache, some are not."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||
# Pre-populate cache for tenant-1 and tenant-2
|
||||
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||
self._set_cache("tenant-2", self._create_test_plan_data("professional", 1767225600))
|
||||
|
||||
# tenant-3 is not in cache
|
||||
missing_plan = {"tenant-3": self._create_test_plan_data("team", 1798761600)}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=missing_plan) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
assert result["tenant-3"]["plan"] == "team"
|
||||
|
||||
# Verify API was called only for missing tenant
|
||||
mock_get_plan_bulk.assert_called_once_with(["tenant-3"])
|
||||
|
||||
# Verify tenant-3 data was written to cache
|
||||
cached_3 = self._get_cache("tenant-3")
|
||||
assert cached_3 is not None
|
||||
cached_data_3 = json.loads(cached_3)
|
||||
assert cached_data_3 == missing_plan["tenant-3"]
|
||||
|
||||
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers):
|
||||
"""Test fallback to API when Redis mget fails."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(redis_client, "mget", side_effect=Exception("Redis connection error")),
|
||||
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk,
|
||||
):
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
|
||||
# Verify API was called for all tenants (fallback)
|
||||
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||
|
||||
# Verify data was written to cache after fallback
|
||||
cached_1 = self._get_cache("tenant-1")
|
||||
cached_2 = self._get_cache("tenant-2")
|
||||
assert cached_1 is not None
|
||||
assert cached_2 is not None
|
||||
|
||||
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers):
|
||||
"""Test fallback to API when cache contains invalid JSON."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||
|
||||
# Set valid cache for tenant-1
|
||||
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||
|
||||
# Set invalid JSON for tenant-2
|
||||
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||
redis_client.setex(cache_key_2, 600, "invalid json {")
|
||||
|
||||
# tenant-3 is not in cache
|
||||
expected_plans = {
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["tenant-1"]["plan"] == "sandbox" # From cache
|
||||
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
|
||||
assert result["tenant-3"]["plan"] == "team" # From API
|
||||
|
||||
# Verify API was called for tenant-2 and tenant-3
|
||||
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
|
||||
|
||||
# Verify tenant-2's invalid JSON was replaced with correct data in cache
|
||||
cached_2 = self._get_cache("tenant-2")
|
||||
assert cached_2 is not None
|
||||
cached_data_2 = json.loads(cached_2)
|
||||
assert cached_data_2 == expected_plans["tenant-2"]
|
||||
assert cached_data_2["plan"] == "professional"
|
||||
assert cached_data_2["expiration_date"] == 1767225600
|
||||
|
||||
# Verify tenant-2 cache has correct TTL
|
||||
cache_key_2_new = BillingService._make_plan_cache_key("tenant-2")
|
||||
ttl_2 = redis_client.ttl(cache_key_2_new)
|
||||
assert ttl_2 > 0
|
||||
assert ttl_2 <= 600
|
||||
|
||||
# Verify tenant-3 data was also written to cache
|
||||
cached_3 = self._get_cache("tenant-3")
|
||||
assert cached_3 is not None
|
||||
cached_data_3 = json.loads(cached_3)
|
||||
assert cached_data_3 == expected_plans["tenant-3"]
|
||||
|
||||
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers):
|
||||
"""Test fallback to API when cache data doesn't match SubscriptionPlan schema."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||
|
||||
# Set valid cache for tenant-1
|
||||
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||
|
||||
# Set invalid plan data for tenant-2 (missing expiration_date)
|
||||
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||
invalid_data = json.dumps({"plan": "professional"}) # Missing expiration_date
|
||||
redis_client.setex(cache_key_2, 600, invalid_data)
|
||||
|
||||
# tenant-3 is not in cache
|
||||
expected_plans = {
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["tenant-1"]["plan"] == "sandbox" # From cache
|
||||
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
|
||||
assert result["tenant-3"]["plan"] == "team" # From API
|
||||
|
||||
# Verify API was called for tenant-2 and tenant-3
|
||||
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
|
||||
|
||||
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers):
|
||||
"""Test that pipeline failure doesn't affect return value."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans),
|
||||
patch.object(redis_client, "pipeline") as mock_pipeline,
|
||||
):
|
||||
# Create a mock pipeline that fails on execute
|
||||
mock_pipe = mock_pipeline.return_value
|
||||
mock_pipe.execute.side_effect = Exception("Pipeline execution failed")
|
||||
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert - Function should still return correct result despite pipeline failure
|
||||
assert len(result) == 2
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
|
||||
# Verify pipeline was attempted
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers):
|
||||
"""Test with empty tenant_ids list."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache([])
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
assert len(result) == 0
|
||||
|
||||
# Verify no API calls
|
||||
mock_get_plan_bulk.assert_not_called()
|
||||
|
||||
# Verify no Redis operations (mget with empty list would return empty list)
|
||||
# But we should check that mget was not called at all
|
||||
# Since we can't easily verify this without more mocking, we just verify the result
|
||||
|
||||
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers):
|
||||
"""Test that expired cache keys are treated as cache misses."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
|
||||
# Set cache for tenant-1 with very short TTL (1 second) to simulate expiration
|
||||
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600), ttl=1)
|
||||
|
||||
# Wait for TTL to expire (key will be deleted by Redis)
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
# Verify cache is expired (key doesn't exist)
|
||||
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
|
||||
exists = redis_client.exists(cache_key_1)
|
||||
assert exists == 0 # Key doesn't exist (expired)
|
||||
|
||||
# tenant-2 is not in cache
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
|
||||
# Verify API was called for both tenants (tenant-1 expired, tenant-2 missing)
|
||||
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||
|
||||
# Verify both were written to cache with correct TTL
|
||||
cache_key_1_new = BillingService._make_plan_cache_key("tenant-1")
|
||||
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||
ttl_1_new = redis_client.ttl(cache_key_1_new)
|
||||
ttl_2 = redis_client.ttl(cache_key_2)
|
||||
assert ttl_1_new > 0
|
||||
assert ttl_1_new <= 600
|
||||
assert ttl_2 > 0
|
||||
assert ttl_2 <= 600
|
||||
@ -0,0 +1,682 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
|
||||
class TestTriggerProviderService:
|
||||
"""Integration tests for TriggerProviderService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager,
|
||||
patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client,
|
||||
patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_provider_controller = MagicMock()
|
||||
mock_provider_controller.get_credential_schema_config.return_value = MagicMock()
|
||||
mock_provider_controller.get_properties_schema.return_value = MagicMock()
|
||||
mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller
|
||||
|
||||
# Mock redis lock
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock(return_value=None)
|
||||
mock_lock.__exit__ = MagicMock(return_value=None)
|
||||
mock_redis_client.lock.return_value = mock_lock
|
||||
|
||||
# Setup account feature service mock
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
yield {
|
||||
"trigger_manager": mock_trigger_manager,
|
||||
"redis_client": mock_redis_client,
|
||||
"delete_cache": mock_delete_cache,
|
||||
"provider_controller": mock_provider_controller,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies[
|
||||
"trigger_manager"
|
||||
].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"]
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_subscription(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
tenant_id,
|
||||
user_id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
credentials,
|
||||
mock_external_service_dependencies,
|
||||
):
|
||||
"""
|
||||
Helper method to create a test trigger subscription.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID
|
||||
provider_id: Provider ID
|
||||
credential_type: Credential type
|
||||
credentials: Credentials dict
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
TriggerSubscription: Created subscription instance
|
||||
"""
|
||||
fake = Faker()
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import create_provider_encrypter
|
||||
|
||||
# Use mock provider controller to encrypt credentials
|
||||
provider_controller = mock_external_service_dependencies["provider_controller"]
|
||||
|
||||
# Create encrypter for credentials
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
subscription = TriggerSubscription(
|
||||
name=fake.word(),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
endpoint_id=fake.uuid4(),
|
||||
parameters={"param1": "value1"},
|
||||
properties={"prop1": "value1"},
|
||||
credentials=dict(credential_encrypter.encrypt(credentials)),
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=-1,
|
||||
expires_at=-1,
|
||||
)
|
||||
|
||||
db.session.add(subscription)
|
||||
db.session.commit()
|
||||
db.session.refresh(subscription)
|
||||
|
||||
return subscription
|
||||
|
||||
def test_rebuild_trigger_subscription_success_with_merged_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
|
||||
|
||||
This test verifies:
|
||||
- Credentials are properly merged (HIDDEN_VALUE replaced with existing values)
|
||||
- Single transaction wraps all operations
|
||||
- Merged credentials are used for subscribe and update
|
||||
- Database state is correctly updated
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create initial subscription with credentials
|
||||
original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Prepare new credentials with HIDDEN_VALUE for api_key (should keep original)
|
||||
# and new value for api_secret (should update)
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE, # Should be replaced with original
|
||||
"api_secret": "new-secret-value", # Should be updated
|
||||
}
|
||||
|
||||
# Mock subscribe_trigger to return a new subscription entity
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={"param1": "value1"},
|
||||
properties={"prop1": "new_prop_value"},
|
||||
expires_at=1234567890,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
|
||||
# Mock unsubscribe_trigger
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={"param1": "updated_value"},
|
||||
name="updated_name",
|
||||
)
|
||||
|
||||
# Verify unsubscribe was called with decrypted original credentials
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once()
|
||||
unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args
|
||||
assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id
|
||||
assert unsubscribe_call_args.kwargs["provider_id"] == provider_id
|
||||
assert unsubscribe_call_args.kwargs["credential_type"] == credential_type
|
||||
|
||||
# Verify subscribe was called with merged credentials (api_key from original, api_secret new)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original
|
||||
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
|
||||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.name == "updated_name"
|
||||
assert subscription.parameters == {"param1": "updated_value"}
|
||||
|
||||
# Verify credentials in DB were updated with merged values (decrypt to check)
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import create_provider_encrypter
|
||||
|
||||
# Use mock provider controller to decrypt credentials
|
||||
provider_controller = mock_external_service_dependencies["provider_controller"]
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant.id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
assert decrypted_db_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert decrypted_db_credentials["api_secret"] == "new-secret-value"
|
||||
|
||||
# Verify cache was cleared
|
||||
mock_external_service_dependencies["delete_cache"].assert_called_once_with(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_with_all_new_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when all credentials are new (no HIDDEN_VALUE).
|
||||
|
||||
This test verifies:
|
||||
- All new credentials are used when no HIDDEN_VALUE is present
|
||||
- Merged credentials contain only new values
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create initial subscription
|
||||
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# All new credentials (no HIDDEN_VALUE)
|
||||
new_credentials = {
|
||||
"api_key": "completely-new-key",
|
||||
"api_secret": "completely-new-secret",
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with all new credentials
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == "completely-new-key"
|
||||
assert subscribe_credentials["api_secret"] == "completely-new-secret"
|
||||
|
||||
def test_rebuild_trigger_subscription_with_all_hidden_values(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
|
||||
|
||||
This test verifies:
|
||||
- All HIDDEN_VALUE credentials are replaced with existing values
|
||||
- Original credentials are preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# All HIDDEN_VALUE (should preserve all original)
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE,
|
||||
"api_secret": HIDDEN_VALUE,
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with all original credentials
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
|
||||
|
||||
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
|
||||
|
||||
This test verifies:
|
||||
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Original has only api_key
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE,
|
||||
"non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
|
||||
|
||||
def test_rebuild_trigger_subscription_rollback_on_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that transaction is rolled back on error.
|
||||
|
||||
This test verifies:
|
||||
- Database transaction is rolled back when an error occurs
|
||||
- Original subscription state is preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
original_name = subscription.name
|
||||
original_parameters = subscription.parameters.copy()
|
||||
|
||||
# Make subscribe_trigger raise an error
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError(
|
||||
"Subscribe failed"
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild and expect error
|
||||
with pytest.raises(ValueError, match="Subscribe failed"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscription state was not changed (rolled back)
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.name == original_name
|
||||
assert subscription.parameters == original_parameters
|
||||
|
||||
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that unsubscribe errors are handled gracefully and operation continues.
|
||||
|
||||
This test verifies:
|
||||
- Unsubscribe errors are caught and logged but don't stop the rebuild
|
||||
- Rebuild continues even if unsubscribe fails
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Make unsubscribe_trigger raise an error (should be caught and continue)
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
|
||||
"Unsubscribe failed"
|
||||
)
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
|
||||
# Execute rebuild - should succeed despite unsubscribe error
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was still called (operation continued)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||
|
||||
# Verify subscription was updated
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.parameters == {}
|
||||
|
||||
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when subscription is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised when subscription doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
fake_subscription_id = fake.uuid4()
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=fake_subscription_id,
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when provider is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised when provider doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
|
||||
|
||||
# Make get_trigger_provider return None
|
||||
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Provider.*not found"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=fake.uuid4(),
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_unsupported_credential_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when credential type is not supported for rebuild.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.UNAUTHORIZED # Not supported
|
||||
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_name_uniqueness_check(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that name uniqueness is checked when updating name.
|
||||
|
||||
This test verifies:
|
||||
- Error is raised when new name conflicts with existing subscription
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create first subscription
|
||||
subscription1 = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{"api_key": "key1"},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Create second subscription with different name
|
||||
subscription2 = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{"api_key": "key2"},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription2.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Try to rename subscription2 to subscription1's name (should fail)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription2.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
name=subscription1.name, # Conflicting name
|
||||
)
|
||||
@ -2,7 +2,9 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from models import Account, Tenant
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
@ -298,7 +300,7 @@ class TestApiToolManageService:
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
@ -364,7 +366,7 @@ class TestApiToolManageService:
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
@ -428,21 +430,10 @@ class TestApiToolManageService:
|
||||
labels = ["test"]
|
||||
|
||||
# Act & Assert: Try to create provider with invalid schema type
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TypeAdapter(ApiProviderSchemaType).validate_python(schema_type)
|
||||
|
||||
assert "invalid schema type" in str(exc_info.value)
|
||||
assert "validation error" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
@ -464,7 +455,7 @@ class TestApiToolManageService:
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {} # Missing auth_type
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
@ -507,7 +498,7 @@ class TestApiToolManageService:
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔑"}
|
||||
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
||||
@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
|
||||
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
||||
]
|
||||
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
# Setup mock client
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.return_value = mock_tools
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
result = service._reconnect_provider(
|
||||
result = MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
|
||||
assert tools_data[1]["name"] == "test_tool_2"
|
||||
|
||||
# Verify mock interactions
|
||||
provider_entity = mcp_provider.to_entity()
|
||||
mock_mcp_client.assert_called_once()
|
||||
mock_mcp_client.assert_called_once_with(
|
||||
server_url="https://example.com/mcp",
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Mock MCPClient to raise authentication error
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
from core.mcp.error import MCPAuthError
|
||||
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
result = service._reconnect_provider(
|
||||
result = MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Mock MCPClient to raise connection error
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
from core.mcp.error import MCPError
|
||||
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
||||
service._reconnect_provider(
|
||||
MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
@ -705,3 +705,207 @@ class TestWorkflowToolManageService:
|
||||
db.session.refresh(created_tool)
|
||||
assert created_tool.name == first_tool_name
|
||||
assert created_tool.updated_at is not None
|
||||
|
||||
def test_create_workflow_tool_with_file_parameter_default(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation with FILE parameter having a file object as default.
|
||||
|
||||
This test verifies:
|
||||
- FILE parameters can have file object defaults
|
||||
- The default value (dict with id/base64Url) is properly handled
|
||||
- Tool creation succeeds without Pydantic validation errors
|
||||
|
||||
Related issue: Array[File] default value causes Pydantic validation errors.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create workflow graph with a FILE variable that has a default value
|
||||
workflow_graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start_node",
|
||||
"data": {
|
||||
"type": "start",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "document",
|
||||
"label": "Document",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
"default": {"id": fake.uuid4(), "base64Url": ""},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
workflow.graph = json.dumps(workflow_graph)
|
||||
|
||||
# Setup workflow tool parameters with FILE type
|
||||
file_parameters = [
|
||||
{
|
||||
"name": "document",
|
||||
"description": "Upload a document",
|
||||
"form": "form",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
}
|
||||
]
|
||||
|
||||
# Execute the method under test
|
||||
# Note: from_db is mocked, so this test primarily validates the parameter configuration
|
||||
result = WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "📄"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=file_parameters,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_workflow_tool_with_files_parameter_default(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
|
||||
|
||||
This test verifies:
|
||||
- FILES parameters can have a list of file objects as default
|
||||
- The default value (list of dicts with id/base64Url) is properly handled
|
||||
- Tool creation succeeds without Pydantic validation errors
|
||||
|
||||
Related issue: Array[File] default value causes 4 Pydantic validation errors
|
||||
because PluginParameter.default only accepts Union[float, int, str, bool] | None.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create workflow graph with a FILE_LIST variable that has a default value
|
||||
workflow_graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start_node",
|
||||
"data": {
|
||||
"type": "start",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "documents",
|
||||
"label": "Documents",
|
||||
"type": "file-list",
|
||||
"required": False,
|
||||
"default": [
|
||||
{"id": fake.uuid4(), "base64Url": ""},
|
||||
{"id": fake.uuid4(), "base64Url": ""},
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
workflow.graph = json.dumps(workflow_graph)
|
||||
|
||||
# Setup workflow tool parameters with FILES type
|
||||
files_parameters = [
|
||||
{
|
||||
"name": "documents",
|
||||
"description": "Upload multiple documents",
|
||||
"form": "form",
|
||||
"type": "files",
|
||||
"required": False,
|
||||
}
|
||||
]
|
||||
|
||||
# Execute the method under test
|
||||
# Note: from_db is mocked, so this test primarily validates the parameter configuration
|
||||
result = WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "📁"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=files_parameters,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_workflow_tool_db_commit_before_validation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that database commit happens before validation, causing DB pollution on validation failure.
|
||||
|
||||
This test verifies the second bug:
|
||||
- WorkflowToolProvider is committed to database BEFORE from_db validation
|
||||
- If validation fails, the record remains in the database
|
||||
- Subsequent attempts fail with "Tool already exists" error
|
||||
|
||||
This demonstrates why we need to validate BEFORE database commit.
|
||||
"""
|
||||
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
tool_name = fake.word()
|
||||
|
||||
# Mock from_db to raise validation error
|
||||
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.side_effect = ValueError(
|
||||
"Validation failed: default parameter type mismatch"
|
||||
)
|
||||
|
||||
# Attempt to create workflow tool (will fail at validation stage)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=self._create_test_workflow_tool_parameters(),
|
||||
)
|
||||
|
||||
assert "Validation failed" in str(exc_info.value)
|
||||
|
||||
# Verify the tool was NOT created in database
|
||||
# This is the expected behavior (no pollution)
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.name == tool_name,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# The record should NOT exist because the transaction should be rolled back
|
||||
# Currently, due to the bug, the record might exist (this test documents the bug)
|
||||
# After the fix, this should always be 0
|
||||
# For now, we document that the record may exist, demonstrating the bug
|
||||
# assert tool_count == 0 # Expected after fix
|
||||
|
||||
@ -12,10 +12,12 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
template = "Hello {{template}}"
|
||||
# Template must be base64 encoded to match the new safe embedding approach
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
@ -37,6 +39,34 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
|
||||
def test_jinja2_template_with_special_characters(self, flask_app_with_containers):
|
||||
"""
|
||||
Test that templates with special characters (quotes, newlines) render correctly.
|
||||
This is a regression test for issue #26818 where textarea pre-fill values
|
||||
containing special characters would break template rendering.
|
||||
"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
# Template with triple quotes, single quotes, double quotes, and newlines
|
||||
template = """<html>
|
||||
<body>
|
||||
<input value="{{ task.get('Task ID', '') }}"/>
|
||||
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
|
||||
<p>Status: "{{ status }}"</p>
|
||||
<pre>'''code block'''</pre>
|
||||
</body>
|
||||
</html>"""
|
||||
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
|
||||
|
||||
# Verify the template rendered correctly with all special characters
|
||||
output = result["result"]
|
||||
assert 'value="TASK-123"' in output
|
||||
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
|
||||
assert 'Status: "completed"' in output
|
||||
assert "'''code block'''" in output
|
||||
|
||||
@ -0,0 +1,46 @@
|
||||
from flask import Response
|
||||
|
||||
from controllers.common.file_response import enforce_download_for_html, is_html_content
|
||||
|
||||
|
||||
class TestFileResponseHelpers:
|
||||
def test_is_html_content_detects_mime_type(self):
|
||||
mime_type = "text/html; charset=UTF-8"
|
||||
|
||||
result = is_html_content(mime_type, filename="file.txt", extension="txt")
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_html_content_detects_extension(self):
|
||||
result = is_html_content("text/plain", filename="report.html", extension=None)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_enforce_download_for_html_sets_headers(self):
|
||||
response = Response("payload", mimetype="text/html")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
response,
|
||||
mime_type="text/html",
|
||||
filename="unsafe.html",
|
||||
extension="html",
|
||||
)
|
||||
|
||||
assert updated is True
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_enforce_download_for_html_no_change_for_non_html(self):
|
||||
response = Response("payload", mimetype="text/plain")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
response,
|
||||
mime_type="text/plain",
|
||||
filename="notes.txt",
|
||||
extension="txt",
|
||||
)
|
||||
|
||||
assert updated is False
|
||||
assert "Content-Disposition" not in response.headers
|
||||
assert "X-Content-Type-Options" not in response.headers
|
||||
@ -163,34 +163,17 @@ class TestActivateApi:
|
||||
"account": mock_account,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_pair(self):
|
||||
"""Create mock token pair object."""
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "access_token"
|
||||
token_pair.refresh_token = "refresh_token"
|
||||
token_pair.csrf_token = "csrf_token"
|
||||
token_pair.model_dump.return_value = {
|
||||
"access_token": "access_token",
|
||||
"refresh_token": "refresh_token",
|
||||
"csrf_token": "csrf_token",
|
||||
}
|
||||
return token_pair
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_successful_account_activation(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test successful account activation.
|
||||
@ -198,12 +181,10 @@ class TestActivateApi:
|
||||
Verifies that:
|
||||
- Account is activated with user preferences
|
||||
- Account status is set to ACTIVE
|
||||
- User is logged in after activation
|
||||
- Invitation token is revoked
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
@ -230,7 +211,6 @@ class TestActivateApi:
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_login.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
def test_activation_with_invalid_token(self, mock_get_invitation, app):
|
||||
@ -264,17 +244,14 @@ class TestActivateApi:
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_activation_sets_interface_theme(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test that activation sets default interface theme.
|
||||
@ -284,7 +261,6 @@ class TestActivateApi:
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
@ -317,17 +293,14 @@ class TestActivateApi:
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_activation_with_different_locales(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
language,
|
||||
timezone,
|
||||
):
|
||||
@ -341,7 +314,6 @@ class TestActivateApi:
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
@ -367,27 +339,23 @@ class TestActivateApi:
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_activation_returns_token_data(
|
||||
def test_activation_returns_success_response(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test that activation returns authentication tokens.
|
||||
Test that activation returns a success response without authentication tokens.
|
||||
|
||||
Verifies that:
|
||||
- Token pair is returned in response
|
||||
- All token types are included (access, refresh, csrf)
|
||||
- Response contains a success result
|
||||
- No token data is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
@ -406,24 +374,18 @@ class TestActivateApi:
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert "data" in response
|
||||
assert response["data"]["access_token"] == "access_token"
|
||||
assert response["data"]["refresh_token"] == "refresh_token"
|
||||
assert response["data"]["csrf_token"] == "csrf_token"
|
||||
assert response == {"result": "success"}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_activation_without_workspace_id(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test account activation without workspace_id.
|
||||
@ -434,7 +396,6 @@ class TestActivateApi:
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
|
||||
236
api/tests/unit_tests/controllers/console/test_extension.py
Normal file
236
api/tests/unit_tests/controllers/console/test_extension.py
Normal file
@ -0,0 +1,236 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView as FlaskMethodView
|
||||
|
||||
_NEEDS_METHOD_VIEW_CLEANUP = False
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = FlaskMethodView
|
||||
_NEEDS_METHOD_VIEW_CLEANUP = True
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console.extension import (
|
||||
APIBasedExtensionAPI,
|
||||
APIBasedExtensionDetailAPI,
|
||||
CodeBasedExtensionAPI,
|
||||
)
|
||||
|
||||
if _NEEDS_METHOD_VIEW_CLEANUP:
|
||||
delattr(builtins, "MethodView")
|
||||
from models.account import AccountStatus
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
|
||||
def _make_extension(
|
||||
*,
|
||||
name: str = "Sample Extension",
|
||||
api_endpoint: str = "https://example.com/api",
|
||||
api_key: str = "super-secret-key",
|
||||
) -> APIBasedExtension:
|
||||
extension = APIBasedExtension(
|
||||
tenant_id="tenant-123",
|
||||
name=name,
|
||||
api_endpoint=api_endpoint,
|
||||
api_key=api_key,
|
||||
)
|
||||
extension.id = f"{uuid.uuid4()}"
|
||||
extension.created_at = datetime.now(tz=UTC)
|
||||
return extension
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
"""Bypass console decorators so handlers can run in isolation."""
|
||||
|
||||
import controllers.console.extension as extension_module
|
||||
from controllers.console import wraps as wraps_module
|
||||
|
||||
account = MagicMock()
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.current_tenant_id = "tenant-123"
|
||||
account.id = "account-123"
|
||||
account.is_authenticated = True
|
||||
|
||||
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True)
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
monkeypatch.setattr(extension_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
|
||||
|
||||
# The login_required decorator consults the shared LocalProxy in libs.login.
|
||||
monkeypatch.setattr("libs.login.current_user", account)
|
||||
monkeypatch.setattr("libs.login.check_csrf_token", lambda *_, **__: None)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restx_mask_defaults(app: Flask):
|
||||
app.config.setdefault("RESTX_MASK_HEADER", "X-Fields")
|
||||
app.config.setdefault("RESTX_MASK_SWAGGER", False)
|
||||
|
||||
|
||||
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
service_result = {"entrypoint": "main:agent"}
|
||||
service_mock = MagicMock(return_value=service_result)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/code-based-extension",
|
||||
method="GET",
|
||||
query_string={"module": "workflow.tools"},
|
||||
):
|
||||
response = CodeBasedExtensionAPI().get()
|
||||
|
||||
assert response == {"module": "workflow.tools", "data": service_result}
|
||||
service_mock.assert_called_once_with("workflow.tools")
|
||||
|
||||
|
||||
def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
extension = _make_extension(name="Weather API", api_key="abcdefghi123")
|
||||
service_mock = MagicMock(return_value=[extension])
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_all_by_tenant_id",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/api-based-extension", method="GET"):
|
||||
response = APIBasedExtensionAPI().get()
|
||||
|
||||
assert response[0]["id"] == extension.id
|
||||
assert response[0]["name"] == "Weather API"
|
||||
assert response[0]["api_endpoint"] == extension.api_endpoint
|
||||
assert response[0]["api_key"].startswith(extension.api_key[:3])
|
||||
service_mock.assert_called_once_with("tenant-123")
|
||||
|
||||
|
||||
def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
saved_extension = _make_extension(name="Docs API", api_key="saved-secret")
|
||||
save_mock = MagicMock(return_value=saved_extension)
|
||||
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||
|
||||
payload = {
|
||||
"name": "Docs API",
|
||||
"api_endpoint": "https://docs.example.com/hook",
|
||||
"api_key": "plain-secret",
|
||||
}
|
||||
|
||||
with app.test_request_context("/console/api/api-based-extension", method="POST", json=payload):
|
||||
response = APIBasedExtensionAPI().post()
|
||||
|
||||
args, _ = save_mock.call_args
|
||||
created_extension: APIBasedExtension = args[0]
|
||||
assert created_extension.tenant_id == "tenant-123"
|
||||
assert created_extension.name == payload["name"]
|
||||
assert created_extension.api_endpoint == payload["api_endpoint"]
|
||||
assert created_extension.api_key == payload["api_key"]
|
||||
assert response["name"] == saved_extension.name
|
||||
save_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
extension = _make_extension(name="Docs API", api_key="abcdefg12345")
|
||||
service_mock = MagicMock(return_value=extension)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
extension_id = uuid.uuid4()
|
||||
with app.test_request_context(f"/console/api/api-based-extension/{extension_id}", method="GET"):
|
||||
response = APIBasedExtensionDetailAPI().get(extension_id)
|
||||
|
||||
assert response["id"] == extension.id
|
||||
assert response["name"] == extension.name
|
||||
service_mock.assert_called_once_with("tenant-123", str(extension_id))
|
||||
|
||||
|
||||
def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
existing_extension = _make_extension(name="Docs API", api_key="keep-me")
|
||||
get_mock = MagicMock(return_value=existing_extension)
|
||||
save_mock = MagicMock(return_value=existing_extension)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||
get_mock,
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||
|
||||
payload = {
|
||||
"name": "Docs API Updated",
|
||||
"api_endpoint": "https://docs.example.com/v2",
|
||||
"api_key": HIDDEN_VALUE,
|
||||
}
|
||||
|
||||
extension_id = uuid.uuid4()
|
||||
with app.test_request_context(
|
||||
f"/console/api/api-based-extension/{extension_id}",
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||
|
||||
assert existing_extension.name == payload["name"]
|
||||
assert existing_extension.api_endpoint == payload["api_endpoint"]
|
||||
assert existing_extension.api_key == "keep-me"
|
||||
save_mock.assert_called_once_with(existing_extension)
|
||||
assert response["name"] == payload["name"]
|
||||
|
||||
|
||||
def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
existing_extension = _make_extension(name="Docs API", api_key="old-secret")
|
||||
get_mock = MagicMock(return_value=existing_extension)
|
||||
save_mock = MagicMock(return_value=existing_extension)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||
get_mock,
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||
|
||||
payload = {
|
||||
"name": "Docs API Updated",
|
||||
"api_endpoint": "https://docs.example.com/v2",
|
||||
"api_key": "new-secret",
|
||||
}
|
||||
|
||||
extension_id = uuid.uuid4()
|
||||
with app.test_request_context(
|
||||
f"/console/api/api-based-extension/{extension_id}",
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||
|
||||
assert existing_extension.api_key == "new-secret"
|
||||
save_mock.assert_called_once_with(existing_extension)
|
||||
assert response["name"] == payload["name"]
|
||||
|
||||
|
||||
def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
existing_extension = _make_extension()
|
||||
get_mock = MagicMock(return_value=existing_extension)
|
||||
delete_mock = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||
get_mock,
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.delete", delete_mock)
|
||||
|
||||
extension_id = uuid.uuid4()
|
||||
with app.test_request_context(
|
||||
f"/console/api/api-based-extension/{extension_id}",
|
||||
method="DELETE",
|
||||
):
|
||||
response, status = APIBasedExtensionDetailAPI().delete(extension_id)
|
||||
|
||||
delete_mock.assert_called_once_with(existing_extension)
|
||||
assert response == {"result": "success"}
|
||||
assert status == 204
|
||||
@ -0,0 +1,145 @@
|
||||
"""Unit tests for load balancing credential validation APIs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Reload controller module with lightweight decorators for testing."""
|
||||
|
||||
from controllers.console import console_ns, wraps
|
||||
from libs import login
|
||||
|
||||
def _noop(func):
|
||||
return func
|
||||
|
||||
monkeypatch.setattr(login, "login_required", _noop)
|
||||
monkeypatch.setattr(wraps, "setup_required", _noop)
|
||||
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
|
||||
|
||||
def _noop_route(*args, **kwargs): # type: ignore[override]
|
||||
def _decorator(cls):
|
||||
return cls
|
||||
|
||||
return _decorator
|
||||
|
||||
monkeypatch.setattr(console_ns, "route", _noop_route)
|
||||
|
||||
module_name = "controllers.console.workspace.load_balancing_config"
|
||||
sys.modules.pop(module_name, None)
|
||||
module = importlib.import_module(module_name)
|
||||
return module
|
||||
|
||||
|
||||
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
|
||||
return SimpleNamespace(current_role=role)
|
||||
|
||||
|
||||
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
|
||||
user = _mock_user(role)
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
|
||||
mock_service = MagicMock()
|
||||
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
|
||||
return mock_service
|
||||
|
||||
|
||||
def _request_payload():
|
||||
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
|
||||
|
||||
|
||||
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||
|
||||
assert response == {"result": "success"}
|
||||
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-***"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||
|
||||
assert response == {"result": "error", "error": "invalid credentials"}
|
||||
|
||||
|
||||
def test_validate_credentials_requires_privileged_role(
|
||||
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.post(provider="openai")
|
||||
|
||||
|
||||
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
|
||||
provider="openai", config_id="cfg-1"
|
||||
)
|
||||
|
||||
assert response == {"result": "success"}
|
||||
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-***"},
|
||||
config_id="cfg-1",
|
||||
)
|
||||
@ -0,0 +1,100 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||
# They are intentionally no-ops because the test already patches the required
|
||||
# behaviors explicitly via @patch and context managers below.
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_user_tenant():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
api = Api(app)
|
||||
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
|
||||
db.init_app(app)
|
||||
# Configure session factory used by controller code
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
|
||||
)
|
||||
@patch("controllers.console.workspace.tool_providers.Session")
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
|
||||
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||
def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client):
|
||||
# Arrange: reconnect returns tools immediately
|
||||
mock_reconnect.return_value = ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps(
|
||||
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
|
||||
),
|
||||
encrypted_credentials="{}",
|
||||
)
|
||||
|
||||
# Fake service.create_provider -> returns object with id for reload
|
||||
svc = MagicMock()
|
||||
create_result = MagicMock()
|
||||
create_result.id = "provider-1"
|
||||
svc.create_provider.return_value = create_result
|
||||
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
# Patch MCPToolManageService constructed inside controller
|
||||
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
|
||||
payload = {
|
||||
"server_url": "http://example.com/mcp",
|
||||
"name": "demo",
|
||||
"icon": "😀",
|
||||
"icon_type": "emoji",
|
||||
"icon_background": "#000",
|
||||
"server_identifier": "demo-sid",
|
||||
"configuration": {"timeout": 5, "sse_read_timeout": 30},
|
||||
"headers": {},
|
||||
"authentication": {},
|
||||
}
|
||||
# Act
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
|
||||
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
|
||||
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
|
||||
patch(
|
||||
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||
),
|
||||
):
|
||||
resp = client.post(
|
||||
"/console/api/workspaces/current/tool-provider/mcp",
|
||||
data=json.dumps(payload),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body.get("id") == "provider-1"
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
@ -41,6 +41,7 @@ class TestFilePreviewApi:
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.id = str(uuid.uuid4())
|
||||
upload_file.name = "test_file.jpg"
|
||||
upload_file.extension = "jpg"
|
||||
upload_file.mime_type = "image/jpeg"
|
||||
upload_file.size = 1024
|
||||
upload_file.key = "storage/key/test_file.jpg"
|
||||
@ -210,6 +211,19 @@ class TestFilePreviewApi:
|
||||
assert mock_upload_file.name in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
|
||||
def test_build_file_response_html_forces_attachment(self, file_preview_api, mock_upload_file):
|
||||
"""Test HTML files are forced to download"""
|
||||
mock_generator = Mock()
|
||||
mock_upload_file.mime_type = "text/html"
|
||||
mock_upload_file.name = "unsafe.html"
|
||||
mock_upload_file.extension = "html"
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
|
||||
"""Test file response building for audio/video files"""
|
||||
mock_generator = Mock()
|
||||
|
||||
195
api/tests/unit_tests/controllers/web/test_forgot_password.py
Normal file
195
api/tests/unit_tests/controllers/web/test_forgot_password.py
Normal file
@ -0,0 +1,195 @@
|
||||
"""Unit tests for controllers.web.forgot_password endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_controller_module():
|
||||
"""Import controllers.web.forgot_password using a stub package."""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
parent_module_name = "controllers.web"
|
||||
module_name = f"{parent_module_name}.forgot_password"
|
||||
|
||||
if parent_module_name not in sys.modules:
|
||||
from flask_restx import Namespace
|
||||
|
||||
stub = ModuleType(parent_module_name)
|
||||
stub.__file__ = "controllers/web/__init__.py"
|
||||
stub.__path__ = ["controllers/web"]
|
||||
stub.__package__ = "controllers"
|
||||
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
|
||||
stub.web_ns = Namespace("web", description="Web API", path="/")
|
||||
sys.modules[parent_module_name] = stub
|
||||
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
forgot_password_module = _load_controller_module()
|
||||
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
|
||||
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
|
||||
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Configure a minimal Flask app for request contexts."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_web_endpoint_guards():
|
||||
"""Stub enterprise and feature toggles used by route decorators."""
|
||||
|
||||
features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_controller_db():
|
||||
"""Replace controller-level db reference with a simple stub."""
|
||||
|
||||
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
|
||||
fake_wraps_db = SimpleNamespace(
|
||||
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
|
||||
)
|
||||
with (
|
||||
patch("controllers.web.forgot_password.db", fake_db),
|
||||
patch("controllers.console.wraps.db", fake_wraps_db),
|
||||
):
|
||||
yield fake_db
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
|
||||
def test_send_reset_email_success(
|
||||
mock_extract_ip: MagicMock,
|
||||
mock_is_ip_limit: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_send_email: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password returns token when email exists and limits allow."""
|
||||
|
||||
mock_account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "user@example.com"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "reset-token"}
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
|
||||
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
|
||||
def test_check_token_success(
|
||||
mock_is_rate_limited: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke: MagicMock,
|
||||
mock_generate: MagicMock,
|
||||
mock_reset_limit: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/validity validates the code and refreshes token."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limited.assert_called_once_with("user@example.com")
|
||||
mock_get_data.assert_called_once_with("old-token")
|
||||
mock_revoke.assert_called_once_with("old-token")
|
||||
mock_generate.assert_called_once_with(
|
||||
"user@example.com",
|
||||
code="123456",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_success(
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_token_bytes: MagicMock,
|
||||
mock_hash_password: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/resets updates the stored password when token is valid."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
|
||||
account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_data.assert_called_once_with("reset-token")
|
||||
mock_revoke_token.assert_called_once_with("reset-token")
|
||||
mock_token_bytes.assert_called_once_with(16)
|
||||
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||
assert account.password == expected_password
|
||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||
assert account.password_salt == expected_salt
|
||||
session_ctx.commit.assert_called_once()
|
||||
@ -287,7 +287,7 @@ def test_validate_inputs_optional_file_with_empty_string():
|
||||
|
||||
|
||||
def test_validate_inputs_optional_file_list_with_empty_list():
|
||||
"""Test that optional FILE_LIST variable with empty list returns None"""
|
||||
"""Test that optional FILE_LIST variable with empty list returns empty list (not None)"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var_file_list = VariableEntity(
|
||||
@ -302,6 +302,28 @@ def test_validate_inputs_optional_file_list_with_empty_list():
|
||||
value=[],
|
||||
)
|
||||
|
||||
# Empty list should be preserved, not converted to None
|
||||
# This allows downstream components like document_extractor to handle empty lists properly
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_validate_inputs_optional_file_list_with_empty_string():
|
||||
"""Test that optional FILE_LIST variable with empty string returns None"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var_file_list = VariableEntity(
|
||||
variable="test_file_list",
|
||||
label="test_file_list",
|
||||
type=VariableEntityType.FILE_LIST,
|
||||
required=False,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file_list,
|
||||
value="",
|
||||
)
|
||||
|
||||
# Empty string should be treated as unset
|
||||
assert result is None
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,420 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueuePingEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
|
||||
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock(spec=ChatAppGenerateEntity)
|
||||
entity.task_id = "test-task-id"
|
||||
entity.app_id = "test-app-id"
|
||||
# minimal app_config used by pipeline internals
|
||||
entity.app_config = SimpleNamespace(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.CHAT,
|
||||
app_model_config_dict={},
|
||||
additional_features=None,
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
# minimal model_conf for LLMResult init
|
||||
entity.model_conf = SimpleNamespace(
|
||||
model="test-model",
|
||||
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
|
||||
credentials={},
|
||||
)
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = Mock(spec=AppQueueManager)
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_cycle_manager(self):
|
||||
"""Create a mock message cycle manager."""
|
||||
manager = Mock()
|
||||
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
|
||||
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
|
||||
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
|
||||
manager.handle_retriever_resources = Mock()
|
||||
manager.handle_annotation_reply.return_value = None
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation(self):
|
||||
"""Create a mock conversation."""
|
||||
conversation = Mock()
|
||||
conversation.id = "test-conversation-id"
|
||||
conversation.mode = "chat"
|
||||
return conversation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message."""
|
||||
message = Mock()
|
||||
message.id = "test-message-id"
|
||||
message.created_at = Mock()
|
||||
message.created_at.timestamp.return_value = 1234567890
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_state(self):
|
||||
"""Create a mock task state."""
|
||||
task_state = Mock(spec=EasyUITaskState)
|
||||
|
||||
# Create LLM result mock
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.prompt_messages = []
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = ""
|
||||
|
||||
task_state.llm_result = llm_result
|
||||
task_state.answer = ""
|
||||
|
||||
return task_state
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(
|
||||
self,
|
||||
mock_application_generate_entity,
|
||||
mock_queue_manager,
|
||||
mock_conversation,
|
||||
mock_message,
|
||||
mock_message_cycle_manager,
|
||||
mock_task_state,
|
||||
):
|
||||
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
|
||||
):
|
||||
pipeline = EasyUIBasedGenerateTaskPipeline(
|
||||
application_generate_entity=mock_application_generate_entity,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
stream=True,
|
||||
)
|
||||
pipeline._message_cycle_manager = mock_message_cycle_manager
|
||||
pipeline._task_state = mock_task_state
|
||||
return pipeline
|
||||
|
||||
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
|
||||
self, pipeline, mock_message_cycle_manager
|
||||
):
|
||||
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
|
||||
# Setup a minimal LLM chunk event
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "hi"
|
||||
chunk.prompt_messages = []
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
|
||||
|
||||
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with text content."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Hello, world!"
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello, world!"
|
||||
|
||||
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with list content."""
|
||||
# Setup
|
||||
text_content = Mock(spec=TextPromptMessageContent)
|
||||
text_content.data = "Hello"
|
||||
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = [text_content, " world!"]
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of agent message events."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Agent response"
|
||||
|
||||
agent_message_event = Mock(spec=QueueAgentMessageEvent)
|
||||
agent_message_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = agent_message_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Ensure method under assertion is a mock to track calls
|
||||
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
# Agent messages should use _agent_message_to_stream_response
|
||||
pipeline._agent_message_to_stream_response.assert_called_once_with(
|
||||
answer="Agent response", message_id="test-message-id"
|
||||
)
|
||||
|
||||
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of message end events."""
|
||||
# Setup
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = "Final response"
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = llm_result
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert mock_task_state.llm_result == llm_result
|
||||
pipeline._save_message.assert_called_once()
|
||||
pipeline._message_end_to_stream_response.assert_called_once()
|
||||
|
||||
def test_error_event(self, pipeline):
|
||||
"""Test handling of error events."""
|
||||
# Setup
|
||||
error_event = Mock(spec=QueueErrorEvent)
|
||||
error_event.error = Exception("Test error")
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = error_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.handle_error = Mock(return_value=Exception("Test error"))
|
||||
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.handle_error.assert_called_once()
|
||||
pipeline.error_to_stream_response.assert_called_once()
|
||||
|
||||
def test_ping_event(self, pipeline):
|
||||
"""Test handling of ping events."""
|
||||
# Setup
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.ping_stream_response.assert_called_once()
|
||||
|
||||
def test_file_event(self, pipeline, mock_message_cycle_manager):
|
||||
"""Test handling of file events."""
|
||||
# Setup
|
||||
file_event = Mock(spec=QueueMessageFileEvent)
|
||||
file_event.message_file_id = "file-id"
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = file_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
file_response = Mock(spec=MessageFileStreamResponse)
|
||||
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert responses[0] == file_response
|
||||
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
|
||||
|
||||
def test_publisher_is_called_with_messages(self, pipeline):
|
||||
"""Test that publisher publishes messages when provided."""
|
||||
# Setup
|
||||
publisher = Mock(spec=AppGeneratorTTSPublisher)
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
# Called once with message and once with None at the end
|
||||
assert publisher.publish.call_count == 2
|
||||
publisher.publish.assert_any_call(mock_queue_message)
|
||||
publisher.publish.assert_any_call(None)
|
||||
|
||||
def test_trace_manager_passed_to_save_message(self, pipeline):
|
||||
"""Test that trace manager is passed to _save_message."""
|
||||
# Setup
|
||||
trace_manager = Mock(spec=TraceQueueManager)
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = None
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
|
||||
|
||||
# Assert
|
||||
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
|
||||
|
||||
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling multiple events in sequence."""
|
||||
# Setup
|
||||
chunk1 = Mock()
|
||||
chunk1.delta.message.content = "Hello"
|
||||
chunk1.prompt_messages = []
|
||||
|
||||
chunk2 = Mock()
|
||||
chunk2.delta.message.content = " world!"
|
||||
chunk2.prompt_messages = []
|
||||
|
||||
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event1.chunk = chunk1
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event2.chunk = chunk2
|
||||
|
||||
mock_queue_messages = [
|
||||
Mock(event=llm_chunk_event1),
|
||||
Mock(event=ping_event),
|
||||
Mock(event=llm_chunk_event2),
|
||||
]
|
||||
pipeline.queue_manager.listen.return_value = mock_queue_messages
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 3
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
# Verify calls to message_to_stream_response
|
||||
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
@ -0,0 +1,166 @@
|
||||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
|
||||
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
|
||||
|
||||
class TestMessageCycleManagerOptimization:
|
||||
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock()
|
||||
entity.task_id = "test-task-id"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def message_cycle_manager(self, mock_application_generate_entity):
|
||||
"""Create a message cycle manager instance."""
|
||||
task_state = Mock()
|
||||
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
|
||||
|
||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and no message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute: compute event type once, then pass to message_to_stream_response
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
# Execute with event_type provided
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
# Should not query database when event_type is provided
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world",
|
||||
message_id="test-message-id",
|
||||
from_variable_selector=["var1", "var2"],
|
||||
event_type=StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.from_variable_selector == ["var1", "var2"]
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
|
||||
def test_optimization_usage_example(self, message_cycle_manager):
|
||||
"""Test the optimization pattern that should be used by callers."""
|
||||
# Step 1: Get event type once (this queries database)
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None # No files
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Should query database once
|
||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
||||
assert event_type == StreamEvent.MESSAGE
|
||||
|
||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
||||
|
||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
chunk2_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Should not query database again
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
assert chunk1_response.answer == "Chunk 1"
|
||||
assert chunk2_response.answer == "Chunk 2"
|
||||
@ -96,7 +96,7 @@ class TestNotionExtractorAuthentication:
|
||||
def test_init_with_integration_token_fallback(self, mock_get_token, mock_config, mock_document_model):
|
||||
"""Test NotionExtractor falls back to integration token when credential not found."""
|
||||
# Arrange
|
||||
mock_get_token.return_value = None
|
||||
mock_get_token.side_effect = Exception("No credential id found")
|
||||
mock_config.NOTION_INTEGRATION_TOKEN = "integration-token-fallback"
|
||||
|
||||
# Act
|
||||
@ -105,7 +105,7 @@ class TestNotionExtractorAuthentication:
|
||||
notion_obj_id="page-456",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant-789",
|
||||
credential_id="cred-123",
|
||||
credential_id=None,
|
||||
document_model=mock_document_model,
|
||||
)
|
||||
|
||||
@ -117,7 +117,7 @@ class TestNotionExtractorAuthentication:
|
||||
def test_init_missing_credentials_raises_error(self, mock_get_token, mock_config, mock_document_model):
|
||||
"""Test NotionExtractor raises error when no credentials available."""
|
||||
# Arrange
|
||||
mock_get_token.return_value = None
|
||||
mock_get_token.side_effect = Exception("No credential id found")
|
||||
mock_config.NOTION_INTEGRATION_TOKEN = None
|
||||
|
||||
# Act & Assert
|
||||
@ -127,7 +127,7 @@ class TestNotionExtractorAuthentication:
|
||||
notion_obj_id="page-456",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant-789",
|
||||
credential_id="cred-123",
|
||||
credential_id=None,
|
||||
document_model=mock_document_model,
|
||||
)
|
||||
assert "Must specify `integration_token`" in str(exc_info.value)
|
||||
|
||||
@ -1,52 +1,109 @@
|
||||
import secrets
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||
from core.helper.ssrf_proxy import (
|
||||
SSRF_DEFAULT_MAX_RETRIES,
|
||||
_get_user_provided_host_header,
|
||||
make_request,
|
||||
)
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_successful_request(mock_request):
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_successful_request(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_request.return_value = mock_response
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_exceed_max_retries(mock_request):
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_retry_exceed_max_retries(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
|
||||
mock_request.side_effect = side_effects
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_logic_success(mock_request):
|
||||
side_effects = []
|
||||
class TestGetUserProvidedHostHeader:
|
||||
"""Tests for _get_user_provided_host_header function."""
|
||||
|
||||
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
|
||||
status_code = secrets.choice(STATUS_FORCELIST)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
side_effects.append(mock_response)
|
||||
def test_returns_none_when_headers_is_none(self):
|
||||
assert _get_user_provided_host_header(None) is None
|
||||
|
||||
mock_response_200 = MagicMock()
|
||||
mock_response_200.status_code = 200
|
||||
side_effects.append(mock_response_200)
|
||||
def test_returns_none_when_headers_is_empty(self):
|
||||
assert _get_user_provided_host_header({}) is None
|
||||
|
||||
mock_request.side_effect = side_effects
|
||||
def test_returns_none_when_host_header_not_present(self):
|
||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}
|
||||
assert _get_user_provided_host_header(headers) is None
|
||||
|
||||
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
||||
def test_returns_host_header_lowercase(self):
|
||||
headers = {"host": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_uppercase(self):
|
||||
headers = {"HOST": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_mixed_case(self):
|
||||
headers = {"HoSt": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_from_multiple_headers(self):
|
||||
headers = {"Content-Type": "application/json", "Host": "api.example.com", "Authorization": "Bearer token"}
|
||||
assert _get_user_provided_host_header(headers) == "api.example.com"
|
||||
|
||||
def test_returns_first_host_header_when_duplicates(self):
|
||||
headers = {"host": "first.com", "Host": "second.com"}
|
||||
# Should return the first one encountered (iteration order is preserved in dict)
|
||||
result = _get_user_provided_host_header(headers)
|
||||
assert result in ("first.com", "second.com")
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_without_user_header(mock_get_client):
|
||||
"""Test that when no Host header is provided, the default behavior is maintained."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
|
||||
assert response.status_code == 200
|
||||
# Host should not be set if not provided by user
|
||||
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_with_user_header(mock_get_client):
|
||||
"""Test that user-provided Host header is preserved in the request."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
custom_host = "custom.example.com:8080"
|
||||
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
||||
assert mock_request.call_args_list[0][1].get("method") == "GET"
|
||||
|
||||
@ -1,129 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client():
|
||||
"""Fixture: Mock Redis client"""
|
||||
with patch("core.helper.tool_provider_cache.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class TestToolProviderListCache:
|
||||
"""Test class for ToolProviderListCache"""
|
||||
|
||||
def test_generate_cache_key(self):
|
||||
"""Test cache key generation logic"""
|
||||
# Scenario 1: Specify typ (valid literal value)
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
|
||||
|
||||
# Scenario 2: typ is None (defaults to "all")
|
||||
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
|
||||
|
||||
def test_get_cached_providers_hit(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit and successful decoding"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "api"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
|
||||
|
||||
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
|
||||
assert result == mock_providers
|
||||
|
||||
def test_get_cached_providers_decode_error(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit but decoding failed"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = b"invalid_json_data"
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_get_cached_providers_miss(self, mock_redis_client):
|
||||
"""Test get cached providers - cache miss"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_set_cached_providers(self, mock_redis_client):
|
||||
"""Test set cached providers"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
|
||||
|
||||
mock_redis_client.setex.assert_called_once_with(
|
||||
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
|
||||
)
|
||||
|
||||
def test_invalidate_cache_specific_type(self, mock_redis_client):
|
||||
"""Test invalidate cache - specific type"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "workflow"
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id, typ)
|
||||
|
||||
mock_redis_client.delete.assert_called_once_with(cache_key)
|
||||
|
||||
def test_invalidate_cache_all_types(self, mock_redis_client):
|
||||
"""Test invalidate cache - clear all tenant cache"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_keys = [
|
||||
b"tool_providers:tenant_id:tenant_123:type:all",
|
||||
b"tool_providers:tenant_id:tenant_123:type:builtin",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = mock_keys
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
def test_invalidate_cache_no_keys(self, mock_redis_client):
|
||||
"""Test invalidate cache - no cache keys for tenant"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.scan_iter.return_value = []
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
def test_redis_fallback_default_return(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - default return value (Redis error)"""
|
||||
mock_redis_client.get.side_effect = RedisError("Redis connection error")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers("tenant_123")
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_redis_fallback_no_default(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - no default return value (Redis error)"""
|
||||
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
|
||||
|
||||
try:
|
||||
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
|
||||
except RedisError:
|
||||
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
|
||||
|
||||
mock_redis_client.setex.assert_called_once()
|
||||
@ -1,6 +1,9 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
@ -136,3 +139,51 @@ class TestValidateProjectName:
|
||||
"""Test custom default name"""
|
||||
result = validate_project_name("", "Custom Default")
|
||||
assert result == "Custom Default"
|
||||
|
||||
|
||||
class TestGenerateDottedOrder:
|
||||
"""Test cases for generate_dotted_order function"""
|
||||
|
||||
def test_dotted_order_has_6_digit_microseconds(self):
|
||||
"""Test that timestamp includes full 6-digit microseconds for LangSmith API compatibility.
|
||||
|
||||
LangSmith API expects timestamps in format: YYYYMMDDTHHMMSSffffffZ (6-digit microseconds).
|
||||
Previously, the code truncated to 3 digits which caused API errors:
|
||||
'cannot parse .111 as .000000'
|
||||
"""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "test-run-id"
|
||||
result = generate_dotted_order(run_id, start_time)
|
||||
|
||||
# Extract timestamp portion (before the run_id)
|
||||
timestamp_match = re.match(r"^(\d{8}T\d{6})(\d+)Z", result)
|
||||
assert timestamp_match is not None, "Timestamp format should match YYYYMMDDTHHMMSSffffffZ"
|
||||
|
||||
microseconds = timestamp_match.group(2)
|
||||
assert len(microseconds) == 6, f"Microseconds should be 6 digits, got {len(microseconds)}: {microseconds}"
|
||||
|
||||
def test_dotted_order_format_matches_langsmith_expected(self):
|
||||
"""Test that dotted_order format matches LangSmith API expected format."""
|
||||
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456)
|
||||
run_id = "abc123"
|
||||
result = generate_dotted_order(run_id, start_time)
|
||||
|
||||
# LangSmith expects: YYYYMMDDTHHMMSSffffffZ followed by run_id
|
||||
assert result == "20250115T103045123456Zabc123"
|
||||
|
||||
def test_dotted_order_with_parent(self):
|
||||
"""Test dotted_order generation with parent order uses dot separator."""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "child-run-id"
|
||||
parent_order = "20251223T041955000000Zparent-run-id"
|
||||
result = generate_dotted_order(run_id, start_time, parent_order)
|
||||
|
||||
assert result == "20251223T041955000000Zparent-run-id.20251223T041955111000Zchild-run-id"
|
||||
|
||||
def test_dotted_order_without_parent_has_no_dot(self):
|
||||
"""Test dotted_order generation without parent has no dot separator."""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "test-run-id"
|
||||
result = generate_dotted_order(run_id, start_time, None)
|
||||
|
||||
assert "." not in result
|
||||
|
||||
@ -0,0 +1,327 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import (
|
||||
PGVector,
|
||||
PGVectorConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestPGVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=False,
|
||||
)
|
||||
self.collection_name = "test_collection"
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test PGVector initialization."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
assert pgvector._collection_name == self.collection_name
|
||||
assert pgvector.table_name == f"embedding_{self.collection_name}"
|
||||
assert pgvector.get_type() == "pgvector"
|
||||
assert pgvector.pool is not None
|
||||
assert pgvector.pg_bigm is False
|
||||
assert pgvector.index_hash is not None
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init_with_pg_bigm(self, mock_pool_class):
|
||||
"""Test PGVector initialization with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
|
||||
assert pgvector.pg_bigm is True
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_basic(self, mock_redis, mock_pool_class):
|
||||
"""Test basic collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify SQL execution calls
|
||||
assert mock_cursor.execute.called
|
||||
|
||||
# Check that CREATE TABLE was called with correct dimension
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(1536)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that CREATE INDEX was called (dimension <= 2000)
|
||||
create_index_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
|
||||
]
|
||||
assert len(create_index_calls) == 1
|
||||
|
||||
# Verify Redis cache was set
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with dimension > 2000 (no HNSW index)."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(3072) # Dimension > 2000
|
||||
|
||||
# Check that CREATE TABLE was called
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(3072)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that HNSW index was NOT created (dimension > 2000)
|
||||
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
|
||||
assert len(hnsw_index_calls) == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that pg_bigm index was created
|
||||
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
|
||||
assert len(bigm_index_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
|
||||
"""Test that vector extension is created if it doesn't exist."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
# First call: vector extension doesn't exist
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that CREATE EXTENSION was called
|
||||
create_extension_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
|
||||
]
|
||||
assert len(create_extension_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
|
||||
"""Test that collection creation is skipped when cache exists."""
|
||||
# Mock Redis operations - cache exists
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = 1 # Cache exists
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that no SQL was executed (early return due to cache)
|
||||
assert mock_cursor.execute.call_count == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
|
||||
"""Test that Redis lock is used during collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify Redis lock was acquired with correct lock name
|
||||
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
|
||||
|
||||
# Verify lock context manager was entered and exited
|
||||
mock_lock.__enter__.assert_called_once()
|
||||
mock_lock.__exit__.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_get_cursor_context_manager(self, mock_pool_class):
|
||||
"""Test that _get_cursor properly manages connection lifecycle."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
with pgvector._get_cursor() as cur:
|
||||
assert cur == mock_cursor
|
||||
|
||||
# Verify connection lifecycle methods were called
|
||||
mock_pool.getconn.assert_called_once()
|
||||
mock_cursor.close.assert_called_once()
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_pool.putconn.assert_called_once_with(mock_conn)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_config_override",
|
||||
[
|
||||
{"host": ""}, # Test empty host
|
||||
{"port": 0}, # Test invalid port
|
||||
{"user": ""}, # Test empty user
|
||||
{"password": ""}, # Test empty password
|
||||
{"database": ""}, # Test empty database
|
||||
{"min_connection": 0}, # Test invalid min_connection
|
||||
{"max_connection": 0}, # Test invalid max_connection
|
||||
{"min_connection": 10, "max_connection": 5}, # Test min > max
|
||||
],
|
||||
)
|
||||
def test_config_validation_parametrized(invalid_config_override):
|
||||
"""Test configuration validation for various invalid inputs using parametrize."""
|
||||
config = {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
"database": "test_db",
|
||||
"min_connection": 1,
|
||||
"max_connection": 5,
|
||||
}
|
||||
config.update(invalid_config_override)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PGVectorConfig(**config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,5 +1,7 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
@ -25,3 +27,35 @@ def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||
|
||||
assert job_id is not None
|
||||
assert isinstance(job_id, str)
|
||||
|
||||
|
||||
def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture):
|
||||
api_key = "fc-"
|
||||
base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"]
|
||||
for base in base_urls:
|
||||
app = FirecrawlApp(api_key=api_key, base_url=base)
|
||||
mock_post = mocker.patch("httpx.post")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"id": "job123"}
|
||||
mock_post.return_value = mock_resp
|
||||
app.crawl_url("https://example.com", params=None)
|
||||
called_url = mock_post.call_args[0][0]
|
||||
assert called_url == "https://custom.firecrawl.dev/v2/crawl"
|
||||
|
||||
|
||||
def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture):
|
||||
api_key = "fc-"
|
||||
app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/")
|
||||
mock_post = mocker.patch("httpx.post")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
mock_resp.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
app.scrape_url("https://example.com")
|
||||
|
||||
# Should not raise a JSONDecodeError; current behavior reports status code only
|
||||
assert str(excinfo.value) == "Failed to scrape URL. Status code: 404"
|
||||
|
||||
@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch):
|
||||
# DB interactions should be recorded
|
||||
assert len(db_stub.session.added) == 2
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
|
||||
def test_extract_images_from_docx_uses_internal_files_url():
|
||||
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
|
||||
# Test the URL generation logic directly
|
||||
from configs import dify_config
|
||||
|
||||
# Mock the configuration values
|
||||
original_files_url = getattr(dify_config, "FILES_URL", None)
|
||||
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
|
||||
|
||||
try:
|
||||
# Set both URLs - INTERNAL should take precedence
|
||||
dify_config.FILES_URL = "http://external.example.com"
|
||||
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
|
||||
|
||||
# Test the URL generation logic (same as in word_extractor.py)
|
||||
upload_file_id = "test_file_id"
|
||||
|
||||
# This is the pattern we fixed in the word extractor
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
|
||||
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
|
||||
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
if original_files_url is not None:
|
||||
dify_config.FILES_URL = original_files_url
|
||||
if original_internal_files_url is not None:
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
||||
@ -421,7 +421,18 @@ class TestRetrievalService:
|
||||
# In real code, this waits for all futures to complete
|
||||
# In tests, futures complete immediately, so wait is a no-op
|
||||
with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"):
|
||||
yield mock_executor
|
||||
# Mock concurrent.futures.as_completed for early error propagation
|
||||
# In real code, this yields futures as they complete
|
||||
# In tests, we yield all futures immediately since they're already done
|
||||
def mock_as_completed(futures_list, timeout=None):
|
||||
"""Mock as_completed that yields futures immediately."""
|
||||
yield from futures_list
|
||||
|
||||
with patch(
|
||||
"core.rag.datasource.retrieval_service.concurrent.futures.as_completed",
|
||||
side_effect=mock_as_completed,
|
||||
):
|
||||
yield mock_executor
|
||||
|
||||
# ==================== Vector Search Tests ====================
|
||||
|
||||
|
||||
@ -0,0 +1,873 @@
|
||||
"""
|
||||
Unit tests for DatasetRetrieval.process_metadata_filter_func.
|
||||
|
||||
This module provides comprehensive test coverage for the process_metadata_filter_func
|
||||
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
|
||||
filter expressions based on metadata filtering conditions.
|
||||
|
||||
Conditions Tested:
|
||||
==================
|
||||
1. **String Conditions**: contains, not contains, start with, end with
|
||||
2. **Equality Conditions**: is / =, is not / ≠
|
||||
3. **Null Conditions**: empty, not empty
|
||||
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
|
||||
5. **List Conditions**: in
|
||||
6. **Edge Cases**: None values, different data types (str, int, float)
|
||||
|
||||
Test Architecture:
|
||||
==================
|
||||
- Direct instantiation of DatasetRetrieval
|
||||
- Mocking of DatasetDocument model attributes
|
||||
- Verification of SQLAlchemy filter expressions
|
||||
- Follows Arrange-Act-Assert (AAA) pattern
|
||||
|
||||
Running Tests:
|
||||
==============
|
||||
# Run all tests in this module
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
|
||||
|
||||
# Run a specific test
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
|
||||
TestProcessMetadataFilterFunc::test_contains_condition -v
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
|
||||
class TestProcessMetadataFilterFunc:
|
||||
"""
|
||||
Comprehensive test suite for process_metadata_filter_func method.
|
||||
|
||||
This test class validates all metadata filtering conditions supported by
|
||||
the DatasetRetrieval class, including string operations, numeric comparisons,
|
||||
null checks, and list operations.
|
||||
|
||||
Method Signature:
|
||||
==================
|
||||
def process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
) -> list:
|
||||
|
||||
The method builds SQLAlchemy filter expressions by:
|
||||
1. Validating value is not None (except for empty/not empty conditions)
|
||||
2. Using DatasetDocument.doc_metadata JSON field operations
|
||||
3. Adding appropriate SQLAlchemy expressions to the filters list
|
||||
4. Returning the updated filters list
|
||||
|
||||
Mocking Strategy:
|
||||
==================
|
||||
- Mock DatasetDocument.doc_metadata to avoid database dependencies
|
||||
- Verify filter expressions are created correctly
|
||||
- Test with various data types (str, int, float, list)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval(self):
|
||||
"""
|
||||
Create a DatasetRetrieval instance for testing.
|
||||
|
||||
Returns:
|
||||
DatasetRetrieval: Instance to test process_metadata_filter_func
|
||||
"""
|
||||
return DatasetRetrieval()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_metadata(self):
|
||||
"""
|
||||
Mock the DatasetDocument.doc_metadata JSON field.
|
||||
|
||||
The method uses DatasetDocument.doc_metadata[metadata_name] to access
|
||||
JSON fields. We mock this to avoid database dependencies.
|
||||
|
||||
Returns:
|
||||
Mock: Mocked doc_metadata attribute
|
||||
"""
|
||||
mock_metadata_field = MagicMock()
|
||||
|
||||
# Create mock for string access
|
||||
mock_string_access = MagicMock()
|
||||
mock_string_access.like = MagicMock()
|
||||
mock_string_access.notlike = MagicMock()
|
||||
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.in_ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for float access (for numeric comparisons)
|
||||
mock_float_access = MagicMock()
|
||||
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for null checks
|
||||
mock_null_access = MagicMock()
|
||||
mock_null_access.is_ = MagicMock(return_value=MagicMock())
|
||||
mock_null_access.isnot = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Setup __getitem__ to return appropriate mock based on usage
|
||||
def getitem_side_effect(name):
|
||||
if name in ["author", "title", "category"]:
|
||||
return mock_string_access
|
||||
elif name in ["year", "price", "rating"]:
|
||||
return mock_float_access
|
||||
else:
|
||||
return mock_string_access
|
||||
|
||||
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
|
||||
mock_metadata_field.as_string.return_value = mock_string_access
|
||||
mock_metadata_field.as_float.return_value = mock_float_access
|
||||
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
|
||||
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
|
||||
|
||||
return mock_metadata_field
|
||||
|
||||
# ==================== String Condition Tests ====================
|
||||
|
||||
def test_contains_condition_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'contains' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "John"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_contains_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with NOT LIKE expression
|
||||
- Pattern matching uses %value% syntax with negation
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not contains"
|
||||
metadata_name = "title"
|
||||
value = "banned"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_start_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'start with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "start with"
|
||||
metadata_name = "category"
|
||||
value = "tech"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_end_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'end with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "end with"
|
||||
metadata_name = "filename"
|
||||
value = ".pdf"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Equality Condition Tests ====================
|
||||
|
||||
def test_is_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' (=) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with equality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = "Jane Doe"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_equals_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test '=' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is' condition
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "="
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_int_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with integer value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_float_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with float value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "price"
|
||||
value = 19.99
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' (≠) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with inequality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "author"
|
||||
value = "Unknown"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test '≠' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is not' condition
|
||||
- Inequality expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≠"
|
||||
metadata_name = "category"
|
||||
value = "archived"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_numeric_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' condition with numeric value.
|
||||
|
||||
Verifies:
|
||||
- Numeric inequality comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Null Condition Tests ====================
|
||||
|
||||
def test_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'empty' condition (null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "empty"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not empty' condition (not null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NOT NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not empty"
|
||||
metadata_name = "description"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Numeric Comparison Tests ====================
|
||||
|
||||
def test_before_condition(self, retrieval):
|
||||
"""
|
||||
Test 'before' (<) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "before"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '<' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'before' condition
|
||||
- Less than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "price"
|
||||
value = 100.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_after_condition(self, retrieval):
|
||||
"""
|
||||
Test 'after' (>) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "after"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '>' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'after' condition
|
||||
- Greater than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≤' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≤"
|
||||
metadata_name = "price"
|
||||
value = 50.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '<=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≤' condition
|
||||
- Less than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<="
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≥' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≥"
|
||||
metadata_name = "rating"
|
||||
value = 3.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '>=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≥' condition
|
||||
- Greater than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== List/In Condition Tests ====================
|
||||
|
||||
def test_in_condition_with_comma_separated_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with comma-separated string value.
|
||||
|
||||
Verifies:
|
||||
- String is split into list
|
||||
- Whitespace is trimmed from each value
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "tech, science, AI "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_list_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with list value.
|
||||
|
||||
Verifies:
|
||||
- List is processed correctly
|
||||
- None values are filtered out
|
||||
- IN expression is created with valid values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "tags"
|
||||
value = ["python", "javascript", None, "golang"]
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_tuple_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with tuple value.
|
||||
|
||||
Verifies:
|
||||
- Tuple is processed like a list
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ("tech", "science", "ai")
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_empty_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with empty string value.
|
||||
|
||||
Verifies:
|
||||
- Empty string results in literal(False) filter
|
||||
- No valid values to match
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
# Verify it's a literal(False) expression
|
||||
# This is a bit tricky to test without access to the actual expression
|
||||
|
||||
def test_in_condition_with_only_whitespace(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with whitespace-only string value.
|
||||
|
||||
Verifies:
|
||||
- Whitespace-only string results in literal(False) filter
|
||||
- All values are stripped and filtered out
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = " , , "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_single_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with single non-comma string.
|
||||
|
||||
Verifies:
|
||||
- Single string is treated as single-item list
|
||||
- IN expression is created with one value
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Edge Case Tests ====================
|
||||
|
||||
def test_none_value_with_non_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with conditions that require value.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values (except empty/not empty)
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0 # No filter added
|
||||
|
||||
def test_none_value_with_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with 'is' (=) condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_none_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "year"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_existing_filters_preserved(self, retrieval):
|
||||
"""
|
||||
Test that existing filters are preserved.
|
||||
|
||||
Verifies:
|
||||
- Existing filters in the list are not removed
|
||||
- New filters are appended to the list
|
||||
"""
|
||||
existing_filter = MagicMock()
|
||||
filters = [existing_filter]
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 2
|
||||
assert filters[0] == existing_filter
|
||||
|
||||
def test_multiple_filters_accumulated(self, retrieval):
|
||||
"""
|
||||
Test multiple calls to accumulate filters.
|
||||
|
||||
Verifies:
|
||||
- Each call adds a new filter to the list
|
||||
- All filters are preserved across calls
|
||||
"""
|
||||
filters = []
|
||||
|
||||
# First filter
|
||||
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
|
||||
assert len(filters) == 1
|
||||
|
||||
# Second filter
|
||||
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
|
||||
assert len(filters) == 2
|
||||
|
||||
# Third filter
|
||||
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
|
||||
assert len(filters) == 3
|
||||
|
||||
def test_unknown_condition(self, retrieval):
|
||||
"""
|
||||
Test unknown/unsupported condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for unknown conditions
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "unknown_condition"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_empty_string_value_with_contains(self, retrieval):
|
||||
"""
|
||||
Test empty string value with 'contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filter is added even with empty string
|
||||
- LIKE expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_special_characters_in_value(self, retrieval):
|
||||
"""
|
||||
Test special characters in value string.
|
||||
|
||||
Verifies:
|
||||
- Special characters are handled in value
|
||||
- LIKE expression is created correctly
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "title"
|
||||
value = "C++ & Python's features"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_zero_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test zero value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Zero is treated as valid value
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "price"
|
||||
value = 0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_negative_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test negative value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Negative numbers are handled correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "temperature"
|
||||
value = -10.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_float_value_with_integer_comparison(self, retrieval):
|
||||
"""
|
||||
Test float value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Float values work correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter:
|
||||
# Verify no empty chunks
|
||||
assert all(len(chunk) > 0 for chunk in result)
|
||||
|
||||
def test_double_slash_n(self):
|
||||
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
|
||||
separator = "\\n\\n---\\n\\n"
|
||||
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
|
||||
chunks = splitter.split_text(data)
|
||||
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Metadata Preservation
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import (
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@ -348,3 +351,127 @@ def test_init_params():
|
||||
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_bearer():
|
||||
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": ""},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_basic():
|
||||
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "basic", "api_key": ""},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_custom():
|
||||
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_whitespace_only_api_key_raises_error():
|
||||
"""Test that whitespace-only API key raises AuthorizationConfigError."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": " "},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_valid_api_key_works():
|
||||
"""Test that valid API key works correctly for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": "valid-api-key-123"},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Should not raise an error
|
||||
headers = executor._assembling_headers()
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == "Bearer valid-api-key-123"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables):
|
||||
|
||||
|
||||
def test_json_object_valid_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@ -65,7 +68,7 @@ def test_json_object_valid_schema():
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
result = node._run()
|
||||
@ -74,12 +77,23 @@ def test_json_object_valid_schema():
|
||||
|
||||
|
||||
def test_json_object_invalid_json_string():
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
json_schema=schema,
|
||||
)
|
||||
]
|
||||
|
||||
@ -88,38 +102,21 @@ def test_json_object_invalid_json_string():
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
|
||||
def test_json_object_valid_json_but_not_object(value):
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": value}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_does_not_match_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema():
|
||||
]
|
||||
|
||||
# age is a string, which violates the schema (expects number)
|
||||
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema():
|
||||
|
||||
|
||||
def test_json_object_missing_required_schema_field():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field():
|
||||
]
|
||||
|
||||
# Missing required field "name"
|
||||
user_inputs = {"profile": {"age": 20}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided():
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=False,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided():
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
# Current implementation raises a validation error even when the variable is optional
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
with pytest.raises(ValueError, match="profile is required in input form"):
|
||||
node._run()
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from qcloud_cos import CosConfig
|
||||
@ -18,3 +18,72 @@ class TestTencentCos(BaseStorageTest):
|
||||
with patch.object(CosConfig, "__init__", return_value=None):
|
||||
self.storage = TencentCosStorage()
|
||||
self.storage.bucket_name = get_example_bucket()
|
||||
|
||||
|
||||
class TestTencentCosConfiguration:
|
||||
"""Tests for TencentCosStorage initialization with different configurations."""
|
||||
|
||||
def test_init_with_custom_domain(self):
|
||||
"""Test initialization with custom domain configured."""
|
||||
# Mock dify_config to return custom domain configuration
|
||||
mock_dify_config = MagicMock()
|
||||
mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = "cos.example.com"
|
||||
mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id"
|
||||
mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key"
|
||||
mock_dify_config.TENCENT_COS_SCHEME = "https"
|
||||
|
||||
# Mock CosConfig and CosS3Client
|
||||
mock_config_instance = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
|
||||
patch(
|
||||
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
|
||||
) as mock_cos_config,
|
||||
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
|
||||
):
|
||||
TencentCosStorage()
|
||||
|
||||
# Verify CosConfig was called with Domain parameter (not Region)
|
||||
mock_cos_config.assert_called_once()
|
||||
call_kwargs = mock_cos_config.call_args[1]
|
||||
assert "Domain" in call_kwargs
|
||||
assert call_kwargs["Domain"] == "cos.example.com"
|
||||
assert "Region" not in call_kwargs
|
||||
assert call_kwargs["SecretId"] == "test-secret-id"
|
||||
assert call_kwargs["SecretKey"] == "test-secret-key"
|
||||
assert call_kwargs["Scheme"] == "https"
|
||||
|
||||
def test_init_with_region(self):
|
||||
"""Test initialization with region configured (no custom domain)."""
|
||||
# Mock dify_config to return region configuration
|
||||
mock_dify_config = MagicMock()
|
||||
mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = None
|
||||
mock_dify_config.TENCENT_COS_REGION = "ap-guangzhou"
|
||||
mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id"
|
||||
mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key"
|
||||
mock_dify_config.TENCENT_COS_SCHEME = "https"
|
||||
|
||||
# Mock CosConfig and CosS3Client
|
||||
mock_config_instance = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
|
||||
patch(
|
||||
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
|
||||
) as mock_cos_config,
|
||||
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
|
||||
):
|
||||
TencentCosStorage()
|
||||
|
||||
# Verify CosConfig was called with Region parameter (not Domain)
|
||||
mock_cos_config.assert_called_once()
|
||||
call_kwargs = mock_cos_config.call_args[1]
|
||||
assert "Region" in call_kwargs
|
||||
assert call_kwargs["Region"] == "ap-guangzhou"
|
||||
assert "Domain" not in call_kwargs
|
||||
assert call_kwargs["SecretId"] == "test-secret-id"
|
||||
assert call_kwargs["SecretKey"] == "test-secret-key"
|
||||
assert call_kwargs["Scheme"] == "https"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
@ -110,9 +111,11 @@ class TestFirecrawlAuth:
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "response_text", "has_json_error", "expected_error_contains"),
|
||||
[
|
||||
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
(403, '{"error": "Forbidden"}', False, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||
# empty body falls back to generic message
|
||||
(404, "", True, "Failed to authorize. Status code: 404. Error: Unknown error occurred"),
|
||||
# non-JSON body is surfaced directly
|
||||
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@ -124,12 +127,14 @@ class TestFirecrawlAuth:
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = response_text
|
||||
if has_json_error:
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Not JSON", "", 0)
|
||||
else:
|
||||
mock_response.json.return_value = {"error": "Forbidden"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert expected_error_contains in str(exc_info.value)
|
||||
assert str(exc_info.value) == expected_error_contains
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
@ -164,20 +169,21 @@ class TestFirecrawlAuth:
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
"""Test that custom base URL is used in validation and normalized"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
|
||||
}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
for base in ("https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"):
|
||||
credentials = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": base},
|
||||
}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
|
||||
@ -0,0 +1,71 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
|
||||
from models import Account
|
||||
from services import app_dsl_service
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
|
||||
|
||||
def _build_response(url: str, status_code: int, content: bytes = b"") -> httpx.Response:
|
||||
request = httpx.Request("GET", url)
|
||||
return httpx.Response(status_code=status_code, request=request, content=content)
|
||||
|
||||
|
||||
def _pending_yaml_content(version: str = "99.0.0") -> bytes:
|
||||
return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode()
|
||||
|
||||
|
||||
def _account_mock() -> MagicMock:
|
||||
account = MagicMock(spec=Account)
|
||||
account.current_tenant_id = "tenant-1"
|
||||
return account
|
||||
|
||||
|
||||
def test_import_app_yaml_url_user_attachments_keeps_original_url(monkeypatch):
|
||||
yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml"
|
||||
raw_url = "https://raw.githubusercontent.com/user-attachments/files/24290802/loop-test.yml"
|
||||
yaml_bytes = _pending_yaml_content()
|
||||
|
||||
def fake_get(url: str, **kwargs):
|
||||
if url == raw_url:
|
||||
return _build_response(url, status_code=404)
|
||||
assert url == yaml_url
|
||||
return _build_response(url, status_code=200, content=yaml_bytes)
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
yaml_url=yaml_url,
|
||||
)
|
||||
|
||||
assert result.status == ImportStatus.PENDING
|
||||
assert result.imported_dsl_version == "99.0.0"
|
||||
|
||||
|
||||
def test_import_app_yaml_url_github_blob_rewrites_to_raw(monkeypatch):
|
||||
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
|
||||
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
|
||||
yaml_bytes = _pending_yaml_content()
|
||||
|
||||
requested_urls: list[str] = []
|
||||
|
||||
def fake_get(url: str, **kwargs):
|
||||
requested_urls.append(url)
|
||||
assert url == raw_url
|
||||
return _build_response(url, status_code=200, content=yaml_bytes)
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
yaml_url=yaml_url,
|
||||
)
|
||||
|
||||
assert result.status == ImportStatus.PENDING
|
||||
assert requested_urls == [raw_url]
|
||||
@ -1156,6 +1156,235 @@ class TestBillingServiceEdgeCases:
|
||||
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestBillingServiceSubscriptionOperations:
|
||||
"""Unit tests for subscription operations in BillingService.
|
||||
|
||||
Tests cover:
|
||||
- Bulk plan retrieval with chunking
|
||||
- Expired subscription cleanup whitelist retrieval
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
"""Mock _send_request method."""
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
def test_get_plan_bulk_with_empty_list(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with empty tenant list."""
|
||||
# Arrange
|
||||
tenant_ids = []
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
mock_send_request.assert_not_called()
|
||||
|
||||
def test_get_plan_bulk_with_chunking(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with more than 200 tenants (chunking logic)."""
|
||||
# Arrange - 250 tenants to test chunking (chunk_size = 200)
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# First chunk: tenants 0-199
|
||||
first_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Second chunk: tenants 200-249
|
||||
second_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)}
|
||||
}
|
||||
|
||||
mock_send_request.side_effect = [first_chunk_response, second_chunk_response]
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 250
|
||||
assert result["tenant-0"]["plan"] == "sandbox"
|
||||
assert result["tenant-199"]["plan"] == "sandbox"
|
||||
assert result["tenant-200"]["plan"] == "professional"
|
||||
assert result["tenant-249"]["plan"] == "professional"
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
# Verify first chunk call
|
||||
first_call = mock_send_request.call_args_list[0]
|
||||
assert first_call[0][0] == "POST"
|
||||
assert first_call[0][1] == "/subscription/plan/batch"
|
||||
assert len(first_call[1]["json"]["tenant_ids"]) == 200
|
||||
|
||||
# Verify second chunk call
|
||||
second_call = mock_send_request.call_args_list[1]
|
||||
assert len(second_call[1]["json"]["tenant_ids"]) == 50
|
||||
|
||||
def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when one batch fails but others succeed."""
|
||||
# Arrange - 250 tenants, second batch will fail
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# First chunk succeeds
|
||||
first_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Second chunk fails - need to create a mock that raises when called
|
||||
def side_effect_func(*args, **kwargs):
|
||||
if mock_send_request.call_count == 1:
|
||||
return first_chunk_response
|
||||
else:
|
||||
raise ValueError("API error")
|
||||
|
||||
mock_send_request.side_effect = side_effect_func
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert - should only have data from first batch
|
||||
assert len(result) == 200
|
||||
assert result["tenant-0"]["plan"] == "sandbox"
|
||||
assert result["tenant-199"]["plan"] == "sandbox"
|
||||
assert "tenant-200" not in result
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when all batches fail."""
|
||||
# Arrange
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# All chunks fail
|
||||
def side_effect_func(*args, **kwargs):
|
||||
raise ValueError("API error")
|
||||
|
||||
mock_send_request.side_effect = side_effect_func
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert - should return empty dict
|
||||
assert result == {}
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with exactly 200 tenants (boundary condition)."""
|
||||
# Arrange
|
||||
tenant_ids = [f"tenant-{i}" for i in range(200)]
|
||||
mock_send_request.return_value = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 200
|
||||
assert mock_send_request.call_count == 1
|
||||
|
||||
def test_get_plan_bulk_with_empty_data_response(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with empty data in response."""
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
mock_send_request.return_value = {"data": {}}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant)."""
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-valid-1", "tenant-invalid", "tenant-valid-2"]
|
||||
|
||||
# Response with one invalid tenant plan (missing expiration_date) and two valid ones
|
||||
mock_send_request.return_value = {
|
||||
"data": {
|
||||
"tenant-valid-1": {"plan": "sandbox", "expiration_date": 1735689600},
|
||||
"tenant-invalid": {"plan": "professional"}, # Missing expiration_date field
|
||||
"tenant-valid-2": {"plan": "team", "expiration_date": 1767225600},
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch("services.billing_service.logger") as mock_logger:
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert - should only contain valid tenants
|
||||
assert len(result) == 2
|
||||
assert "tenant-valid-1" in result
|
||||
assert "tenant-valid-2" in result
|
||||
assert "tenant-invalid" not in result
|
||||
|
||||
# Verify valid tenants have correct data
|
||||
assert result["tenant-valid-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-valid-1"]["expiration_date"] == 1735689600
|
||||
assert result["tenant-valid-2"]["plan"] == "team"
|
||||
assert result["tenant-valid-2"]["expiration_date"] == 1767225600
|
||||
|
||||
# Verify exception was logged for the invalid tenant
|
||||
mock_logger.exception.assert_called_once()
|
||||
log_call_args = mock_logger.exception.call_args[0]
|
||||
assert "get_plan_bulk: failed to validate subscription plan for tenant" in log_call_args[0]
|
||||
assert "tenant-invalid" in log_call_args[1]
|
||||
|
||||
def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request):
|
||||
"""Test successful retrieval of expired subscription cleanup whitelist."""
|
||||
# Arrange
|
||||
api_response = [
|
||||
{
|
||||
"created_at": "2025-10-16T01:56:17",
|
||||
"tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6",
|
||||
"contact": "example@dify.ai",
|
||||
"id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5",
|
||||
"expired_at": "2026-01-01T01:56:17",
|
||||
"updated_at": "2025-10-16T01:56:17",
|
||||
},
|
||||
{
|
||||
"created_at": "2025-10-16T02:00:00",
|
||||
"tenant_id": "tenant-2",
|
||||
"contact": "test@example.com",
|
||||
"id": "whitelist-id-2",
|
||||
"expired_at": "2026-02-01T00:00:00",
|
||||
"updated_at": "2025-10-16T02:00:00",
|
||||
},
|
||||
{
|
||||
"created_at": "2025-10-16T03:00:00",
|
||||
"tenant_id": "tenant-3",
|
||||
"contact": "another@example.com",
|
||||
"id": "whitelist-id-3",
|
||||
"expired_at": "2026-03-01T00:00:00",
|
||||
"updated_at": "2025-10-16T03:00:00",
|
||||
},
|
||||
]
|
||||
mock_send_request.return_value = {"data": api_response}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
|
||||
# Assert - should return only tenant_ids
|
||||
assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"]
|
||||
assert len(result) == 3
|
||||
assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6"
|
||||
assert result[1] == "tenant-2"
|
||||
assert result[2] == "tenant-3"
|
||||
mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist")
|
||||
|
||||
def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request):
|
||||
"""Test retrieval of empty cleanup whitelist."""
|
||||
# Arrange
|
||||
mock_send_request.return_value = {"data": []}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestBillingServiceIntegrationScenarios:
|
||||
"""Integration-style tests simulating real-world usage scenarios.
|
||||
|
||||
|
||||
@ -0,0 +1,472 @@
|
||||
"""
|
||||
Unit tests for SegmentService.get_segments method.
|
||||
|
||||
Tests the retrieval of document segments with pagination and filtering:
|
||||
- Basic pagination (page, limit)
|
||||
- Status filtering
|
||||
- Keyword search
|
||||
- Ordering by position and id (to avoid duplicate data)
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class SegmentServiceTestDataFactory:
|
||||
"""
|
||||
Factory class for creating test data and mock objects for segment tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_segment_mock(
|
||||
segment_id: str = "segment-123",
|
||||
document_id: str = "doc-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
position: int = 1,
|
||||
content: str = "Test content",
|
||||
status: str = "completed",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock document segment.
|
||||
|
||||
Args:
|
||||
segment_id: Unique identifier for the segment
|
||||
document_id: Parent document ID
|
||||
tenant_id: Tenant ID the segment belongs to
|
||||
dataset_id: Parent dataset ID
|
||||
position: Position within the document
|
||||
content: Segment text content
|
||||
status: Indexing status
|
||||
**kwargs: Additional attributes
|
||||
|
||||
Returns:
|
||||
Mock: DocumentSegment mock object
|
||||
"""
|
||||
segment = create_autospec(DocumentSegment, instance=True)
|
||||
segment.id = segment_id
|
||||
segment.document_id = document_id
|
||||
segment.tenant_id = tenant_id
|
||||
segment.dataset_id = dataset_id
|
||||
segment.position = position
|
||||
segment.content = content
|
||||
segment.status = status
|
||||
for key, value in kwargs.items():
|
||||
setattr(segment, key, value)
|
||||
return segment
|
||||
|
||||
|
||||
class TestSegmentServiceGetSegments:
|
||||
"""
|
||||
Comprehensive unit tests for SegmentService.get_segments method.
|
||||
|
||||
Tests cover:
|
||||
- Basic pagination functionality
|
||||
- Status list filtering
|
||||
- Keyword search filtering
|
||||
- Ordering (position + id for uniqueness)
|
||||
- Empty results
|
||||
- Combined filters
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_segment_service_dependencies(self):
|
||||
"""
|
||||
Common mock setup for segment service dependencies.
|
||||
|
||||
Patches:
|
||||
- db: Database operations and pagination
|
||||
- select: SQLAlchemy query builder
|
||||
"""
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.select") as mock_select,
|
||||
):
|
||||
yield {
|
||||
"db": mock_db,
|
||||
"select": mock_select,
|
||||
}
|
||||
|
||||
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test basic pagination functionality.
|
||||
|
||||
Verifies:
|
||||
- Query is built with document_id and tenant_id filters
|
||||
- Pagination uses correct page and limit parameters
|
||||
- Returns segments and total count
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
page = 1
|
||||
limit = 20
|
||||
|
||||
# Create mock segments
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", position=1, content="First segment"
|
||||
)
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-2", position=2, content="Second segment"
|
||||
)
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2]
|
||||
mock_paginated.total = 2
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
# Mock select builder
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 2
|
||||
assert total == 2
|
||||
assert items[0].id == "seg-1"
|
||||
assert items[1].id == "seg-2"
|
||||
mock_segment_service_dependencies["db"].paginate.assert_called_once()
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == limit
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
assert call_kwargs["error_out"] is False
|
||||
|
||||
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test filtering by status list.
|
||||
|
||||
Verifies:
|
||||
- Status list filter is applied to query
|
||||
- Only segments with matching status are returned
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = ["completed", "indexing"]
|
||||
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2]
|
||||
mock_paginated.total = 2
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id, tenant_id=tenant_id, status_list=status_list
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 2
|
||||
assert total == 2
|
||||
# Verify where was called multiple times (base filters + status filter)
|
||||
assert mock_query.where.call_count >= 2
|
||||
|
||||
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with empty status list.
|
||||
|
||||
Verifies:
|
||||
- Empty status list is handled correctly
|
||||
- No status filter is applied to avoid WHERE false condition
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = []
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id, tenant_id=tenant_id, status_list=status_list
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Should only be called once (base filters, no status filter)
|
||||
assert mock_query.where.call_count == 1
|
||||
|
||||
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test keyword search functionality.
|
||||
|
||||
Verifies:
|
||||
- Keyword filter uses ilike for case-insensitive search
|
||||
- Search pattern includes wildcards (%keyword%)
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
keyword = "search term"
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", content="This contains search term"
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Verify where was called for base filters + keyword filter
|
||||
assert mock_query.where.call_count == 2
|
||||
|
||||
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test ordering by position and id.
|
||||
|
||||
Verifies:
|
||||
- Results are ordered by position ASC
|
||||
- Results are secondarily ordered by id ASC to ensure uniqueness
|
||||
- This prevents duplicate data across pages when positions are not unique
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
# Create segments with same position but different ids
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", position=1, content="Content 1"
|
||||
)
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-2", position=1, content="Content 2"
|
||||
)
|
||||
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-3", position=2, content="Content 3"
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2, segment3]
|
||||
mock_paginated.total = 3
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 3
|
||||
assert total == 3
|
||||
mock_query.order_by.assert_called_once()
|
||||
|
||||
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test when no segments match the criteria.
|
||||
|
||||
Verifies:
|
||||
- Empty list is returned for items
|
||||
- Total count is 0
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "non-existent-doc"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = []
|
||||
mock_paginated.total = 0
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert items == []
|
||||
assert total == 0
|
||||
|
||||
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with multiple filters combined.
|
||||
|
||||
Verifies:
|
||||
- All filters work together correctly
|
||||
- Status list and keyword search both applied
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = ["completed"]
|
||||
keyword = "important"
|
||||
page = 2
|
||||
limit = 10
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1",
|
||||
status="completed",
|
||||
content="This is important information",
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
status_list=status_list,
|
||||
keyword=keyword,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Verify filters: base + status + keyword
|
||||
assert mock_query.where.call_count == 3
|
||||
# Verify pagination parameters
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == limit
|
||||
|
||||
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with None status list.
|
||||
|
||||
Verifies:
|
||||
- None status list is handled correctly
|
||||
- No status filter is applied
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
status_list=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Should only be called once (base filters only, no status filter)
|
||||
assert mock_query.where.call_count == 1
|
||||
|
||||
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test that max_per_page is correctly set to 100.
|
||||
|
||||
Verifies:
|
||||
- max_per_page parameter is set to 100
|
||||
- This prevents excessive page sizes
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
limit = 200 # Request more than max_per_page
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = []
|
||||
mock_paginated.total = 0
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
520
api/tests/unit_tests/tasks/test_document_indexing_sync_task.py
Normal file
520
api/tests/unit_tests/tasks/test_document_indexing_sync_task.py
Normal file
@ -0,0 +1,520 @@
|
||||
"""
|
||||
Unit tests for document indexing sync task.
|
||||
|
||||
This module tests the document indexing sync task functionality including:
|
||||
- Syncing Notion documents when updated
|
||||
- Validating document and data source existence
|
||||
- Credential validation and retrieval
|
||||
- Cleaning old segments before re-indexing
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
"""Generate a unique tenant ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
"""Generate a unique dataset ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_id():
|
||||
"""Generate a unique document ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_workspace_id():
|
||||
"""Generate a Notion workspace ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_page_id():
|
||||
"""Generate a Notion page ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credential_id():
|
||||
"""Generate a credential ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id, tenant_id):
|
||||
"""Create a mock Dataset object."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
|
||||
"""Create a mock Document object with Notion data source."""
|
||||
doc = Mock(spec=Document)
|
||||
doc.id = document_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.tenant_id = tenant_id
|
||||
doc.data_source_type = "notion_import"
|
||||
doc.indexing_status = "completed"
|
||||
doc.error = None
|
||||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
doc.doc_form = "text_model"
|
||||
doc.data_source_info_dict = {
|
||||
"notion_workspace_id": notion_workspace_id,
|
||||
"notion_page_id": notion_page_id,
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
"credential_id": credential_id,
|
||||
}
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_segments(document_id):
|
||||
"""Create mock DocumentSegment objects."""
|
||||
segments = []
|
||||
for i in range(3):
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = document_id
|
||||
segment.index_node_id = f"node-{document_id}-{i}"
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasource_provider_service():
|
||||
"""Mock DatasourceProviderService."""
|
||||
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||
mock_service_class.return_value = mock_service
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_notion_extractor():
|
||||
"""Mock NotionExtractor."""
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
yield mock_extractor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
yield mock_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner.run = Mock()
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for document_indexing_sync_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTask:
|
||||
"""Tests for the document_indexing_sync_task function."""
|
||||
|
||||
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
|
||||
"""Test that task handles document not found gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_page_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when data_source_info is empty."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = None
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_credential_not_found(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles missing credentials by updating document status."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
assert mock_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in mock_document.error
|
||||
assert mock_document.stopped_at is not None
|
||||
mock_db_session.commit.assert_called()
|
||||
mock_db_session.close.assert_called()
|
||||
|
||||
def test_page_not_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task does nothing when page has not been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
# Return same time as stored in document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# No session operations should be performed beyond the initial query
|
||||
mock_db_session.close.assert_not_called()
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test successful sync flow when Notion page has been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# NotionExtractor returns updated time
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Verify document status was updated to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
assert mock_document.processing_started_at is not None
|
||||
|
||||
# Verify segments were cleaned
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
# Verify session operations
|
||||
assert mock_db_session.commit.called
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_dataset_not_found_during_cleaning(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles dataset not found during cleaning phase."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document should still be set to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that indexing continues even if cleaning fails."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Indexing should still be attempted despite cleaning error
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_indexing_runner_document_paused_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after handling error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_indexing_runner_general_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that general exceptions during indexing are handled."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_notion_extractor_initialized_with_correct_params(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
notion_workspace_id,
|
||||
notion_page_id,
|
||||
):
|
||||
"""Test that NotionExtractor is initialized with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
|
||||
|
||||
# Act
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_extractor_class.assert_called_once_with(
|
||||
notion_workspace_id=notion_workspace_id,
|
||||
notion_obj_id=notion_page_id,
|
||||
notion_page_type="page",
|
||||
notion_access_token="test_token",
|
||||
tenant_id=mock_document.tenant_id,
|
||||
)
|
||||
|
||||
def test_datasource_credentials_requested_correctly(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
credential_id,
|
||||
):
|
||||
"""Test that datasource credentials are requested with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=mock_document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
def test_credential_id_missing_uses_none(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles missing credential_id by passing None."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {
|
||||
"notion_workspace_id": "ws123",
|
||||
"notion_page_id": "page123",
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=mock_document.tenant_id,
|
||||
credential_id=None,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
def test_index_processor_clean_called_with_correct_params(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that index processor clean is called with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
|
||||
mock_processor.clean.assert_called_once_with(
|
||||
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
|
||||
)
|
||||
122
api/tests/unit_tests/tools/test_mcp_tool.py
Normal file
122
api/tests/unit_tests/tools/test_mcp_tool.py
Normal file
@ -0,0 +1,122 @@
|
||||
import base64
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.mcp.types import (
|
||||
AudioContent,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
|
||||
|
||||
def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool:
|
||||
identity = ToolIdentity(
|
||||
author="test",
|
||||
name="test_mcp_tool",
|
||||
label=I18nObject(en_US="Test MCP Tool", zh_Hans="测试MCP工具"),
|
||||
provider="test_provider",
|
||||
)
|
||||
entity = ToolEntity(identity=identity, output_schema=output_schema or {})
|
||||
runtime = Mock(spec=ToolRuntime)
|
||||
runtime.credentials = {}
|
||||
return MCPTool(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id="test_tenant",
|
||||
icon="",
|
||||
server_url="https://server.invalid",
|
||||
provider_id="provider_1",
|
||||
headers={},
|
||||
)
|
||||
|
||||
|
||||
class TestMCPToolInvoke:
|
||||
@pytest.mark.parametrize(
|
||||
("content_factory", "mime_type"),
|
||||
[
|
||||
(
|
||||
lambda b64, mt: ImageContent(type="image", data=b64, mimeType=mt),
|
||||
"image/png",
|
||||
),
|
||||
(
|
||||
lambda b64, mt: AudioContent(type="audio", data=b64, mimeType=mt),
|
||||
"audio/mpeg",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invoke_image_or_audio_yields_blob(self, content_factory, mime_type) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
raw = b"\x00\x01test-bytes\x02"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
content = content_factory(b64, mime_type)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
|
||||
assert msg.message.blob == raw
|
||||
assert msg.meta == {"mime_type": mime_type}
|
||||
|
||||
def test_invoke_embedded_text_resource_yields_text(self) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
text_resource = TextResourceContents(uri="file://test.txt", mimeType="text/plain", text="hello world")
|
||||
content = EmbeddedResource(type="resource", resource=text_resource)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(msg.message, ToolInvokeMessage.TextMessage)
|
||||
assert msg.message.text == "hello world"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mime_type", "expected_mime"),
|
||||
[("application/pdf", "application/pdf"), (None, "application/octet-stream")],
|
||||
)
|
||||
def test_invoke_embedded_blob_resource_yields_blob(self, mime_type, expected_mime) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
raw = b"binary-data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
blob_resource = BlobResourceContents(uri="file://doc.bin", mimeType=mime_type, blob=b64)
|
||||
content = EmbeddedResource(type="resource", resource=blob_resource)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
|
||||
assert msg.message.blob == raw
|
||||
assert msg.meta == {"mime_type": expected_mime}
|
||||
|
||||
def test_invoke_yields_variables_when_structured_content_and_schema(self) -> None:
|
||||
tool = _make_mcp_tool(output_schema={"type": "object"})
|
||||
result = CallToolResult(content=[], structuredContent={"a": 1, "b": "x"})
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
# Expect two variable messages corresponding to keys a and b
|
||||
assert len(messages) == 2
|
||||
var_msgs = [m for m in messages if isinstance(m.message, ToolInvokeMessage.VariableMessage)]
|
||||
assert {m.message.variable_name for m in var_msgs} == {"a", "b"}
|
||||
# Validate values
|
||||
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
|
||||
assert values == {"a": 1, "b": "x"}
|
||||
Reference in New Issue
Block a user