Merge branch main into feat/rag-2

This commit is contained in:
twwu
2025-07-24 17:40:04 +08:00
608 changed files with 8175 additions and 3026 deletions

View File

@ -214,7 +214,7 @@ class TestDraftVariableLoader(unittest.TestCase):
def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session:
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete(
synchronize_session=False
)
session.commit()

View File

@ -1,4 +1,7 @@
import os
import uuid
import tablestore
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
TableStoreConfig,
@ -6,6 +9,8 @@ from core.rag.datasource.vdb.tablestore.tablestore_vector import (
)
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_document,
get_example_text,
setup_mock_redis,
)
@ -29,6 +34,49 @@ class TableStoreVectorTest(AbstractVectorTest):
assert len(ids) == 1
assert ids[0] == self.example_doc_id
def create_vector(self):
self.vector.create(
texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[self.example_embedding],
)
while True:
search_response = self.vector._tablestore_client.search(
table_name=self.vector._table_name,
index_name=self.vector._index_name,
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
if search_response.total_count == 1:
break
def search_by_vector(self):
super().search_by_vector()
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert docs[0].metadata["score"] > 0
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0
def search_by_full_text(self):
super().search_by_full_text()
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert not hasattr(docs[0], "score")
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0
def run_all_tests(self):
try:
self.vector.delete()
except Exception:
pass
return super().run_all_tests()
def test_tablestore_vector(setup_mock_redis):
TableStoreVectorTest().run_all_tests()

View File

@ -44,7 +44,7 @@ class TestEncryptToken:
"""Test successful token encryption"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_data"
result = encrypt_token("tenant-123", "test_token")
@ -55,7 +55,7 @@ class TestEncryptToken:
@patch("models.engine.db.session.query")
def test_tenant_not_found(self, mock_query):
"""Test error when tenant doesn't exist"""
mock_query.return_value.filter.return_value.first.return_value = None
mock_query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError) as exc_info:
encrypt_token("invalid-tenant", "test_token")
@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration:
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# Setup mock encryption/decryption
original_token = "test_token_123"
@ -153,7 +153,7 @@ class TestSecurity:
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "tenant1_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_for_tenant1"
# Encrypt token for tenant1
@ -186,7 +186,7 @@ class TestSecurity:
def test_encryption_randomness(self, mock_encrypt, mock_query):
"""Ensure same plaintext produces different ciphertext"""
mock_tenant = MagicMock(encrypt_public_key="key")
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# Different outputs for same input
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
@ -211,7 +211,7 @@ class TestEdgeCases:
"""Test encryption of empty token"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_empty"
result = encrypt_token("tenant-123", "")
@ -225,7 +225,7 @@ class TestEdgeCases:
"""Test tokens containing special/unicode characters"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_special"
# Test various special characters
@ -248,7 +248,7 @@ class TestEdgeCases:
"""Test behavior when token exceeds RSA encryption limits"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# RSA 2048-bit can only encrypt ~245 bytes
# The actual limit depends on padding scheme

View File

@ -0,0 +1,86 @@
import pytest
from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id
class DummyRequest:
def __init__(self, headers=None, args=None, json=None, is_json=False):
self.headers = headers or {}
self.args = args or {}
self.json = json
self.is_json = is_json
class TestTraceIdHelper:
"""Test cases for trace_id_helper.py"""
@pytest.mark.parametrize(
("trace_id", "expected"),
[
("abc123", True),
("A-B_C-123", True),
("a" * 128, True),
("", False),
("a" * 129, False),
("abc!@#", False),
("空格", False),
("with space", False),
],
)
def test_is_valid_trace_id(self, trace_id, expected):
"""Test trace_id validation for various cases"""
assert is_valid_trace_id(trace_id) is expected
def test_get_external_trace_id_from_header(self):
"""Should extract valid trace_id from header"""
req = DummyRequest(headers={"X-Trace-Id": "abc123"})
assert get_external_trace_id(req) == "abc123"
def test_get_external_trace_id_from_args(self):
"""Should extract valid trace_id from args if header missing"""
req = DummyRequest(args={"trace_id": "abc123"})
assert get_external_trace_id(req) == "abc123"
def test_get_external_trace_id_from_json(self):
"""Should extract valid trace_id from JSON body if header and args missing"""
req = DummyRequest(is_json=True, json={"trace_id": "abc123"})
assert get_external_trace_id(req) == "abc123"
def test_get_external_trace_id_priority(self):
"""Header > args > json priority"""
req = DummyRequest(
headers={"X-Trace-Id": "header_id"},
args={"trace_id": "args_id"},
is_json=True,
json={"trace_id": "json_id"},
)
assert get_external_trace_id(req) == "header_id"
req2 = DummyRequest(args={"trace_id": "args_id"}, is_json=True, json={"trace_id": "json_id"})
assert get_external_trace_id(req2) == "args_id"
req3 = DummyRequest(is_json=True, json={"trace_id": "json_id"})
assert get_external_trace_id(req3) == "json_id"
@pytest.mark.parametrize(
"req",
[
DummyRequest(headers={"X-Trace-Id": "!!!"}),
DummyRequest(args={"trace_id": "!!!"}),
DummyRequest(is_json=True, json={"trace_id": "!!!"}),
DummyRequest(),
],
)
def test_get_external_trace_id_invalid(self, req):
"""Should return None for invalid or missing trace_id"""
assert get_external_trace_id(req) is None
@pytest.mark.parametrize(
("args", "expected"),
[
({"external_trace_id": "abc123"}, {"external_trace_id": "abc123"}),
({"other": "value"}, {}),
({}, {}),
],
)
def test_extract_external_trace_id_from_args(self, args, expected):
"""Test extraction of external_trace_id from args mapping"""
assert extract_external_trace_id_from_args(args) == expected

View File

@ -54,8 +54,7 @@ def mock_tool_file():
mock.mimetype = "application/pdf"
mock.original_url = "http://example.com/tool.pdf"
mock.size = 2048
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.filter.return_value.first.return_value = mock
with patch("factories.file_factory.db.session.scalar", return_value=mock):
yield mock
@ -153,8 +152,7 @@ def test_build_from_remote_url(mock_http_head):
def test_tool_file_not_found():
"""Test ToolFile not found in database."""
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.filter.return_value.first.return_value = None
with patch("factories.file_factory.db.session.scalar", return_value=None):
mapping = tool_file_mapping()
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)

View File

@ -0,0 +1,539 @@
"""
Unit tests for EmailI18nService
Tests the email internationalization service with mocked dependencies
following Domain-Driven Design principles.
"""
from typing import Any
from unittest.mock import MagicMock
import pytest
from libs.email_i18n import (
EmailI18nConfig,
EmailI18nService,
EmailLanguage,
EmailTemplate,
EmailType,
FlaskEmailRenderer,
FlaskMailSender,
create_default_email_config,
get_email_i18n_service,
)
from services.feature_service import BrandingModel
class MockEmailRenderer:
"""Mock implementation of EmailRenderer protocol"""
def __init__(self) -> None:
self.rendered_templates: list[tuple[str, dict[str, Any]]] = []
def render_template(self, template_path: str, **context: Any) -> str:
"""Mock render_template that returns a formatted string"""
self.rendered_templates.append((template_path, context))
return f"<html>Rendered {template_path} with {context}</html>"
class MockBrandingService:
"""Mock implementation of BrandingService protocol"""
def __init__(self, enabled: bool = False, application_title: str = "Dify") -> None:
self.enabled = enabled
self.application_title = application_title
def get_branding_config(self) -> BrandingModel:
"""Return mock branding configuration"""
branding_model = MagicMock(spec=BrandingModel)
branding_model.enabled = self.enabled
branding_model.application_title = self.application_title
return branding_model
class MockEmailSender:
"""Mock implementation of EmailSender protocol"""
def __init__(self) -> None:
self.sent_emails: list[dict[str, str]] = []
def send_email(self, to: str, subject: str, html_content: str) -> None:
"""Mock send_email that records sent emails"""
self.sent_emails.append(
{
"to": to,
"subject": subject,
"html_content": html_content,
}
)
class TestEmailI18nService:
"""Test cases for EmailI18nService"""
@pytest.fixture
def email_config(self) -> EmailI18nConfig:
"""Create test email configuration"""
return EmailI18nConfig(
templates={
EmailType.RESET_PASSWORD: {
EmailLanguage.EN_US: EmailTemplate(
subject="Reset Your {application_title} Password",
template_path="reset_password_en.html",
branded_template_path="branded/reset_password_en.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="重置您的 {application_title} 密码",
template_path="reset_password_zh.html",
branded_template_path="branded/reset_password_zh.html",
),
},
EmailType.INVITE_MEMBER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Join {application_title} Workspace",
template_path="invite_member_en.html",
branded_template_path="branded/invite_member_en.html",
),
},
}
)
@pytest.fixture
def mock_renderer(self) -> MockEmailRenderer:
"""Create mock email renderer"""
return MockEmailRenderer()
@pytest.fixture
def mock_branding_service(self) -> MockBrandingService:
"""Create mock branding service"""
return MockBrandingService()
@pytest.fixture
def mock_sender(self) -> MockEmailSender:
"""Create mock email sender"""
return MockEmailSender()
@pytest.fixture
def email_service(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_branding_service: MockBrandingService,
mock_sender: MockEmailSender,
) -> EmailI18nService:
"""Create EmailI18nService with mocked dependencies"""
return EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=mock_branding_service,
sender=mock_sender,
)
def test_send_email_with_english_language(
self,
email_service: EmailI18nService,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
) -> None:
"""Test sending email with English language"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code="en-US",
to="test@example.com",
template_context={"reset_link": "https://example.com/reset"},
)
# Verify renderer was called with correct template
assert len(mock_renderer.rendered_templates) == 1
template_path, context = mock_renderer.rendered_templates[0]
assert template_path == "reset_password_en.html"
assert context["reset_link"] == "https://example.com/reset"
assert context["branding_enabled"] is False
assert context["application_title"] == "Dify"
# Verify email was sent
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["to"] == "test@example.com"
assert sent_email["subject"] == "Reset Your Dify Password"
assert "reset_password_en.html" in sent_email["html_content"]
def test_send_email_with_chinese_language(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
"""Test sending email with Chinese language"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code="zh-Hans",
to="test@example.com",
template_context={"reset_link": "https://example.com/reset"},
)
# Verify email was sent with Chinese subject
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "重置您的 Dify 密码"
def test_send_email_with_branding_enabled(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
) -> None:
"""Test sending email with branding enabled"""
# Create branding service with branding enabled
branding_service = MockBrandingService(enabled=True, application_title="MyApp")
email_service = EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=branding_service,
sender=mock_sender,
)
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code="en-US",
to="test@example.com",
)
# Verify branded template was used
assert len(mock_renderer.rendered_templates) == 1
template_path, context = mock_renderer.rendered_templates[0]
assert template_path == "branded/reset_password_en.html"
assert context["branding_enabled"] is True
assert context["application_title"] == "MyApp"
# Verify subject includes custom application title
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "Reset Your MyApp Password"
def test_send_email_with_language_fallback(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
"""Test language fallback to English when requested language not available"""
# Request invite member in Chinese (not configured)
email_service.send_email(
email_type=EmailType.INVITE_MEMBER,
language_code="zh-Hans",
to="test@example.com",
)
# Should fall back to English
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "Join Dify Workspace"
def test_send_email_with_unknown_language_code(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
"""Test unknown language code falls back to English"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code="fr-FR", # French not configured
to="test@example.com",
)
# Should use English
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "Reset Your Dify Password"
def test_send_change_email_old_phase(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
) -> None:
"""Test sending change email for old email verification"""
# Add change email templates to config
email_config.templates[EmailType.CHANGE_EMAIL_OLD] = {
EmailLanguage.EN_US: EmailTemplate(
subject="Verify your current email",
template_path="change_email_old_en.html",
branded_template_path="branded/change_email_old_en.html",
),
}
email_service = EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=mock_branding_service,
sender=mock_sender,
)
email_service.send_change_email(
language_code="en-US",
to="old@example.com",
code="123456",
phase="old_email",
)
# Verify correct template and context
assert len(mock_renderer.rendered_templates) == 1
template_path, context = mock_renderer.rendered_templates[0]
assert template_path == "change_email_old_en.html"
assert context["to"] == "old@example.com"
assert context["code"] == "123456"
def test_send_change_email_new_phase(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
) -> None:
"""Test sending change email for new email verification"""
# Add change email templates to config
email_config.templates[EmailType.CHANGE_EMAIL_NEW] = {
EmailLanguage.EN_US: EmailTemplate(
subject="Verify your new email",
template_path="change_email_new_en.html",
branded_template_path="branded/change_email_new_en.html",
),
}
email_service = EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=mock_branding_service,
sender=mock_sender,
)
email_service.send_change_email(
language_code="en-US",
to="new@example.com",
code="654321",
phase="new_email",
)
# Verify correct template and context
assert len(mock_renderer.rendered_templates) == 1
template_path, context = mock_renderer.rendered_templates[0]
assert template_path == "change_email_new_en.html"
assert context["to"] == "new@example.com"
assert context["code"] == "654321"
def test_send_change_email_invalid_phase(
self,
email_service: EmailI18nService,
) -> None:
"""Test sending change email with invalid phase raises error"""
with pytest.raises(ValueError, match="Invalid phase: invalid_phase"):
email_service.send_change_email(
language_code="en-US",
to="test@example.com",
code="123456",
phase="invalid_phase",
)
def test_send_raw_email_single_recipient(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
"""Test sending raw email to single recipient"""
email_service.send_raw_email(
to="test@example.com",
subject="Test Subject",
html_content="<html>Test Content</html>",
)
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["to"] == "test@example.com"
assert sent_email["subject"] == "Test Subject"
assert sent_email["html_content"] == "<html>Test Content</html>"
def test_send_raw_email_multiple_recipients(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
) -> None:
"""Test sending raw email to multiple recipients"""
recipients = ["user1@example.com", "user2@example.com", "user3@example.com"]
email_service.send_raw_email(
to=recipients,
subject="Test Subject",
html_content="<html>Test Content</html>",
)
# Should send individual emails to each recipient
assert len(mock_sender.sent_emails) == 3
for i, recipient in enumerate(recipients):
sent_email = mock_sender.sent_emails[i]
assert sent_email["to"] == recipient
assert sent_email["subject"] == "Test Subject"
assert sent_email["html_content"] == "<html>Test Content</html>"
def test_get_template_missing_email_type(
self,
email_config: EmailI18nConfig,
) -> None:
"""Test getting template for missing email type raises error"""
with pytest.raises(ValueError, match="No templates configured for email type"):
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
def test_get_template_missing_language_and_english(
self,
email_config: EmailI18nConfig,
) -> None:
"""Test error when neither requested language nor English fallback exists"""
# Add template without English fallback
email_config.templates[EmailType.EMAIL_CODE_LOGIN] = {
EmailLanguage.ZH_HANS: EmailTemplate(
subject="Test",
template_path="test.html",
branded_template_path="branded/test.html",
),
}
with pytest.raises(ValueError, match="No template found for"):
# Request a language that doesn't exist and no English fallback
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
def test_subject_templating_with_variables(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
) -> None:
"""Test subject templating with custom variables"""
# Add template with variable in subject
email_config.templates[EmailType.OWNER_TRANSFER_NEW_NOTIFY] = {
EmailLanguage.EN_US: EmailTemplate(
subject="You are now the owner of {WorkspaceName}",
template_path="owner_transfer_en.html",
branded_template_path="branded/owner_transfer_en.html",
),
}
email_service = EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=mock_branding_service,
sender=mock_sender,
)
email_service.send_email(
email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY,
language_code="en-US",
to="test@example.com",
template_context={"WorkspaceName": "My Workspace"},
)
# Verify subject was templated correctly
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "You are now the owner of My Workspace"
def test_email_language_from_language_code(self) -> None:
"""Test EmailLanguage.from_language_code method"""
assert EmailLanguage.from_language_code("zh-Hans") == EmailLanguage.ZH_HANS
assert EmailLanguage.from_language_code("en-US") == EmailLanguage.EN_US
assert EmailLanguage.from_language_code("fr-FR") == EmailLanguage.EN_US # Fallback
assert EmailLanguage.from_language_code("unknown") == EmailLanguage.EN_US # Fallback
class TestEmailI18nIntegration:
"""Integration tests for email i18n components"""
def test_create_default_email_config(self) -> None:
"""Test creating default email configuration"""
config = create_default_email_config()
# Verify key email types have at least English template
expected_types = [
EmailType.RESET_PASSWORD,
EmailType.INVITE_MEMBER,
EmailType.EMAIL_CODE_LOGIN,
EmailType.CHANGE_EMAIL_OLD,
EmailType.CHANGE_EMAIL_NEW,
EmailType.OWNER_TRANSFER_CONFIRM,
EmailType.OWNER_TRANSFER_OLD_NOTIFY,
EmailType.OWNER_TRANSFER_NEW_NOTIFY,
EmailType.ACCOUNT_DELETION_SUCCESS,
EmailType.ACCOUNT_DELETION_VERIFICATION,
EmailType.QUEUE_MONITOR_ALERT,
EmailType.DOCUMENT_CLEAN_NOTIFY,
]
for email_type in expected_types:
assert email_type in config.templates
assert EmailLanguage.EN_US in config.templates[email_type]
# Verify some have Chinese translations
assert EmailLanguage.ZH_HANS in config.templates[EmailType.RESET_PASSWORD]
assert EmailLanguage.ZH_HANS in config.templates[EmailType.INVITE_MEMBER]
def test_get_email_i18n_service(self) -> None:
"""Test getting global email i18n service instance"""
service1 = get_email_i18n_service()
service2 = get_email_i18n_service()
# Should return the same instance
assert service1 is service2
def test_flask_email_renderer(self) -> None:
"""Test FlaskEmailRenderer implementation"""
renderer = FlaskEmailRenderer()
# Should raise TemplateNotFound when template doesn't exist
from jinja2.exceptions import TemplateNotFound
with pytest.raises(TemplateNotFound):
renderer.render_template("test.html", foo="bar")
def test_flask_mail_sender_not_initialized(self) -> None:
"""Test FlaskMailSender when mail is not initialized"""
sender = FlaskMailSender()
# Mock mail.is_inited() to return False
import libs.email_i18n
original_mail = libs.email_i18n.mail
mock_mail = MagicMock()
mock_mail.is_inited.return_value = False
libs.email_i18n.mail = mock_mail
try:
# Should not send email when mail is not initialized
sender.send_email("test@example.com", "Subject", "<html>Content</html>")
mock_mail.send.assert_not_called()
finally:
# Restore original mail
libs.email_i18n.mail = original_mail
def test_flask_mail_sender_initialized(self) -> None:
"""Test FlaskMailSender when mail is initialized"""
sender = FlaskMailSender()
# Mock mail.is_inited() to return True
import libs.email_i18n
original_mail = libs.email_i18n.mail
mock_mail = MagicMock()
mock_mail.is_inited.return_value = True
libs.email_i18n.mail = mock_mail
try:
# Should send email when mail is initialized
sender.send_email("test@example.com", "Subject", "<html>Content</html>")
mock_mail.send.assert_called_once_with(
to="test@example.com",
subject="Subject",
html="<html>Content</html>",
)
finally:
# Restore original mail
libs.email_i18n.mail = original_mail

View File

@ -6,7 +6,7 @@ import pytest
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
from sqlalchemy import insert
from sqlalchemy.orm import DeclarativeBase, Mapped, Session
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from sqlalchemy.sql.sqltypes import VARCHAR
from models.types import EnumText
@ -32,22 +32,26 @@ class _EnumWithLongValue(StrEnum):
class _User(_Base):
__tablename__ = "users"
id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False)
user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True)
id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
name: Mapped[str] = mapped_column(sa.String(length=255), nullable=False)
user_type: Mapped[_UserType] = mapped_column(
EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal
)
user_type_nullable: Mapped[_UserType | None] = mapped_column(EnumText(enum_class=_UserType), nullable=True)
class _ColumnTest(_Base):
__tablename__ = "column_test"
id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
explicit_length: Mapped[_UserType | None] = sa.Column(
user_type: Mapped[_UserType] = mapped_column(
EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal
)
explicit_length: Mapped[_UserType | None] = mapped_column(
EnumText(_UserType, length=50), nullable=True, default=_UserType.normal
)
long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
_T = TypeVar("_T")
@ -110,12 +114,12 @@ class TestEnumText:
session.commit()
with Session(engine) as session:
user = session.query(_User).filter(_User.id == admin_user_id).first()
user = session.query(_User).where(_User.id == admin_user_id).first()
assert user.user_type == _UserType.admin
assert user.user_type_nullable is None
with Session(engine) as session:
user = session.query(_User).filter(_User.id == normal_user_id).first()
user = session.query(_User).where(_User.id == normal_user_id).first()
assert user.user_type == _UserType.normal
assert user.user_type_nullable == _UserType.normal
@ -184,4 +188,4 @@ class TestEnumText:
with pytest.raises(ValueError) as exc:
with Session(engine) as session:
_user = session.query(_User).filter(_User.id == 1).first()
_user = session.query(_User).where(_User.id == 1).first()

View File

@ -28,7 +28,7 @@ class TestApiKeyAuthService:
mock_binding.provider = self.provider
mock_binding.disabled = False
mock_session.query.return_value.filter.return_value.all.return_value = [mock_binding]
mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
@ -39,7 +39,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_empty(self, mock_session):
"""Test get provider auth list - empty result"""
mock_session.query.return_value.filter.return_value.all.return_value = []
mock_session.query.return_value.where.return_value.all.return_value = []
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
@ -48,13 +48,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_filters_disabled(self, mock_session):
"""Test get provider auth list - filters disabled items"""
mock_session.query.return_value.filter.return_value.all.return_value = []
mock_session.query.return_value.where.return_value.all.return_value = []
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
# Verify filter conditions include disabled.is_(False)
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 2 # tenant_id and disabled filter conditions
# Verify where conditions include disabled.is_(False)
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2 # tenant_id and disabled filter conditions
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@ -138,7 +138,8 @@ class TestApiKeyAuthService:
# Mock database query result
mock_binding = Mock()
mock_binding.credentials = json.dumps(self.mock_credentials)
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
@ -148,7 +149,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_not_found(self, mock_session):
"""Test get auth credentials - not found"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
@ -157,13 +158,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_filters_correctly(self, mock_session):
"""Test get auth credentials - applies correct filters"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
# Verify filter conditions are correct
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 4 # tenant_id, category, provider, disabled
# Verify where conditions are correct
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 4 # tenant_id, category, provider, disabled
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_json_parsing(self, mock_session):
@ -173,7 +174,7 @@ class TestApiKeyAuthService:
mock_binding = Mock()
mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
@ -185,7 +186,7 @@ class TestApiKeyAuthService:
"""Test delete provider auth - success scenario"""
# Mock database query result
mock_binding = Mock()
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
@ -196,7 +197,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_not_found(self, mock_session):
"""Test delete provider auth - not found"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
@ -207,13 +208,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_filters_by_tenant(self, mock_session):
"""Test delete provider auth - filters by tenant"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify filter conditions include tenant_id and binding_id
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 2
# Verify where conditions include tenant_id and binding_id
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2
def test_validate_api_key_auth_args_success(self):
"""Test API key auth args validation - success scenario"""
@ -336,7 +337,7 @@ class TestApiKeyAuthService:
# Mock database returning invalid JSON
mock_binding = Mock()
mock_binding.credentials = "invalid json content"
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
with pytest.raises(json.JSONDecodeError):
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)

View File

@ -0,0 +1,234 @@
"""
API Key Authentication System Integration Tests
"""
import json
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch
import pytest
import requests
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService
from services.auth.auth_type import AuthType
class TestAuthIntegration:
def setup_method(self):
self.tenant_id_1 = "tenant_123"
self.tenant_id_2 = "tenant_456" # For multi-tenant isolation testing
self.category = "search"
# Realistic authentication configurations
self.firecrawl_credentials = {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}
self.jina_credentials = {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}}
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
"""Test complete authentication flow: request → validation → encryption → storage"""
mock_http.return_value = self._create_success_response()
mock_encrypt.return_value = "encrypted_fc_test_key_123"
mock_session.add = Mock()
mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
mock_http.assert_called_once()
call_args = mock_http.call_args
assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0]
assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123"
mock_encrypt.assert_called_once_with(self.tenant_id_1, "fc_test_key_123")
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_cross_component_integration(self, mock_http):
"""Test factory → provider → HTTP call integration"""
mock_http.return_value = self._create_success_response()
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
result = factory.validate_credentials()
assert result is True
mock_http.assert_called_once()
@patch("services.auth.api_key_auth_service.db.session")
def test_multi_tenant_isolation(self, mock_session):
"""Ensure complete tenant data isolation"""
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
assert len(result1) == 1
assert result1[0].tenant_id == self.tenant_id_1
assert len(result2) == 1
assert result2[0].tenant_id == self.tenant_id_2
@patch("services.auth.api_key_auth_service.db.session")
def test_cross_tenant_access_prevention(self, mock_session):
"""Test prevention of cross-tenant credential access"""
mock_session.query.return_value.where.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL)
assert result is None
def test_sensitive_data_protection(self):
"""Ensure API keys don't leak to logs"""
credentials_with_secrets = {
"auth_type": "bearer",
"config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"},
}
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets)
factory_str = str(factory)
assert "super_secret_key_do_not_log" not in factory_str
assert "another_secret" not in factory_str
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
"""Test concurrent authentication creation safety"""
mock_http.return_value = self._create_success_response()
mock_encrypt.return_value = "encrypted_key"
mock_session.add = Mock()
mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
results = []
exceptions = []
def create_auth():
try:
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
results.append("success")
except Exception as e:
exceptions.append(e)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(create_auth) for _ in range(5)]
for future in futures:
future.result()
assert len(results) == 5
assert len(exceptions) == 0
assert mock_session.add.call_count == 5
assert mock_session.commit.call_count == 5
@pytest.mark.parametrize(
"invalid_input",
[
None, # Null input
{}, # Empty dictionary - missing required fields
{"auth_type": "bearer"}, # Missing config section
{"auth_type": "bearer", "config": {}}, # Missing api_key
],
)
def test_invalid_input_boundary(self, invalid_input):
"""Test boundary handling for invalid inputs"""
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_http_error_handling(self, mock_http):
"""Test proper HTTP error handling"""
mock_response = Mock()
mock_response.status_code = 401
mock_response.text = '{"error": "Unauthorized"}'
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
mock_http.return_value = mock_response
# PT012: Split into single statement for pytest.raises
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
with pytest.raises((requests.exceptions.HTTPError, Exception)):
factory.validate_credentials()
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_network_failure_recovery(self, mock_http, mock_session):
"""Test system recovery from network failures"""
mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
mock_session.add = Mock()
mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
with pytest.raises(requests.exceptions.RequestException):
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
mock_session.commit.assert_not_called()
@pytest.mark.parametrize(
("provider", "credentials"),
[
(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}),
(AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}),
(AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}),
],
)
def test_all_providers_factory_creation(self, provider, credentials):
"""Test factory creation for all supported providers"""
try:
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class is not None
factory = ApiKeyAuthFactory(provider, credentials)
assert factory.auth is not None
except ImportError:
pytest.skip(f"Provider {provider} not implemented yet")
def _create_success_response(self, status_code=200):
"""Create successful HTTP response mock"""
mock_response = Mock()
mock_response.status_code = status_code
mock_response.json.return_value = {"status": "success"}
mock_response.raise_for_status.return_value = None
return mock_response
def _create_mock_binding(self, tenant_id: str, provider: str, credentials: dict) -> Mock:
"""Create realistic database binding mock"""
mock_binding = Mock()
mock_binding.id = f"binding_{provider}_{tenant_id}"
mock_binding.tenant_id = tenant_id
mock_binding.category = self.category
mock_binding.provider = provider
mock_binding.credentials = json.dumps(credentials, ensure_ascii=False)
mock_binding.disabled = False
mock_binding.created_at = Mock()
mock_binding.created_at.timestamp.return_value = 1640995200
mock_binding.updated_at = Mock()
mock_binding.updated_at.timestamp.return_value = 1640995200
return mock_binding
def test_integration_coverage_validation(self):
"""Validate integration test coverage meets quality standards"""
core_scenarios = {
"business_logic": ["end_to_end_auth_flow", "cross_component_integration"],
"security": ["multi_tenant_isolation", "cross_tenant_access_prevention", "sensitive_data_protection"],
"reliability": ["concurrent_creation_safety", "network_failure_recovery"],
"compatibility": ["all_providers_factory_creation"],
"boundaries": ["invalid_input_boundary", "http_error_handling"],
}
total_scenarios = sum(len(scenarios) for scenarios in core_scenarios.values())
assert total_scenarios >= 10
security_tests = core_scenarios["security"]
assert "multi_tenant_isolation" in security_tests
assert "sensitive_data_protection" in security_tests
assert True

View File

@ -708,9 +708,9 @@ class TestTenantService:
with patch("services.account_service.db") as mock_db:
# Mock the join query that returns the tenant_account_join
mock_query = MagicMock()
mock_filter = MagicMock()
mock_filter.first.return_value = mock_tenant_join
mock_query.filter.return_value = mock_filter
mock_where = MagicMock()
mock_where.first.return_value = mock_tenant_join
mock_query.where.return_value = mock_where
mock_query.join.return_value = mock_query
mock_db.session.query.return_value = mock_query
@ -1381,10 +1381,10 @@ class TestRegisterService:
# Mock database queries - complex query mocking
mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query1.where.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal")
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
@ -1449,7 +1449,7 @@ class TestRegisterService:
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.filter.return_value.first.return_value = None # No account found
mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
@ -1482,7 +1482,7 @@ class TestRegisterService:
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal")
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]

View File

@ -43,7 +43,7 @@ def test_delete_workflow_success(workflow_setup):
# Setup mocks
# Mock the tool provider query to return None (not published as a tool)
workflow_setup["session"].query.return_value.filter.return_value.first.return_value = None
workflow_setup["session"].query.return_value.where.return_value.first.return_value = None
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
@ -106,7 +106,7 @@ def test_delete_workflow_published_as_tool_error(workflow_setup):
# Mock the tool provider query
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
workflow_setup["session"].query.return_value.filter.return_value.first.return_value = mock_tool_provider
workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]

View File

@ -95,7 +95,7 @@ def test_included_position_data(prepare_example_positions_yaml):
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
pin_list = ["forth", "first"]
include_set = {"forth", "first"}
exclude_set = {}
exclude_set = set()
position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)