mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 23:18:05 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -0,0 +1,168 @@
|
||||
"""
|
||||
Unit tests for App description validation functions.
|
||||
|
||||
This test module validates the 400-character limit enforcement
|
||||
for App descriptions across all creation and editing endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the API root to Python path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
|
||||
|
||||
class TestAppDescriptionValidationUnit:
|
||||
"""Unit tests for description validation function"""
|
||||
|
||||
def test_validate_description_length_function(self):
|
||||
"""Test the _validate_description_length function directly"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
|
||||
# Test valid descriptions
|
||||
assert _validate_description_length("") == ""
|
||||
assert _validate_description_length("x" * 400) == "x" * 400
|
||||
assert _validate_description_length(None) is None
|
||||
|
||||
# Test invalid descriptions
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 401)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 500)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 1000)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_validation_consistency_with_dataset(self):
|
||||
"""Test that App and Dataset validation functions are consistent"""
|
||||
from controllers.console.app.app import _validate_description_length as app_validate
|
||||
from controllers.console.datasets.datasets import _validate_description_length as dataset_validate
|
||||
from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate
|
||||
|
||||
# Test same valid inputs
|
||||
valid_desc = "x" * 400
|
||||
assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc)
|
||||
assert app_validate("") == dataset_validate("") == service_dataset_validate("")
|
||||
assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None)
|
||||
|
||||
# Test same invalid inputs produce same error
|
||||
invalid_desc = "x" * 401
|
||||
|
||||
app_error = None
|
||||
dataset_error = None
|
||||
service_dataset_error = None
|
||||
|
||||
try:
|
||||
app_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
app_error = str(e)
|
||||
|
||||
try:
|
||||
dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
dataset_error = str(e)
|
||||
|
||||
try:
|
||||
service_dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
service_dataset_error = str(e)
|
||||
|
||||
assert app_error == dataset_error == service_dataset_error
|
||||
assert app_error == "Description cannot exceed 400 characters."
|
||||
|
||||
def test_boundary_values(self):
|
||||
"""Test boundary values for description validation"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
|
||||
# Test exact boundary
|
||||
exactly_400 = "x" * 400
|
||||
assert _validate_description_length(exactly_400) == exactly_400
|
||||
|
||||
# Test just over boundary
|
||||
just_over_400 = "x" * 401
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(just_over_400)
|
||||
|
||||
# Test just under boundary
|
||||
just_under_400 = "x" * 399
|
||||
assert _validate_description_length(just_under_400) == just_under_400
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for description validation"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
|
||||
# Test None input
|
||||
assert _validate_description_length(None) is None
|
||||
|
||||
# Test empty string
|
||||
assert _validate_description_length("") == ""
|
||||
|
||||
# Test single character
|
||||
assert _validate_description_length("a") == "a"
|
||||
|
||||
# Test unicode characters
|
||||
unicode_desc = "测试" * 200 # 400 characters in Chinese
|
||||
assert _validate_description_length(unicode_desc) == unicode_desc
|
||||
|
||||
# Test unicode over limit
|
||||
unicode_over = "测试" * 201 # 402 characters
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(unicode_over)
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test how validation handles whitespace"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
|
||||
# Test description with spaces
|
||||
spaces_400 = " " * 400
|
||||
assert _validate_description_length(spaces_400) == spaces_400
|
||||
|
||||
# Test description with spaces over limit
|
||||
spaces_401 = " " * 401
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(spaces_401)
|
||||
|
||||
# Test mixed content
|
||||
mixed_400 = "a" * 200 + " " * 200
|
||||
assert _validate_description_length(mixed_400) == mixed_400
|
||||
|
||||
# Test mixed over limit
|
||||
mixed_401 = "a" * 200 + " " * 201
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(mixed_401)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
import traceback
|
||||
|
||||
test_instance = TestAppDescriptionValidationUnit()
|
||||
test_methods = [method for method in dir(test_instance) if method.startswith("test_")]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_method in test_methods:
|
||||
try:
|
||||
print(f"Running {test_method}...")
|
||||
getattr(test_instance, test_method)()
|
||||
print(f"✅ {test_method} PASSED")
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ {test_method} FAILED: {str(e)}")
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
|
||||
|
||||
if failed == 0:
|
||||
print("🎉 All tests passed!")
|
||||
else:
|
||||
print("💥 Some tests failed!")
|
||||
sys.exit(1)
|
||||
@ -39,10 +39,7 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(
|
||||
collection_name="test_collection_" + str(os.getpid()),
|
||||
config=config
|
||||
)
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
|
||||
yield vector
|
||||
|
||||
@ -114,7 +111,7 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
"category": "technical" if i % 2 == 0 else "general",
|
||||
"document_id": f"doc_{i // 3}", # Group documents
|
||||
"importance": i,
|
||||
}
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
# Create varied embeddings
|
||||
@ -124,22 +121,14 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
|
||||
# Test vector search with document filter
|
||||
query_vector = [0.5, 1.0, 1.5, 2.0]
|
||||
results = vector_store.search_by_vector(
|
||||
query_vector,
|
||||
top_k=5,
|
||||
document_ids_filter=["doc_0", "doc_1"]
|
||||
)
|
||||
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
|
||||
assert len(results) > 0
|
||||
# All results should belong to doc_0 or doc_1 groups
|
||||
for result in results:
|
||||
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
|
||||
|
||||
# Test score threshold
|
||||
results = vector_store.search_by_vector(
|
||||
query_vector,
|
||||
top_k=10,
|
||||
score_threshold=0.5
|
||||
)
|
||||
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
|
||||
# Check that all results have a score above threshold
|
||||
for result in results:
|
||||
assert result.metadata.get("score", 0) >= 0.5
|
||||
@ -154,7 +143,7 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
for i in range(batch_size):
|
||||
doc = Document(
|
||||
page_content=f"Batch document {i}: This is a test document for batch processing.",
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"}
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
|
||||
)
|
||||
documents.append(doc)
|
||||
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
|
||||
@ -179,7 +168,7 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
# Test special characters in content
|
||||
special_doc = Document(
|
||||
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"}
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"},
|
||||
)
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
||||
|
||||
@ -199,20 +188,18 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
# Prepare documents with various language content
|
||||
documents = [
|
||||
Document(
|
||||
page_content="云器科技提供强大的Lakehouse解决方案",
|
||||
metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Clickzetta provides powerful Lakehouse solutions",
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"}
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"},
|
||||
),
|
||||
Document(
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分",
|
||||
metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Modern data architecture includes Lakehouse technology",
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"}
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
"""
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
@ -20,7 +21,7 @@ def test_clickzetta_connection():
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify")
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
@ -36,7 +37,7 @@ def test_clickzetta_connection():
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == 'dify']:
|
||||
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
@ -55,6 +56,7 @@ def test_clickzetta_connection():
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
@ -83,6 +85,7 @@ def test_dify_api():
|
||||
print(f"✗ API test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
@ -91,15 +94,10 @@ def verify_table_structure():
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"vector": "ARRAY<FLOAT>"
|
||||
"vector": "ARRAY<FLOAT>",
|
||||
}
|
||||
|
||||
expected_metadata_fields = [
|
||||
"doc_id",
|
||||
"doc_hash",
|
||||
"document_id",
|
||||
"dataset_id"
|
||||
]
|
||||
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
@ -117,6 +115,7 @@ def verify_table_structure():
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
@ -137,9 +136,9 @@ def main():
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*50)
|
||||
print("\n" + "=" * 50)
|
||||
print("Test Summary:")
|
||||
print("="*50)
|
||||
print("=" * 50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
@ -161,5 +160,6 @@ def main():
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,487 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
|
||||
|
||||
class TestAPIBasedExtensionService:
|
||||
"""Integration tests for APIBasedExtensionService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.api_based_extension_service.APIBasedExtensionRequestor") as mock_requestor,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_account_feature_service.get_features.return_value.billing.enabled = False
|
||||
|
||||
# Mock successful ping response
|
||||
mock_requestor_instance = mock_requestor.return_value
|
||||
mock_requestor_instance.request.return_value = {"result": "pong"}
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"requestor": mock_requestor,
|
||||
"requestor_instance": mock_requestor_instance,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful saving of API-based extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
# Save extension
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Verify extension was saved correctly
|
||||
assert saved_extension.id is not None
|
||||
assert saved_extension.tenant_id == tenant.id
|
||||
assert saved_extension.name == extension_data.name
|
||||
assert saved_extension.api_endpoint == extension_data.api_endpoint
|
||||
assert saved_extension.api_key == extension_data.api_key # Should be decrypted when retrieved
|
||||
assert saved_extension.created_at is not None
|
||||
|
||||
# Verify extension was saved to database
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(saved_extension)
|
||||
assert saved_extension.id is not None
|
||||
|
||||
# Verify ping connection was called
|
||||
mock_external_service_dependencies["requestor_instance"].request.assert_called_once()
|
||||
|
||||
def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation errors when saving extension with invalid data.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test empty name
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = ""
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test empty api_endpoint
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = ""
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test empty api_key
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = ""
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of all extensions by tenant ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create multiple extensions
|
||||
extensions = []
|
||||
for i in range(3):
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = f"Extension {i}: {fake.company()}"
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
extensions.append(saved_extension)
|
||||
|
||||
# Get all extensions for tenant
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
|
||||
|
||||
# Verify results
|
||||
assert len(extension_list) == 3
|
||||
|
||||
# Verify all extensions belong to the correct tenant and are ordered by created_at desc
|
||||
for i, extension in enumerate(extension_list):
|
||||
assert extension.tenant_id == tenant.id
|
||||
assert extension.api_key is not None # Should be decrypted
|
||||
if i > 0:
|
||||
# Verify descending order (newer first)
|
||||
assert extension.created_at <= extension_list[i - 1].created_at
|
||||
|
||||
def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of extension by tenant ID and extension ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create an extension
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Get extension by ID
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
|
||||
|
||||
# Verify extension was retrieved correctly
|
||||
assert retrieved_extension is not None
|
||||
assert retrieved_extension.id == created_extension.id
|
||||
assert retrieved_extension.tenant_id == tenant.id
|
||||
assert retrieved_extension.name == extension_data.name
|
||||
assert retrieved_extension.api_endpoint == extension_data.api_endpoint
|
||||
assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted
|
||||
assert retrieved_extension.created_at is not None
|
||||
|
||||
def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extension when extension is not found.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
non_existent_extension_id = fake.uuid4()
|
||||
|
||||
# Try to get non-existent extension
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
|
||||
|
||||
def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create an extension first
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
extension_id = created_extension.id
|
||||
|
||||
# Delete the extension
|
||||
APIBasedExtensionService.delete(created_extension)
|
||||
|
||||
# Verify extension was deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first()
|
||||
assert deleted_extension is None
|
||||
|
||||
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when saving extension with duplicate name.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create first extension
|
||||
extension_data1 = APIBasedExtension()
|
||||
extension_data1.tenant_id = tenant.id
|
||||
extension_data1.name = "Test Extension"
|
||||
extension_data1.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data1.api_key = fake.password(length=20)
|
||||
|
||||
APIBasedExtensionService.save(extension_data1)
|
||||
|
||||
# Try to create second extension with same name
|
||||
extension_data2 = APIBasedExtension()
|
||||
extension_data2.tenant_id = tenant.id
|
||||
extension_data2.name = "Test Extension" # Same name
|
||||
extension_data2.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data2.api_key = fake.password(length=20)
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(extension_data2)
|
||||
|
||||
def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful update of existing extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create initial extension
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Save original values for later comparison
|
||||
original_name = created_extension.name
|
||||
original_endpoint = created_extension.api_endpoint
|
||||
|
||||
# Update the extension
|
||||
new_name = fake.company()
|
||||
new_endpoint = f"https://{fake.domain_name()}/api"
|
||||
new_api_key = fake.password(length=20)
|
||||
|
||||
created_extension.name = new_name
|
||||
created_extension.api_endpoint = new_endpoint
|
||||
created_extension.api_key = new_api_key
|
||||
|
||||
updated_extension = APIBasedExtensionService.save(created_extension)
|
||||
|
||||
# Verify extension was updated correctly
|
||||
assert updated_extension.id == created_extension.id
|
||||
assert updated_extension.tenant_id == tenant.id
|
||||
assert updated_extension.name == new_name
|
||||
assert updated_extension.api_endpoint == new_endpoint
|
||||
|
||||
# Verify original values were changed
|
||||
assert updated_extension.name != original_name
|
||||
assert updated_extension.api_endpoint != original_endpoint
|
||||
|
||||
# Verify ping connection was called for both create and update
|
||||
assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2
|
||||
|
||||
# Verify the update by retrieving the extension again
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
|
||||
assert retrieved_extension.name == new_name
|
||||
assert retrieved_extension.api_endpoint == new_endpoint
|
||||
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
|
||||
|
||||
def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test connection error when saving extension with invalid endpoint.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock connection error
|
||||
mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError(
|
||||
"connection error: request timeout"
|
||||
)
|
||||
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = "https://invalid-endpoint.com/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
# Try to save extension with connection error
|
||||
with pytest.raises(ValueError, match="connection error: request timeout"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_invalid_api_key_length(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation error when saving extension with API key that is too short.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup extension data with short API key
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = "1234" # Less than 5 characters
|
||||
|
||||
# Try to save extension with short API key
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation errors when saving extension with empty required fields.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test with None values
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = None
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test with None api_endpoint
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test with None api_key
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extensions when no extensions exist for tenant.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Get all extensions for tenant (none exist)
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
|
||||
|
||||
# Verify empty list is returned
|
||||
assert len(extension_list) == 0
|
||||
assert extension_list == []
|
||||
|
||||
def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when ping response is invalid.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock invalid ping response
|
||||
mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"}
|
||||
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
# Try to save extension with invalid ping response
|
||||
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when ping response is missing result field.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock ping response without result field
|
||||
mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"}
|
||||
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
# Try to save extension with missing ping result
|
||||
with pytest.raises(ValueError, match="{'status': 'ok'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extension when tenant ID doesn't match.
|
||||
"""
|
||||
fake = Faker()
|
||||
account1, tenant1 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create second account and tenant
|
||||
account2, tenant2 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create extension in first tenant
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant1.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Try to get extension with wrong tenant ID
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
|
||||
@ -0,0 +1,473 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from faker import Faker
|
||||
|
||||
from models.model import App, AppModelConfig
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppDslService:
|
||||
"""Integration tests for AppDslService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_dsl_service.WorkflowService") as mock_workflow_service,
|
||||
patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service,
|
||||
patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service,
|
||||
patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy,
|
||||
patch("services.app_dsl_service.redis_client") as mock_redis_client,
|
||||
patch("services.app_dsl_service.app_was_created") as mock_app_was_created,
|
||||
patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_workflow_service.return_value.get_draft_workflow.return_value = None
|
||||
mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock()
|
||||
mock_dependencies_service.generate_latest_dependencies.return_value = []
|
||||
mock_dependencies_service.get_leaked_dependencies.return_value = []
|
||||
mock_dependencies_service.generate_dependencies.return_value = []
|
||||
mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None
|
||||
mock_ssrf_proxy.get.return_value.content = b"test content"
|
||||
mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None
|
||||
mock_redis_client.setex.return_value = None
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_redis_client.delete.return_value = None
|
||||
mock_app_was_created.send.return_value = None
|
||||
mock_app_model_config_was_updated.send.return_value = None
|
||||
|
||||
# Mock ModelManager for app service
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
# Mock FeatureService and EnterpriseService
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
yield {
|
||||
"workflow_service": mock_workflow_service,
|
||||
"dependencies_service": mock_dependencies_service,
|
||||
"draft_variable_service": mock_draft_variable_service,
|
||||
"ssrf_proxy": mock_ssrf_proxy,
|
||||
"redis_client": mock_redis_client,
|
||||
"app_was_created": mock_app_was_created,
|
||||
"app_model_config_was_updated": mock_app_model_config_was_updated,
|
||||
"model_manager": mock_model_manager,
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
with patch("services.account_service.FeatureService") as mock_account_feature_service:
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"):
|
||||
"""
|
||||
Helper method to create simple YAML content for testing.
|
||||
"""
|
||||
yaml_data = {
|
||||
"version": "0.3.0",
|
||||
"kind": "app",
|
||||
"app": {
|
||||
"name": app_name,
|
||||
"mode": app_mode,
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FFEAD5",
|
||||
"description": "Test app description",
|
||||
"use_icon_as_answer_icon": False,
|
||||
},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
},
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"prompt_type": "simple",
|
||||
},
|
||||
}
|
||||
return yaml.dump(yaml_data, allow_unicode=True)
|
||||
|
||||
def test_import_app_yaml_content_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app import from YAML content.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create YAML content
|
||||
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
|
||||
|
||||
# Import app
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name="Imported App",
|
||||
description="Imported app description",
|
||||
)
|
||||
|
||||
# Verify import result
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.app_id is not None
|
||||
assert result.app_mode == "chat"
|
||||
assert result.imported_dsl_version == "0.3.0"
|
||||
assert result.error == ""
|
||||
|
||||
# Verify app was created in database
|
||||
imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first()
|
||||
assert imported_app is not None
|
||||
assert imported_app.name == "Imported App"
|
||||
assert imported_app.description == "Imported app description"
|
||||
assert imported_app.mode == "chat"
|
||||
assert imported_app.tenant_id == account.current_tenant_id
|
||||
assert imported_app.created_by == account.id
|
||||
|
||||
# Verify model config was created
|
||||
model_config = (
|
||||
db_session_with_containers.query(AppModelConfig).filter(AppModelConfig.app_id == result.app_id).first()
|
||||
)
|
||||
assert model_config is not None
|
||||
# The provider and model_id are stored in the model field as JSON
|
||||
model_dict = model_config.model_dict
|
||||
assert model_dict["provider"] == "openai"
|
||||
assert model_dict["name"] == "gpt-3.5-turbo"
|
||||
|
||||
def test_import_app_yaml_url_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app import from YAML URL.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create YAML content for mock response
|
||||
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
|
||||
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = yaml_content.encode("utf-8")
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_external_service_dependencies["ssrf_proxy"].get.return_value = mock_response
|
||||
|
||||
# Import app from URL
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
yaml_url="https://example.com/app.yaml",
|
||||
name="URL Imported App",
|
||||
description="App imported from URL",
|
||||
)
|
||||
|
||||
# Verify import result
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.app_id is not None
|
||||
assert result.app_mode == "chat"
|
||||
assert result.imported_dsl_version == "0.3.0"
|
||||
assert result.error == ""
|
||||
|
||||
# Verify app was created in database
|
||||
imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first()
|
||||
assert imported_app is not None
|
||||
assert imported_app.name == "URL Imported App"
|
||||
assert imported_app.description == "App imported from URL"
|
||||
assert imported_app.mode == "chat"
|
||||
assert imported_app.tenant_id == account.current_tenant_id
|
||||
|
||||
# Verify ssrf_proxy was called
|
||||
mock_external_service_dependencies["ssrf_proxy"].get.assert_called_once_with(
|
||||
"https://example.com/app.yaml", follow_redirects=True, timeout=(10, 10)
|
||||
)
|
||||
|
||||
def test_import_app_invalid_yaml_format(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with invalid YAML format.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create invalid YAML content
|
||||
invalid_yaml = "invalid: yaml: content: ["
|
||||
|
||||
# Import app with invalid YAML
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=invalid_yaml,
|
||||
name="Invalid App",
|
||||
)
|
||||
|
||||
# Verify import failed
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.app_id is None
|
||||
assert "Invalid YAML format" in result.error
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with missing YAML content.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Import app without YAML content
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
name="Missing Content App",
|
||||
)
|
||||
|
||||
# Verify import failed
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.app_id is None
|
||||
assert "yaml_content is required" in result.error
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with missing YAML URL.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Import app without YAML URL
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
name="Missing URL App",
|
||||
)
|
||||
|
||||
# Verify import failed
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.app_id is None
|
||||
assert "yaml_url is required" in result.error
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with invalid import mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create YAML content
|
||||
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
|
||||
|
||||
# Import app with invalid mode should raise ValueError
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"):
|
||||
dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode="invalid-mode",
|
||||
yaml_content=yaml_content,
|
||||
name="Invalid Mode App",
|
||||
)
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful DSL export for chat app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create model config for the app
|
||||
model_config = AppModelConfig()
|
||||
model_config.id = fake.uuid4()
|
||||
model_config.app_id = app.id
|
||||
model_config.provider = "openai"
|
||||
model_config.model_id = "gpt-3.5-turbo"
|
||||
model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
)
|
||||
model_config.pre_prompt = "You are a helpful assistant."
|
||||
model_config.prompt_type = "simple"
|
||||
model_config.created_by = account.id
|
||||
model_config.updated_by = account.id
|
||||
|
||||
# Set the app_model_config_id to link the config
|
||||
app.app_model_config_id = model_config.id
|
||||
|
||||
db_session_with_containers.add(model_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Export DSL
|
||||
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
|
||||
|
||||
# Parse exported YAML
|
||||
exported_data = yaml.safe_load(exported_dsl)
|
||||
|
||||
# Verify exported data structure
|
||||
assert exported_data["kind"] == "app"
|
||||
assert exported_data["app"]["name"] == app.name
|
||||
assert exported_data["app"]["mode"] == app.mode
|
||||
assert exported_data["app"]["icon"] == app.icon
|
||||
assert exported_data["app"]["icon_background"] == app.icon_background
|
||||
assert exported_data["app"]["description"] == app.description
|
||||
|
||||
# Verify model config was exported
|
||||
assert "model_config" in exported_data
|
||||
# The exported model_config structure may be different from the database structure
|
||||
# Check that the model config exists and has the expected content
|
||||
assert exported_data["model_config"] is not None
|
||||
|
||||
# Verify dependencies were exported
|
||||
assert "dependencies" in exported_data
|
||||
assert isinstance(exported_data["dependencies"], list)
|
||||
|
||||
def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful DSL export for workflow app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Update app to workflow mode
|
||||
app.mode = "workflow"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock workflow service to return a workflow
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.to_dict.return_value = {
|
||||
"graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []},
|
||||
"features": {},
|
||||
"environment_variables": [],
|
||||
"conversation_variables": [],
|
||||
}
|
||||
mock_external_service_dependencies[
|
||||
"workflow_service"
|
||||
].return_value.get_draft_workflow.return_value = mock_workflow
|
||||
|
||||
# Export DSL
|
||||
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
|
||||
|
||||
# Parse exported YAML
|
||||
exported_data = yaml.safe_load(exported_dsl)
|
||||
|
||||
# Verify exported data structure
|
||||
assert exported_data["kind"] == "app"
|
||||
assert exported_data["app"]["name"] == app.name
|
||||
assert exported_data["app"]["mode"] == "workflow"
|
||||
|
||||
# Verify workflow was exported
|
||||
assert "workflow" in exported_data
|
||||
assert "graph" in exported_data["workflow"]
|
||||
assert "nodes" in exported_data["workflow"]["graph"]
|
||||
|
||||
# Verify dependencies were exported
|
||||
assert "dependencies" in exported_data
|
||||
assert isinstance(exported_data["dependencies"], list)
|
||||
|
||||
# Verify workflow service was called
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
|
||||
app
|
||||
)
|
||||
|
||||
def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful dependency checking.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock Redis to return dependencies
|
||||
mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}'
|
||||
mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json
|
||||
|
||||
# Check dependencies
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.check_dependencies(app_model=app)
|
||||
|
||||
# Verify result
|
||||
assert result.leaked_dependencies == []
|
||||
|
||||
# Verify Redis was queried
|
||||
mock_external_service_dependencies["redis_client"].get.assert_called_once_with(
|
||||
f"app_check_dependencies:{app.id}"
|
||||
)
|
||||
|
||||
# Verify dependencies service was called
|
||||
mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once()
|
||||
@ -0,0 +1,252 @@
|
||||
import pytest
|
||||
|
||||
from controllers.console.app.app import _validate_description_length as app_validate
|
||||
from controllers.console.datasets.datasets import _validate_description_length as dataset_validate
|
||||
from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate
|
||||
|
||||
|
||||
class TestDescriptionValidationUnit:
|
||||
"""Unit tests for description validation functions in App and Dataset APIs"""
|
||||
|
||||
def test_app_validate_description_length_valid(self):
|
||||
"""Test App validation function with valid descriptions"""
|
||||
# Empty string should be valid
|
||||
assert app_validate("") == ""
|
||||
|
||||
# None should be valid
|
||||
assert app_validate(None) is None
|
||||
|
||||
# Short description should be valid
|
||||
short_desc = "Short description"
|
||||
assert app_validate(short_desc) == short_desc
|
||||
|
||||
# Exactly 400 characters should be valid
|
||||
exactly_400 = "x" * 400
|
||||
assert app_validate(exactly_400) == exactly_400
|
||||
|
||||
# Just under limit should be valid
|
||||
just_under = "x" * 399
|
||||
assert app_validate(just_under) == just_under
|
||||
|
||||
def test_app_validate_description_length_invalid(self):
|
||||
"""Test App validation function with invalid descriptions"""
|
||||
# 401 characters should fail
|
||||
just_over = "x" * 401
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
app_validate(just_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
# 500 characters should fail
|
||||
way_over = "x" * 500
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
app_validate(way_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
# 1000 characters should fail
|
||||
very_long = "x" * 1000
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
app_validate(very_long)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_dataset_validate_description_length_valid(self):
|
||||
"""Test Dataset validation function with valid descriptions"""
|
||||
# Empty string should be valid
|
||||
assert dataset_validate("") == ""
|
||||
|
||||
# Short description should be valid
|
||||
short_desc = "Short description"
|
||||
assert dataset_validate(short_desc) == short_desc
|
||||
|
||||
# Exactly 400 characters should be valid
|
||||
exactly_400 = "x" * 400
|
||||
assert dataset_validate(exactly_400) == exactly_400
|
||||
|
||||
# Just under limit should be valid
|
||||
just_under = "x" * 399
|
||||
assert dataset_validate(just_under) == just_under
|
||||
|
||||
def test_dataset_validate_description_length_invalid(self):
|
||||
"""Test Dataset validation function with invalid descriptions"""
|
||||
# 401 characters should fail
|
||||
just_over = "x" * 401
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
dataset_validate(just_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
# 500 characters should fail
|
||||
way_over = "x" * 500
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
dataset_validate(way_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_service_dataset_validate_description_length_valid(self):
|
||||
"""Test Service Dataset validation function with valid descriptions"""
|
||||
# Empty string should be valid
|
||||
assert service_dataset_validate("") == ""
|
||||
|
||||
# None should be valid
|
||||
assert service_dataset_validate(None) is None
|
||||
|
||||
# Short description should be valid
|
||||
short_desc = "Short description"
|
||||
assert service_dataset_validate(short_desc) == short_desc
|
||||
|
||||
# Exactly 400 characters should be valid
|
||||
exactly_400 = "x" * 400
|
||||
assert service_dataset_validate(exactly_400) == exactly_400
|
||||
|
||||
# Just under limit should be valid
|
||||
just_under = "x" * 399
|
||||
assert service_dataset_validate(just_under) == just_under
|
||||
|
||||
def test_service_dataset_validate_description_length_invalid(self):
|
||||
"""Test Service Dataset validation function with invalid descriptions"""
|
||||
# 401 characters should fail
|
||||
just_over = "x" * 401
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service_dataset_validate(just_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
# 500 characters should fail
|
||||
way_over = "x" * 500
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service_dataset_validate(way_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_app_dataset_validation_consistency(self):
|
||||
"""Test that App and Dataset validation functions behave identically"""
|
||||
test_cases = [
|
||||
"", # Empty string
|
||||
"Short description", # Normal description
|
||||
"x" * 100, # Medium description
|
||||
"x" * 400, # Exactly at limit
|
||||
]
|
||||
|
||||
# Test valid cases produce same results
|
||||
for test_desc in test_cases:
|
||||
assert app_validate(test_desc) == dataset_validate(test_desc) == service_dataset_validate(test_desc)
|
||||
|
||||
# Test invalid cases produce same errors
|
||||
invalid_cases = [
|
||||
"x" * 401, # Just over limit
|
||||
"x" * 500, # Way over limit
|
||||
"x" * 1000, # Very long
|
||||
]
|
||||
|
||||
for invalid_desc in invalid_cases:
|
||||
app_error = None
|
||||
dataset_error = None
|
||||
service_dataset_error = None
|
||||
|
||||
# Capture App validation error
|
||||
try:
|
||||
app_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
app_error = str(e)
|
||||
|
||||
# Capture Dataset validation error
|
||||
try:
|
||||
dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
dataset_error = str(e)
|
||||
|
||||
# Capture Service Dataset validation error
|
||||
try:
|
||||
service_dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
service_dataset_error = str(e)
|
||||
|
||||
# All should produce errors
|
||||
assert app_error is not None, f"App validation should fail for {len(invalid_desc)} characters"
|
||||
assert dataset_error is not None, f"Dataset validation should fail for {len(invalid_desc)} characters"
|
||||
error_msg = f"Service Dataset validation should fail for {len(invalid_desc)} characters"
|
||||
assert service_dataset_error is not None, error_msg
|
||||
|
||||
# Errors should be identical
|
||||
error_msg = f"Error messages should be identical for {len(invalid_desc)} characters"
|
||||
assert app_error == dataset_error == service_dataset_error, error_msg
|
||||
assert app_error == "Description cannot exceed 400 characters."
|
||||
|
||||
def test_boundary_values(self):
|
||||
"""Test boundary values around the 400 character limit"""
|
||||
boundary_tests = [
|
||||
(0, True), # Empty
|
||||
(1, True), # Minimum
|
||||
(399, True), # Just under limit
|
||||
(400, True), # Exactly at limit
|
||||
(401, False), # Just over limit
|
||||
(402, False), # Over limit
|
||||
(500, False), # Way over limit
|
||||
]
|
||||
|
||||
for length, should_pass in boundary_tests:
|
||||
test_desc = "x" * length
|
||||
|
||||
if should_pass:
|
||||
# Should not raise exception
|
||||
assert app_validate(test_desc) == test_desc
|
||||
assert dataset_validate(test_desc) == test_desc
|
||||
assert service_dataset_validate(test_desc) == test_desc
|
||||
else:
|
||||
# Should raise ValueError
|
||||
with pytest.raises(ValueError):
|
||||
app_validate(test_desc)
|
||||
with pytest.raises(ValueError):
|
||||
dataset_validate(test_desc)
|
||||
with pytest.raises(ValueError):
|
||||
service_dataset_validate(test_desc)
|
||||
|
||||
def test_special_characters(self):
|
||||
"""Test validation with special characters, Unicode, etc."""
|
||||
# Unicode characters
|
||||
unicode_desc = "测试描述" * 100 # Chinese characters
|
||||
if len(unicode_desc) <= 400:
|
||||
assert app_validate(unicode_desc) == unicode_desc
|
||||
assert dataset_validate(unicode_desc) == unicode_desc
|
||||
assert service_dataset_validate(unicode_desc) == unicode_desc
|
||||
|
||||
# Special characters
|
||||
special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10
|
||||
if len(special_desc) <= 400:
|
||||
assert app_validate(special_desc) == special_desc
|
||||
assert dataset_validate(special_desc) == special_desc
|
||||
assert service_dataset_validate(special_desc) == special_desc
|
||||
|
||||
# Mixed content
|
||||
mixed_desc = "Mixed content: 测试 123 !@# " * 15
|
||||
if len(mixed_desc) <= 400:
|
||||
assert app_validate(mixed_desc) == mixed_desc
|
||||
assert dataset_validate(mixed_desc) == mixed_desc
|
||||
assert service_dataset_validate(mixed_desc) == mixed_desc
|
||||
elif len(mixed_desc) > 400:
|
||||
with pytest.raises(ValueError):
|
||||
app_validate(mixed_desc)
|
||||
with pytest.raises(ValueError):
|
||||
dataset_validate(mixed_desc)
|
||||
with pytest.raises(ValueError):
|
||||
service_dataset_validate(mixed_desc)
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test validation with various whitespace scenarios"""
|
||||
# Leading/trailing whitespace
|
||||
whitespace_desc = " Description with whitespace "
|
||||
if len(whitespace_desc) <= 400:
|
||||
assert app_validate(whitespace_desc) == whitespace_desc
|
||||
assert dataset_validate(whitespace_desc) == whitespace_desc
|
||||
assert service_dataset_validate(whitespace_desc) == whitespace_desc
|
||||
|
||||
# Newlines and tabs
|
||||
multiline_desc = "Line 1\nLine 2\tTabbed content"
|
||||
if len(multiline_desc) <= 400:
|
||||
assert app_validate(multiline_desc) == multiline_desc
|
||||
assert dataset_validate(multiline_desc) == multiline_desc
|
||||
assert service_dataset_validate(multiline_desc) == multiline_desc
|
||||
|
||||
# Only whitespace over limit
|
||||
only_spaces = " " * 401
|
||||
with pytest.raises(ValueError):
|
||||
app_validate(only_spaces)
|
||||
with pytest.raises(ValueError):
|
||||
dataset_validate(only_spaces)
|
||||
with pytest.raises(ValueError):
|
||||
service_dataset_validate(only_spaces)
|
||||
@ -0,0 +1,336 @@
|
||||
"""
|
||||
Unit tests for Service API File Preview endpoint
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.service_api.app.error import FileAccessDeniedError, FileNotFoundError
|
||||
from controllers.service_api.app.file_preview import FilePreviewApi
|
||||
from models.model import App, EndUser, Message, MessageFile, UploadFile
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
"""Test suite for FilePreviewApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def file_preview_api(self):
|
||||
"""Create FilePreviewApi instance for testing"""
|
||||
return FilePreviewApi()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Mock App model"""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user(self):
|
||||
"""Mock EndUser model"""
|
||||
end_user = Mock(spec=EndUser)
|
||||
end_user.id = str(uuid.uuid4())
|
||||
return end_user
|
||||
|
||||
@pytest.fixture
|
||||
def mock_upload_file(self):
|
||||
"""Mock UploadFile model"""
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.id = str(uuid.uuid4())
|
||||
upload_file.name = "test_file.jpg"
|
||||
upload_file.mime_type = "image/jpeg"
|
||||
upload_file.size = 1024
|
||||
upload_file.key = "storage/key/test_file.jpg"
|
||||
upload_file.tenant_id = str(uuid.uuid4())
|
||||
return upload_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file(self):
|
||||
"""Mock MessageFile model"""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Mock Message model"""
|
||||
message = Mock(spec=Message)
|
||||
message.id = str(uuid.uuid4())
|
||||
message.app_id = str(uuid.uuid4())
|
||||
return message
|
||||
|
||||
def test_validate_file_ownership_success(
|
||||
self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
|
||||
):
|
||||
"""Test successful file ownership validation"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = mock_app.id
|
||||
|
||||
# Set up the mocks
|
||||
mock_upload_file.tenant_id = mock_app.tenant_id
|
||||
mock_message.app_id = app_id
|
||||
mock_message_file.upload_file_id = file_id
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
# Execute the method
|
||||
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
# Assertions
|
||||
assert result_message_file == mock_message_file
|
||||
assert result_upload_file == mock_upload_file
|
||||
|
||||
def test_validate_file_ownership_file_not_found(self, file_preview_api):
|
||||
"""Test file ownership validation when MessageFile not found"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile not found
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileNotFoundError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
assert "File not found in message context" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_ownership_access_denied(self, file_preview_api, mock_message_file):
|
||||
"""Test file ownership validation when Message not owned by app"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile found but Message not owned by app
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query - found
|
||||
None, # Message query - not found (access denied)
|
||||
]
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
assert "not owned by requesting app" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_ownership_upload_file_not_found(self, file_preview_api, mock_message_file, mock_message):
|
||||
"""Test file ownership validation when UploadFile not found"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile and Message found but UploadFile not found
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query - found
|
||||
mock_message, # Message query - found
|
||||
None, # UploadFile query - not found
|
||||
]
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileNotFoundError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
assert "Upload file record not found" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_ownership_tenant_mismatch(
|
||||
self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
|
||||
):
|
||||
"""Test file ownership validation with tenant mismatch"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = mock_app.id
|
||||
|
||||
# Set up tenant mismatch
|
||||
mock_upload_file.tenant_id = "different_tenant_id"
|
||||
mock_app.tenant_id = "app_tenant_id"
|
||||
mock_message.app_id = app_id
|
||||
mock_message_file.upload_file_id = file_id
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
assert "tenant mismatch" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_ownership_invalid_input(self, file_preview_api):
|
||||
"""Test file ownership validation with invalid input"""
|
||||
|
||||
# Test with empty file_id
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership("", "app_id")
|
||||
assert "Invalid file or app identifier" in str(exc_info.value)
|
||||
|
||||
# Test with empty app_id
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership("file_id", "")
|
||||
assert "Invalid file or app identifier" in str(exc_info.value)
|
||||
|
||||
def test_build_file_response_basic(self, file_preview_api, mock_upload_file):
|
||||
"""Test basic file response building"""
|
||||
mock_generator = Mock()
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
|
||||
# Check response properties
|
||||
assert response.mimetype == mock_upload_file.mime_type
|
||||
assert response.direct_passthrough is True
|
||||
assert response.headers["Content-Length"] == str(mock_upload_file.size)
|
||||
assert "Cache-Control" in response.headers
|
||||
|
||||
def test_build_file_response_as_attachment(self, file_preview_api, mock_upload_file):
|
||||
"""Test file response building with attachment flag"""
|
||||
mock_generator = Mock()
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, True)
|
||||
|
||||
# Check attachment-specific headers
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
assert mock_upload_file.name in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
|
||||
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
|
||||
"""Test file response building for audio/video files"""
|
||||
mock_generator = Mock()
|
||||
mock_upload_file.mime_type = "video/mp4"
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
|
||||
# Check Range support for media files
|
||||
assert response.headers["Accept-Ranges"] == "bytes"
|
||||
|
||||
def test_build_file_response_no_size(self, file_preview_api, mock_upload_file):
|
||||
"""Test file response building when size is unknown"""
|
||||
mock_generator = Mock()
|
||||
mock_upload_file.size = 0 # Unknown size
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
|
||||
# Content-Length should not be set when size is unknown
|
||||
assert "Content-Length" not in response.headers
|
||||
|
||||
@patch("controllers.service_api.app.file_preview.storage")
|
||||
def test_get_method_integration(
|
||||
self, mock_storage, file_preview_api, mock_app, mock_end_user, mock_upload_file, mock_message_file, mock_message
|
||||
):
|
||||
"""Test the full GET method integration (without decorator)"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = mock_app.id
|
||||
|
||||
# Set up mocks
|
||||
mock_upload_file.tenant_id = mock_app.tenant_id
|
||||
mock_message.app_id = app_id
|
||||
mock_message_file.upload_file_id = file_id
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
mock_generator = Mock()
|
||||
mock_storage.load.return_value = mock_generator
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse:
|
||||
# Mock request parsing
|
||||
mock_parser = Mock()
|
||||
mock_parser.parse_args.return_value = {"as_attachment": False}
|
||||
mock_reqparse.RequestParser.return_value = mock_parser
|
||||
|
||||
# Test the core logic directly without Flask decorators
|
||||
# Validate file ownership
|
||||
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
assert result_message_file == mock_message_file
|
||||
assert result_upload_file == mock_upload_file
|
||||
|
||||
# Test file response building
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
assert response is not None
|
||||
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_not_called() # Since we're testing components separately
|
||||
|
||||
@patch("controllers.service_api.app.file_preview.storage")
|
||||
def test_storage_error_handling(
|
||||
self, mock_storage, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
|
||||
):
|
||||
"""Test storage error handling in the core logic"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = mock_app.id
|
||||
|
||||
# Set up mocks
|
||||
mock_upload_file.tenant_id = mock_app.tenant_id
|
||||
mock_message.app_id = app_id
|
||||
mock_message_file.upload_file_id = file_id
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
# Mock storage error
|
||||
mock_storage.load.side_effect = Exception("Storage error")
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries for validation
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
# First validate file ownership works
|
||||
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
assert result_message_file == mock_message_file
|
||||
assert result_upload_file == mock_upload_file
|
||||
|
||||
# Test storage error handling
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
mock_storage.load(mock_upload_file.key, stream=True)
|
||||
|
||||
assert "Storage error" in str(exc_info.value)
|
||||
|
||||
@patch("controllers.service_api.app.file_preview.logger")
|
||||
def test_validate_file_ownership_unexpected_error_logging(self, mock_logger, file_preview_api):
|
||||
"""Test that unexpected errors are logged properly"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database query to raise unexpected exception
|
||||
mock_db.session.query.side_effect = Exception("Unexpected database error")
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
# Verify error message
|
||||
assert "File access validation failed" in str(exc_info.value)
|
||||
|
||||
# Verify logging was called
|
||||
mock_logger.exception.assert_called_once_with(
|
||||
"Unexpected error during file ownership validation",
|
||||
extra={"file_id": file_id, "app_id": app_id, "error": "Unexpected database error"},
|
||||
)
|
||||
@ -0,0 +1,419 @@
|
||||
"""Test conversation variable handling in AdvancedChatAppRunner."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.variables import SegmentType
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable, Workflow
|
||||
|
||||
|
||||
class TestAdvancedChatAppRunnerConversationVariables:
|
||||
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
|
||||
|
||||
def test_missing_conversation_variables_are_added(self):
|
||||
"""Test that new conversation variables added to workflow are created for existing conversations."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with two conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "existing_var",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "new_var",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow with conversation variables
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Create existing conversation variable (only var1 exists in DB)
|
||||
existing_db_var = MagicMock(spec=ConversationVariable)
|
||||
existing_db_var.id = "var1"
|
||||
existing_db_var.app_id = app_id
|
||||
existing_db_var.conversation_id = conversation_id
|
||||
existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0])
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# First query returns only existing variable
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = [existing_db_var]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Track what gets added to session
|
||||
added_items = []
|
||||
|
||||
def track_add_all(items):
|
||||
added_items.extend(items)
|
||||
|
||||
mock_session.add_all.side_effect = track_add_all
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that the missing variable was added
|
||||
assert len(added_items) == 1, "Should have added exactly one missing variable"
|
||||
|
||||
# Check that the added item is the missing variable (var2)
|
||||
added_var = added_items[0]
|
||||
assert hasattr(added_var, "id"), "Added item should be a ConversationVariable"
|
||||
# Note: Since we're mocking ConversationVariable.from_variable,
|
||||
# we can't directly check the id, but we can verify add_all was called
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_no_variables_creates_all(self):
|
||||
"""Test that all conversation variables are created when none exist in DB."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "var1",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "var2",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# Query returns empty list (no existing variables)
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = []
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Track what gets added to session
|
||||
added_items = []
|
||||
|
||||
def track_add_all(items):
|
||||
added_items.extend(items)
|
||||
|
||||
mock_session.add_all.side_effect = track_add_all
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock ConversationVariable.from_variable to return mock objects
|
||||
mock_conv_vars = []
|
||||
for var in workflow_vars:
|
||||
mock_cv = MagicMock()
|
||||
mock_cv.id = var.id
|
||||
mock_cv.to_variable.return_value = var
|
||||
mock_conv_vars.append(mock_cv)
|
||||
|
||||
mock_conv_var_class.from_variable.side_effect = mock_conv_vars
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that all variables were created
|
||||
assert len(added_items) == 2, "Should have added both variables"
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_all_variables_exist_no_changes(self):
|
||||
"""Test that no changes are made when all variables already exist in DB."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "var1",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "var2",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Create existing conversation variables (both exist in DB)
|
||||
existing_db_vars = []
|
||||
for var in workflow_vars:
|
||||
db_var = MagicMock(spec=ConversationVariable)
|
||||
db_var.id = var.id
|
||||
db_var.app_id = app_id
|
||||
db_var.conversation_id = conversation_id
|
||||
db_var.to_variable = MagicMock(return_value=var)
|
||||
existing_db_vars.append(db_var)
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# Query returns all existing variables
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = existing_db_vars
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that no variables were added
|
||||
assert not mock_session.add_all.called, "Session add_all should not have been called"
|
||||
assert mock_session.commit.called, "Session commit should still be called"
|
||||
127
api/tests/unit_tests/services/test_conversation_service.py
Normal file
127
api/tests/unit_tests/services/test_conversation_service.py
Normal file
@ -0,0 +1,127 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class TestConversationService:
|
||||
def test_pagination_with_empty_include_ids(self):
|
||||
"""Test that empty include_ids returns empty result"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=[], # Empty include_ids should return empty result
|
||||
exclude_ids=None,
|
||||
)
|
||||
|
||||
assert result.data == []
|
||||
assert result.has_more is False
|
||||
assert result.limit == 20
|
||||
|
||||
def test_pagination_with_non_empty_include_ids(self):
|
||||
"""Test that non-empty include_ids filters properly"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=["conv1", "conv2"], # Non-empty include_ids
|
||||
exclude_ids=None,
|
||||
)
|
||||
|
||||
# Verify the where clause was called with id.in_
|
||||
assert mock_stmt.where.called
|
||||
|
||||
def test_pagination_with_empty_exclude_ids(self):
|
||||
"""Test that empty exclude_ids doesn't filter"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=None,
|
||||
exclude_ids=[], # Empty exclude_ids should not filter
|
||||
)
|
||||
|
||||
# Result should contain the mocked conversations
|
||||
assert len(result.data) == 5
|
||||
|
||||
def test_pagination_with_non_empty_exclude_ids(self):
|
||||
"""Test that non-empty exclude_ids filters properly"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=None,
|
||||
exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids
|
||||
)
|
||||
|
||||
# Verify the where clause was called for exclusion
|
||||
assert mock_stmt.where.called
|
||||
Reference in New Issue
Block a user