Merge branch 'main' into sandboxed-agent-rebase

Made-with: Cursor

# Conflicts:
#	api/tests/unit_tests/controllers/console/app/test_message.py
#	api/tests/unit_tests/controllers/console/app/test_statistic.py
#	api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py
#	api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py
#	api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py
#	api/tests/unit_tests/controllers/console/auth/test_oauth_server.py
#	web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx
#	web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx
#	web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx
#	web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx
#	web/app/components/header/account-setting/data-source-page/panel/config-item.tsx
#	web/app/components/header/account-setting/data-source-page/panel/index.tsx
#	web/app/components/workflow/nodes/knowledge-retrieval/node.tsx
#	web/package.json
#	web/pnpm-lock.yaml
This commit is contained in:
Novice
2026-03-24 11:19:50 +08:00
294 changed files with 20298 additions and 13491 deletions

View File

@ -525,3 +525,147 @@ class TestAPIBasedExtensionService:
# Try to get extension with wrong tenant ID
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
def test_save_extension_api_key_exactly_four_chars_rejected(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""API key with exactly 4 characters should be rejected (boundary)."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key="1234",
)
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_api_key_exactly_five_chars_accepted(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""API key with exactly 5 characters should be accepted (boundary)."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key="12345",
)
saved = APIBasedExtensionService.save(extension_data)
assert saved.id is not None
def test_save_extension_requestor_constructor_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Exception raised by requestor constructor is wrapped in ValueError."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config")
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="connection error: bad config"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_network_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Network exceptions during ping are wrapped in ValueError."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError(
"network failure"
)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="connection error: network failure"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_update_duplicate_name_rejected(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Updating an existing extension to use another extension's name should fail."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
ext1 = APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant.id,
name="Extension Alpha",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
ext2 = APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant.id,
name="Extension Beta",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
# Try to rename ext2 to ext1's name
ext2.name = "Extension Alpha"
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(ext2)
def test_get_all_returns_empty_for_different_tenant(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Extensions from one tenant should not be visible to another."""
fake = Faker()
_, tenant1 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
_, tenant2 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant1 is not None
APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant1.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
assert tenant2 is not None
result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id)
assert result == []

View File

@ -0,0 +1,80 @@
"""Testcontainers integration tests for AttachmentService."""
import base64
from datetime import UTC, datetime
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
import services.attachment_service as attachment_service_module
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.attachment_service import AttachmentService
class TestAttachmentService:
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id or str(uuid4()),
storage_type=StorageType.OPENDAL,
key=f"upload/{uuid4()}.txt",
name="test-file.txt",
size=100,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=datetime.now(UTC),
used=False,
)
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
def test_should_initialize_with_sessionmaker(self):
session_factory = sessionmaker()
service = AttachmentService(session_factory=session_factory)
assert service._session_maker is session_factory
def test_should_initialize_with_engine(self):
engine = create_engine("sqlite:///:memory:")
service = AttachmentService(session_factory=engine)
session = service._session_maker()
try:
assert session.bind == engine
finally:
session.close()
engine.dispose()
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory):
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
AttachmentService(session_factory=invalid_session_factory)
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
upload_file = self._create_upload_file(db_session_with_containers)
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
result = service.get_file_base64(upload_file.id)
assert result == base64.b64encode(b"binary-content").decode()
mock_load.assert_called_once_with(upload_file.key)
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
with pytest.raises(NotFound, match="File not found"):
service.get_file_base64(str(uuid4()))
mock_load.assert_not_called()

View File

@ -0,0 +1,58 @@
"""Testcontainers integration tests for ConversationVariableUpdater."""
from uuid import uuid4
import pytest
from sqlalchemy.orm import sessionmaker
from dify_graph.variables import StringVariable
from extensions.ext_database import db
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
class TestConversationVariableUpdater:
def _create_conversation_variable(
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
) -> ConversationVariable:
row = ConversationVariable(
id=variable.id,
conversation_id=conversation_id,
app_id=app_id or str(uuid4()),
data=variable.model_dump_json(),
)
db_session_with_containers.add(row)
db_session_with_containers.commit()
return row
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
self._create_conversation_variable(
db_session_with_containers, conversation_id=conversation_id, variable=variable
)
updated_variable = StringVariable(id=variable.id, name="topic", value="new value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
updater.update(conversation_id=conversation_id, variable=updated_variable)
db_session_with_containers.expire_all()
row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id))
assert row is not None
assert row.data == updated_variable.model_dump_json()
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
updater.update(conversation_id=conversation_id, variable=variable)
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
result = updater.flush()
assert result is None

View File

@ -0,0 +1,104 @@
"""Testcontainers integration tests for CreditPoolService."""
from uuid import uuid4
import pytest
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from models.enums import ProviderQuotaType
from services.credit_pool_service import CreditPoolService
class TestCreditPoolService:
def _create_tenant_id(self) -> str:
return str(uuid4())
def test_create_default_pool(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
assert isinstance(pool, TenantCreditPool)
assert pool.tenant_id == tenant_id
assert pool.pool_type == ProviderQuotaType.TRIAL
assert pool.quota_used == 0
assert pool.quota_limit > 0
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
assert result is not None
assert result.tenant_id == tenant_id
assert result.pool_type == ProviderQuotaType.TRIAL
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
assert result is None
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers):
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
assert result is False
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10)
assert result is True
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
# Exhaust credits
pool.quota_used = pool.quota_limit
db_session_with_containers.commit()
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1)
assert result is False
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers):
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
pool.quota_used = pool.quota_limit
db_session_with_containers.commit()
with pytest.raises(QuotaExceededError, match="No credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
credits_required = 10
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required)
assert result == credits_required
db_session_with_containers.expire_all()
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
db_session_with_containers.commit()
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
assert result == remaining
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == pool.quota_limit

View File

@ -397,6 +397,68 @@ class TestDatasetPermissionServiceClearPartialMemberList:
class TestDatasetServiceCheckDatasetPermission:
"""Verify dataset access checks against persisted partial-member permissions."""
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers):
"""Test that users from different tenants cannot access dataset."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM
)
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other_user)
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers):
"""Test that tenant owners can access any dataset regardless of permission level."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
DatasetService.check_dataset_permission(dataset, owner)
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers):
"""Test ONLY_ME permission allows only the dataset creator to access."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
DatasetService.check_dataset_permission(dataset, creator)
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers):
"""Test ONLY_ME permission denies access to non-creators."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other)
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers):
"""Test ALL_TEAM permission allows any team member to access the dataset."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM
)
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
"""
Test that user with explicit permission can access partial_members dataset.
@ -443,6 +505,16 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers):
"""Test PARTIAL_TEAM permission allows creator to access without explicit permission."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM
)
DatasetService.check_dataset_permission(dataset, creator)
class TestDatasetServiceCheckDatasetOperatorPermission:
"""Verify operator permission checks against persisted partial-member permissions."""

View File

@ -694,3 +694,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1)
patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id)
def test_batch_update_invalid_action_raises_value_error(
self, db_session_with_containers: Session, patched_dependencies
):
"""Test that an invalid action raises ValueError."""
factory = DocumentBatchUpdateIntegrationDataFactory
dataset = factory.create_dataset(db_session_with_containers)
doc = factory.create_document(db_session_with_containers, dataset)
user = UserDouble(id=str(uuid4()))
patched_dependencies["redis_client"].get.return_value = None
with pytest.raises(ValueError, match="Invalid action"):
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user
)

View File

@ -0,0 +1,60 @@
"""Testcontainers integration tests for DatasetService.create_empty_rag_pipeline_dataset."""
from __future__ import annotations
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from models.account import Account, Tenant, TenantAccountJoin
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
class TestDatasetServiceCreateRagPipelineDataset:
def _create_tenant_and_account(self, db_session_with_containers) -> tuple[Tenant, Account]:
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"ds_create_{uuid4()}@example.com",
password="hashed",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return tenant, account
def _build_entity(self, name: str = "Test Dataset") -> RagPipelineDatasetCreateEntity:
icon_info = IconInfo(icon="\U0001f4d9", icon_background="#FFF4ED", icon_type="emoji")
return RagPipelineDatasetCreateEntity(
name=name,
description="",
icon_info=icon_info,
permission="only_me",
)
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers):
tenant, _ = self._create_tenant_and_account(db_session_with_containers)
mock_user = Mock(id=None)
with patch("services.dataset_service.current_user", mock_user):
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=tenant.id,
rag_pipeline_dataset_create_entity=self._build_entity(),
)

View File

@ -142,3 +142,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c
rows = db_session_with_containers.scalars(filtered).all()
assert {row.id for row in rows} == {doc1.id, doc2.id}
def test_normalize_display_status_alias_mapping():
"""Test that normalize_display_status maps aliases correctly."""
assert DocumentService.normalize_display_status("ACTIVE") == "available"
assert DocumentService.normalize_display_status("enabled") == "available"
assert DocumentService.normalize_display_status("archived") == "archived"
assert DocumentService.normalize_display_status("unknown") is None

View File

@ -414,3 +414,144 @@ class TestEndUserServiceGetEndUserById:
)
assert result is None
class TestEndUserServiceCreateBatch:
"""Integration tests for EndUserService.create_end_user_batch."""
@pytest.fixture
def factory(self):
return TestEndUserServiceFactory()
def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3):
"""Create multiple apps under the same tenant."""
first_app = factory.create_app_and_account(db_session_with_containers)
tenant_id = first_app.tenant_id
apps = [first_app]
for _ in range(count - 1):
app = App(
tenant_id=tenant_id,
name=f"App {uuid4()}",
description="",
mode="chat",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=first_app.created_by,
updated_by=first_app.updated_by,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all()
return tenant_id, all_apps
def test_create_batch_empty_app_ids(self, db_session_with_containers):
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1"
)
assert result == {}
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
app_ids = [a.id for a in apps]
user_id = f"user-{uuid4()}"
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
)
assert len(result) == 3
for app_id in app_ids:
assert app_id in result
assert result[app_id].session_id == user_id
assert result[app_id].type == InvokeFrom.SERVICE_API
def test_create_batch_default_session_id(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [a.id for a in apps]
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=""
)
assert len(result) == 2
for end_user in result.values():
assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
assert end_user._is_anonymous is True
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id]
user_id = f"user-{uuid4()}"
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
)
assert len(result) == 2
def test_create_batch_returns_existing_users(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [a.id for a in apps]
user_id = f"user-{uuid4()}"
# Create batch first time
first_result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
)
# Create batch second time — should return existing users
second_result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
)
assert len(second_result) == 2
for app_id in app_ids:
assert first_result[app_id].id == second_result[app_id].id
def test_create_batch_partial_existing_users(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
user_id = f"user-{uuid4()}"
# Create for first 2 apps
first_result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_ids=[apps[0].id, apps[1].id],
user_id=user_id,
)
# Create for all 3 apps — should reuse first 2, create 3rd
all_result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_ids=[a.id for a in apps],
user_id=user_id,
)
assert len(all_result) == 3
assert all_result[apps[0].id].id == first_result[apps[0].id].id
assert all_result[apps[1].id].id == first_result[apps[1].id].id
assert all_result[apps[2].id].session_id == user_id
@pytest.mark.parametrize(
"invoke_type",
[InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER],
)
def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1)
user_id = f"user-{uuid4()}"
result = EndUserService.create_end_user_batch(
type=invoke_type, tenant_id=tenant_id, app_ids=[apps[0].id], user_id=user_id
)
assert len(result) == 1
assert result[apps[0].id].type == invoke_type

View File

@ -0,0 +1,96 @@
"""
Testcontainers integration tests for FileService helpers.
Covers:
- ZIP tempfile building (sanitization + deduplication + content writes)
- tenant-scoped batch lookup behavior (get_upload_files_by_ids)
"""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from uuid import uuid4
from zipfile import ZipFile
import pytest
import services.file_service as file_service_module
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.file_service import FileService
def _create_upload_file(db_session, *, tenant_id: str, key: str, name: str) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=StorageType.OPENDAL,
key=key,
name=name,
size=100,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=datetime.now(UTC),
used=False,
)
db_session.add(upload_file)
db_session.commit()
return upload_file
def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None:
"""Ensure ZIP entry names are safe and unique while preserving extensions."""
upload_files: list[Any] = [
SimpleNamespace(name="a/b.txt", key="k1"),
SimpleNamespace(name="c/b.txt", key="k2"),
SimpleNamespace(name="../b.txt", key="k3"),
]
data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]}
def _load(key: str, stream: bool = True) -> list[bytes]:
assert stream is True
return data_by_key[key]
monkeypatch.setattr(file_service_module.storage, "load", _load)
with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp:
with ZipFile(tmp, mode="r") as zf:
assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"]
assert zf.read("b.txt") == b"one"
assert zf.read("b (1).txt") == b"two"
assert zf.read("b (2).txt") == b"three"
def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers) -> None:
"""Ensure empty input returns an empty mapping without hitting the database."""
assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {}
def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers) -> None:
"""Ensure batch lookup returns a dict keyed by stringified UploadFile ids."""
tenant_id = str(uuid4())
file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt")
file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt")
result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id])
assert set(result.keys()) == {file1.id, file2.id}
assert result[file1.id].id == file1.id
assert result[file2.id].id == file2.id
def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers) -> None:
"""Ensure files from other tenants are not returned."""
tenant_a = str(uuid4())
tenant_b = str(uuid4())
file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt")
_create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt")
result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id])
assert set(result.keys()) == {file_a.id}

View File

@ -8,6 +8,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.file.enums import FileType
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
# MessageFile
file = MessageFile(
message_id=message.id,
type="image",
type=FileType.IMAGE,
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to=MessageFileBelongsTo.USER,

View File

@ -0,0 +1,174 @@
"""Testcontainers integration tests for OAuthServerService."""
from __future__ import annotations
import uuid
from typing import cast
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import BadRequest
from models.model import OAuthProviderApp
from services.oauth_server import (
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
OAUTH_ACCESS_TOKEN_REDIS_KEY,
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
OAUTH_REFRESH_TOKEN_REDIS_KEY,
OAuthGrantType,
OAuthServerService,
)
class TestOAuthServerServiceGetProviderApp:
"""DB-backed tests for get_oauth_provider_app."""
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
app = OAuthProviderApp(
app_icon="icon.png",
client_id=client_id,
client_secret=str(uuid4()),
app_label={"en-US": "Test OAuth App"},
redirect_uris=["https://example.com/callback"],
scope="read",
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
client_id = f"client-{uuid4()}"
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
result = OAuthServerService.get_oauth_provider_app(client_id)
assert result is not None
assert result.client_id == client_id
assert result.id == created.id
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
assert result is None
class TestOAuthServerServiceTokenOperations:
"""Redis-backed tests for token sign/validate operations."""
@pytest.fixture
def mock_redis(self):
with patch("services.oauth_server.redis_client") as mock:
yield mock
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
assert code == str(deterministic_uuid)
mock_redis.set.assert_called_once_with(
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
"user-1",
ex=600,
)
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
mock_redis.get.return_value = None
with pytest.raises(BadRequest, match="invalid code"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="bad-code",
client_id="client-1",
)
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
token_uuids = [
uuid.UUID("00000000-0000-0000-0000-000000000201"),
uuid.UUID("00000000-0000-0000-0000-000000000202"),
]
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
mock_redis.get.return_value = b"user-1"
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="code-1",
client_id="client-1",
)
assert access_token == str(token_uuids[0])
assert refresh_token == str(token_uuids[1])
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
mock_redis.delete.assert_called_once_with(code_key)
mock_redis.set.assert_any_call(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
mock_redis.set.assert_any_call(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
b"user-1",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
mock_redis.get.return_value = None
with pytest.raises(BadRequest, match="invalid refresh token"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="stale-token",
client_id="client-1",
)
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
mock_redis.get.return_value = b"user-1"
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="refresh-1",
client_id="client-1",
)
assert access_token == str(deterministic_uuid)
assert returned_refresh == "refresh-1"
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
grant_type = cast(OAuthGrantType, "invalid-grant-type")
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
assert result is None
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
assert refresh_token == str(deterministic_uuid)
mock_redis.set.assert_called_once_with(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
"user-2",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
mock_redis.get.return_value = None
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
assert result is None
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
mock_redis.get.return_value = b"user-88"
expected_user = MagicMock()
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
assert result is expected_user
mock_load.assert_called_once_with("user-88")

View File

@ -396,11 +396,6 @@ class TestSavedMessageService:
assert "User is required" in str(exc_info.value)
# Verify no database operations were performed
saved_messages = db_session_with_containers.query(SavedMessage).all()
assert len(saved_messages) == 0
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
@ -497,124 +492,140 @@ class TestSavedMessageService:
# The message should still exist, only the saved_message should be deleted
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None
def test_pagination_by_last_id_error_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when no user is provided.
This test verifies:
- Proper error handling for missing user
- ValueError is raised when user is None
- No database operations are performed
"""
# Arrange: Create test data
fake = Faker()
def test_save_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""Test saving a message for an EndUser."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, end_user)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
mock_external_service_dependencies["message_service"].get_message.return_value = message
assert "User is required" in str(exc_info.value)
SavedMessageService.save(app_model=app, user=end_user, message_id=message.id)
# Verify no database operations were performed for this specific test
# Note: We don't check total count as other tests may have created data
# Instead, we verify that the error was properly raised
pass
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
This test verifies:
- Method returns early when user is None
- No database operations are performed
- No exceptions are raised
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Act: Execute the method under test with None user
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
# Assert: Verify the expected outcomes
assert result is None
# Verify no saved message was created
saved_message = (
saved = (
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
)
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
.first()
)
assert saved is not None
assert saved.created_by == end_user.id
assert saved.created_by_role == "end_user"
assert saved_message is None
def test_delete_success_existing_message(
def test_save_duplicate_is_idempotent(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful deletion of an existing saved message.
This test verifies:
- Proper deletion of existing saved message
- Correct database state after deletion
- No errors during deletion process
"""
# Arrange: Create test data
fake = Faker()
"""Test that saving an already-saved message does not create a duplicate."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Create a saved message first
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by=account.id,
)
mock_external_service_dependencies["message_service"].get_message.return_value = message
db_session_with_containers.add(saved_message)
# Save once
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
# Save again
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
count = (
db_session_with_containers.query(SavedMessage)
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
.count()
)
assert count == 1
def test_delete_without_user_does_nothing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that deleting without a user is a no-op."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Pre-create a saved message
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
# Verify saved message exists
SavedMessageService.delete(app_model=app, user=None, message_id=message.id)
# Should still exist
assert (
db_session_with_containers.query(SavedMessage)
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
.first()
is not None
)
def test_delete_non_existent_does_nothing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that deleting a non-existent saved message is a no-op."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Should not raise — use a valid UUID that doesn't exist in DB
from uuid import uuid4
SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4()))
def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""Test deleting a saved message for an EndUser."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, end_user)
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id)
assert (
db_session_with_containers.query(SavedMessage)
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
.first()
is None
)
def test_delete_only_affects_own_saved_messages(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that delete only removes the requesting user's saved message."""
app, account1 = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, account1)
# Both users save the same message
saved_account = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id
)
saved_end_user = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id
)
db_session_with_containers.add_all([saved_account, saved_end_user])
db_session_with_containers.commit()
# Delete only account1's saved message
SavedMessageService.delete(app_model=app, user=account1, message_id=message.id)
# Account's saved message should be gone
assert (
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
SavedMessage.created_by == account1.id,
)
.first()
is not None
is None
)
# Act: Execute the method under test
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
deleted_saved_message = (
# End user's saved message should still exist
assert (
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
SavedMessage.created_by == end_user.id,
)
.first()
is not None
)
assert deleted_saved_message is None
# Verify database state
db_session_with_containers.commit()
# The message should still exist, only the saved_message should be deleted
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset
from models.enums import DataSourceType
from models.enums import DataSourceType, TagType
from models.model import App, Tag, TagBinding
from services.tag_service import TagService
@ -547,7 +547,7 @@ class TestTagService:
assert result is not None
assert len(result) == 1
assert result[0].name == "python_tag"
assert result[0].type == "app"
assert result[0].type == TagType.APP
assert result[0].tenant_id == tenant.id
def test_get_tag_by_tag_name_no_matches(
@ -638,7 +638,7 @@ class TestTagService:
# Verify all tags are returned
for tag in result:
assert tag.type == "app"
assert tag.type == TagType.APP
assert tag.tenant_id == tenant.id
assert tag.id in [t.id for t in tags]

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from models.workflow import WorkflowAppLogCreatedFrom
from services.account_service import AccountService, TenantService
# Delay import of AppService to avoid circular dependency
@ -221,7 +222,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -357,7 +358,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_1.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -399,7 +400,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_2.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -441,7 +442,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_4.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -521,7 +522,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -627,7 +628,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -732,7 +733,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -860,7 +861,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -902,7 +903,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="web-app",
created_from=WorkflowAppLogCreatedFrom.WEB_APP,
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
@ -1037,7 +1038,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1125,7 +1126,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1279,7 +1280,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1379,7 +1380,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1481,7 +1482,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)

View File

@ -536,3 +536,151 @@ class TestApiToolManageService:
# Verify mock interactions
mock_external_service_dependencies["encrypter"].assert_called_once()
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
def test_delete_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test successful deletion of an API tool provider."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
schema = self._create_test_openapi_schema()
provider_name = fake.unique.word()
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy="",
custom_disclaimer="",
labels=[],
)
provider = (
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert provider is not None
result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name)
assert result == {"result": "success"}
deleted = (
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert deleted is None
def test_delete_api_tool_provider_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test deletion raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
def test_update_api_tool_provider_not_found(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test update raises ValueError when original provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="does not exists"):
ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name="new-name",
original_provider="nonexistent",
icon={},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
privacy_policy=None,
custom_disclaimer="",
labels=[],
)
def test_update_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test update raises ValueError when auth_type is missing from credentials."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
schema = self._create_test_openapi_schema()
provider_name = fake.unique.word()
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy="",
custom_disclaimer="",
labels=[],
)
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
original_provider=provider_name,
icon={},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy=None,
custom_disclaimer="",
labels=[],
)
def test_list_api_tool_provider_tools_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test listing tools raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
def test_test_api_tool_preview_invalid_schema_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test preview raises ValueError for invalid schema type."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="invalid schema type"):
ApiToolManageService.test_api_tool_preview(
tenant_id=tenant.id,
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type="bad-schema-type",
schema="schema",
)

View File

@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService:
# After the fix, this should always be 0
# For now, we document that the record may exist, demonstrating the bug
# assert tool_count == 0 # Expected after fix
def test_delete_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test successful deletion of a workflow tool."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
tool_name = fake.unique.word()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
)
tool = (
db_session_with_containers.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name)
.first()
)
assert tool is not None
result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id)
assert result == {"result": "success"}
deleted = (
db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first()
)
assert deleted is None
def test_list_tenant_workflow_tools_empty(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test listing workflow tools when none exist returns empty list."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id)
assert result == []
def test_get_workflow_tool_by_tool_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that get_workflow_tool_by_tool_id raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4())
def test_get_workflow_tool_by_app_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that get_workflow_tool_by_app_id raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4())
def test_list_single_workflow_tools_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that list_single_workflow_tools raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4())
def test_create_workflow_tool_with_labels(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that labels are forwarded to ToolLabelManager when provided."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.unique.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
labels=["label-1", "label-2"],
)
assert result == {"result": "success"}
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()

View File

@ -0,0 +1,158 @@
"""Testcontainers integration tests for WorkflowService.delete_workflow."""
import json
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session, sessionmaker
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin
from models.model import App
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
class TestWorkflowDeletion:
def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]:
tenant = Tenant(name=f"Tenant {uuid4()}")
session.add(tenant)
session.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"wf_del_{uuid4()}@example.com",
password="hashed",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
session.add(account)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
session.add(join)
session.flush()
return tenant, account
def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App:
app = App(
tenant_id=tenant.id,
name=f"App {uuid4()}",
description="",
mode="workflow",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
workflow_id=workflow_id,
)
session.add(app)
session.flush()
return app
def _create_workflow(
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0"
) -> Workflow:
workflow = Workflow(
id=str(uuid4()),
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
version=version,
graph=json.dumps({"nodes": [], "edges": []}),
_features=json.dumps({}),
created_by=account.id,
updated_by=account.id,
)
session.add(workflow)
session.flush()
return workflow
def _create_tool_provider(
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str
) -> WorkflowToolProvider:
provider = WorkflowToolProvider(
name=f"tool-{uuid4()}",
label=f"Tool {uuid4()}",
icon="wrench",
app_id=app.id,
version=version,
user_id=account.id,
tenant_id=tenant.id,
description="test tool provider",
)
session.add(provider)
session.flush()
return provider
def test_delete_workflow_success(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
db_session_with_containers.commit()
workflow_id = workflow.id
service = WorkflowService(sessionmaker(bind=db.engine))
result = service.delete_workflow(
session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id
)
assert result is True
db_session_with_containers.expire_all()
assert db_session_with_containers.get(Workflow, workflow_id) is None
def test_delete_draft_workflow_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="draft"
)
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(DraftWorkflowDeletionError):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
# Point app to this workflow
app.workflow_id = workflow.id
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(WorkflowInUseError, match="currently in use by app"):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0")
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(WorkflowInUseError, match="published as a tool"):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)